Commit 56d535dd24

Pavel Verigo <paul.verigo@gmail.com>
2024-07-18 18:01:06
stage2-wasm: improve @shlWithOverflow for <= 128 bits
Additionally fixed a bug for shr on signed big ints
1 parent dc3176d
Changed files (2)
src
arch
test
behavior
src/arch/wasm/CodeGen.zig
@@ -2669,8 +2669,14 @@ fn binOpBigInt(func: *CodeGen, lhs: WValue, rhs: WValue, ty: Type, op: Op) Inner
             .signed => return func.callIntrinsic("__udivti3", &.{ ty.toIntern(), ty.toIntern() }, ty, &.{ lhs, rhs }),
             .unsigned => return func.callIntrinsic("__divti3", &.{ ty.toIntern(), ty.toIntern() }, ty, &.{ lhs, rhs }),
         },
-        .rem => return func.callIntrinsic("__umodti3", &.{ ty.toIntern(), ty.toIntern() }, ty, &.{ lhs, rhs }),
-        .shr => return func.callIntrinsic("__lshrti3", &.{ ty.toIntern(), .i32_type }, ty, &.{ lhs, rhs }),
+        .rem => switch (int_info.signedness) {
+            .signed => return func.callIntrinsic("__modti3", &.{ ty.toIntern(), ty.toIntern() }, ty, &.{ lhs, rhs }),
+            .unsigned => return func.callIntrinsic("__umodti3", &.{ ty.toIntern(), ty.toIntern() }, ty, &.{ lhs, rhs }),
+        },
+        .shr => switch (int_info.signedness) {
+            .signed => return func.callIntrinsic("__ashrti3", &.{ ty.toIntern(), .i32_type }, ty, &.{ lhs, rhs }),
+            .unsigned => return func.callIntrinsic("__lshrti3", &.{ ty.toIntern(), .i32_type }, ty, &.{ lhs, rhs }),
+        },
         .shl => return func.callIntrinsic("__ashlti3", &.{ ty.toIntern(), .i32_type }, ty, &.{ lhs, rhs }),
         .@"and", .@"or", .xor => {
             const result = try func.allocStack(ty);
@@ -6055,14 +6061,14 @@ fn airShlWithOverflow(func: *CodeGen, inst: Air.Inst.Index) InnerError!void {
 
     const lhs = try func.resolveInst(extra.lhs);
     const rhs = try func.resolveInst(extra.rhs);
-    const lhs_ty = func.typeOf(extra.lhs);
+    const ty = func.typeOf(extra.lhs);
     const rhs_ty = func.typeOf(extra.rhs);
 
-    if (lhs_ty.zigTypeTag(mod) == .Vector) {
+    if (ty.zigTypeTag(mod) == .Vector) {
         return func.fail("TODO: Implement overflow arithmetic for vectors", .{});
     }
 
-    const int_info = lhs_ty.intInfo(mod);
+    const int_info = ty.intInfo(mod);
     const wasm_bits = toWasmBits(int_info.bits) orelse {
         return func.fail("TODO: Implement shl_with_overflow for integer bitsize: {d}", .{int_info.bits});
     };
@@ -6070,32 +6076,28 @@ fn airShlWithOverflow(func: *CodeGen, inst: Air.Inst.Index) InnerError!void {
     // Ensure rhs is coerced to lhs as they must have the same WebAssembly types
     // before we can perform any binary operation.
     const rhs_wasm_bits = toWasmBits(rhs_ty.intInfo(mod).bits).?;
-    const rhs_final = if (wasm_bits != rhs_wasm_bits) blk: {
-        const rhs_casted = try func.intcast(rhs, rhs_ty, lhs_ty);
-        break :blk try rhs_casted.toLocal(func, lhs_ty);
+    // If wasm_bits == 128, compiler-rt expects i32 for shift
+    const rhs_final = if (wasm_bits != rhs_wasm_bits and wasm_bits == 64) blk: {
+        const rhs_casted = try func.intcast(rhs, rhs_ty, ty);
+        break :blk try rhs_casted.toLocal(func, ty);
     } else rhs;
 
-    var shl = try (try func.binOp(lhs, rhs_final, lhs_ty, .shl)).toLocal(func, lhs_ty);
+    var shl = try (try func.wrapBinOp(lhs, rhs_final, ty, .shl)).toLocal(func, ty);
     defer shl.free(func);
-    var result = if (wasm_bits != int_info.bits) blk: {
-        break :blk try (try func.wrapOperand(shl, lhs_ty)).toLocal(func, lhs_ty);
-    } else shl;
-    defer result.free(func); // it's a no-op to free the same local twice (when wasm_bits == int_info.bits)
 
     const overflow_bit = blk: {
-        try func.emitWValue(lhs);
-        const shr = try func.binOp(result, rhs_final, lhs_ty, .shr);
-        break :blk try func.cmp(.stack, shr, lhs_ty, .neq);
+        const shr = try func.binOp(shl, rhs_final, ty, .shr);
+        break :blk try func.cmp(shr, lhs, ty, .neq);
     };
     var overflow_local = try overflow_bit.toLocal(func, Type.u1);
     defer overflow_local.free(func);
 
-    const result_ptr = try func.allocStack(func.typeOfIndex(inst));
-    try func.store(result_ptr, result, lhs_ty, 0);
-    const offset = @as(u32, @intCast(lhs_ty.abiSize(pt)));
-    try func.store(result_ptr, overflow_local, Type.u1, offset);
+    const result = try func.allocStack(func.typeOfIndex(inst));
+    const offset: u32 = @intCast(ty.abiSize(pt));
+    try func.store(result, shl, ty, 0);
+    try func.store(result, overflow_local, Type.u1, offset);
 
-    return func.finishAir(inst, result_ptr, &.{ extra.lhs, extra.rhs });
+    return func.finishAir(inst, result, &.{ extra.lhs, extra.rhs });
 }
 
 fn airMulWithOverflow(func: *CodeGen, inst: Air.Inst.Index) InnerError!void {
test/behavior/math.zig
@@ -1345,62 +1345,56 @@ test "@subWithOverflow > 64 bits" {
     try testSubWithOverflow(i128, maxInt(i128), -2, minInt(i128) + 1, 1);
 }
 
+fn testShlWithOverflow(comptime T: type, a: T, b: math.Log2Int(T), shl: T, bit: u1) !void {
+    const ov = @shlWithOverflow(a, b);
+    try expect(ov[0] == shl);
+    try expect(ov[1] == bit);
+}
+
 test "@shlWithOverflow" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
-    {
-        var a: u4 = 2;
-        _ = &a;
-        var b: u2 = 1;
-        var ov = @shlWithOverflow(a, b);
-        try expect(ov[0] == 4);
-        try expect(ov[1] == 0);
+    try testShlWithOverflow(u4, 2, 1, 4, 0);
+    try testShlWithOverflow(u4, 2, 3, 0, 1);
 
-        b = 3;
-        ov = @shlWithOverflow(a, b);
-        try expect(ov[0] == 0);
-        try expect(ov[1] == 1);
-    }
+    try testShlWithOverflow(i9, 127, 1, 254, 0);
+    try testShlWithOverflow(i9, 127, 2, -4, 1);
 
-    {
-        var a: i9 = 127;
-        _ = &a;
-        var b: u4 = 1;
-        var ov = @shlWithOverflow(a, b);
-        try expect(ov[0] == 254);
-        try expect(ov[1] == 0);
+    try testShlWithOverflow(u16, 0b0010111111111111, 3, 0b0111111111111000, 1);
+    try testShlWithOverflow(u16, 0b0010111111111111, 2, 0b1011111111111100, 0);
 
-        b = 2;
-        ov = @shlWithOverflow(a, b);
-        try expect(ov[0] == -4);
-        try expect(ov[1] == 1);
-    }
+    try testShlWithOverflow(u16, 0b0000_0000_0000_0011, 15, 0b1000_0000_0000_0000, 1);
+    try testShlWithOverflow(u16, 0b0000_0000_0000_0011, 14, 0b1100_0000_0000_0000, 0);
+}
 
-    {
-        const ov = @shlWithOverflow(@as(u16, 0b0010111111111111), 3);
-        try expect(ov[0] == 0b0111111111111000);
-        try expect(ov[1] == 1);
-    }
-    {
-        const ov = @shlWithOverflow(@as(u16, 0b0010111111111111), 2);
-        try expect(ov[0] == 0b1011111111111100);
-        try expect(ov[1] == 0);
-    }
-    {
-        var a: u16 = 0b0000_0000_0000_0011;
-        _ = &a;
-        var b: u4 = 15;
-        var ov = @shlWithOverflow(a, b);
-        try expect(ov[0] == 0b1000_0000_0000_0000);
-        try expect(ov[1] == 1);
-        b = 14;
-        ov = @shlWithOverflow(a, b);
-        try expect(ov[0] == 0b1100_0000_0000_0000);
-        try expect(ov[1] == 0);
-    }
+test "@shlWithOverflow > 64 bits" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
+
+    try testShlWithOverflow(u65, 0x0_0100_0000_0000_0000, 7, 0x0_8000_0000_0000_0000, 0);
+    try testShlWithOverflow(u65, 0x0_0100_0000_0000_0000, 8, 0x1_0000_0000_0000_0000, 0);
+    try testShlWithOverflow(u65, 0x0_0100_0000_0000_0000, 9, 0, 1);
+    try testShlWithOverflow(u65, 0x0_0100_0000_0000_0000, 10, 0, 1);
+
+    try testShlWithOverflow(u128, 0x0100_0000_0000_0000_0000000000000000, 6, 0x4000_0000_0000_0000_0000000000000000, 0);
+    try testShlWithOverflow(u128, 0x0100_0000_0000_0000_0000000000000000, 7, 0x8000_0000_0000_0000_0000000000000000, 0);
+    try testShlWithOverflow(u128, 0x0100_0000_0000_0000_0000000000000000, 8, 0, 1);
+    try testShlWithOverflow(u128, 0x0100_0000_0000_0000_0000000000000000, 9, 0, 1);
+
+    try testShlWithOverflow(i65, 0x0_0100_0000_0000_0000, 7, 0x0_8000_0000_0000_0000, 0);
+    try testShlWithOverflow(i65, 0x0_0100_0000_0000_0000, 8, minInt(i65), 1);
+    try testShlWithOverflow(i65, 0x0_0100_0000_0000_0000, 9, 0, 1);
+    try testShlWithOverflow(i65, 0x0_0100_0000_0000_0000, 10, 0, 1);
+
+    try testShlWithOverflow(i128, 0x0100_0000_0000_0000_0000000000000000, 6, 0x4000_0000_0000_0000_0000000000000000, 0);
+    try testShlWithOverflow(i128, 0x0100_0000_0000_0000_0000000000000000, 7, minInt(i128), 1);
+    try testShlWithOverflow(i128, 0x0100_0000_0000_0000_0000000000000000, 8, 0, 1);
+    try testShlWithOverflow(i128, 0x0100_0000_0000_0000_0000000000000000, 9, 0, 1);
 }
 
 test "overflow arithmetic with u0 values" {