Commit 160aa4c11d

Luuk de Gram <luuk@degram.dev>
2022-05-13 21:25:23
wasm: Improve shl_with_overflow
This re-implements the shl_with_overflow operation from scratch, making it a lot more robust and outputs the equal code to the LLVM backend.
1 parent 0a2d3d4
Changed files (3)
src
arch
test
src/arch/wasm/CodeGen.zig
@@ -1452,7 +1452,7 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue {
 
         .add_with_overflow => self.airAddSubWithOverflow(inst, .add),
         .sub_with_overflow => self.airAddSubWithOverflow(inst, .sub),
-        .shl_with_overflow => self.airBinOpOverflow(inst, .shl),
+        .shl_with_overflow => self.airShlWithOverflow(inst),
         .mul_with_overflow => self.airMulWithOverflow(inst),
 
         .clz => self.airClz(inst),
@@ -3941,115 +3941,6 @@ fn airPtrSliceFieldPtr(self: *Self, inst: Air.Inst.Index, offset: u32) InnerErro
     return self.buildPointerOffset(slice_ptr, offset, .new);
 }
 
-fn airBinOpOverflow(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
-    if (self.liveness.isUnused(inst)) return WValue{ .none = {} };
-
-    const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
-    const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;
-    const lhs = try self.resolveInst(extra.lhs);
-    const rhs = try self.resolveInst(extra.rhs);
-    const lhs_ty = self.air.typeOf(extra.lhs);
-
-    if (lhs_ty.zigTypeTag() == .Vector) {
-        return self.fail("TODO: Implement overflow arithmetic for vectors", .{});
-    }
-
-    // We store the bit if it's overflowed or not in this. As it's zero-initialized
-    // we only need to update it if an overflow (or underflow) occured.
-    const overflow_bit = try self.allocLocal(Type.initTag(.u1));
-    const int_info = lhs_ty.intInfo(self.target);
-    const wasm_bits = toWasmBits(int_info.bits) orelse {
-        return self.fail("TODO: Implement overflow arithmetic for integer bitsize: {d}", .{int_info.bits});
-    };
-
-    const zero = switch (wasm_bits) {
-        32 => WValue{ .imm32 = 0 },
-        64 => WValue{ .imm64 = 0 },
-        else => unreachable,
-    };
-    const int_max = (@as(u65, 1) << @intCast(u7, int_info.bits - @boolToInt(int_info.signedness == .signed))) - 1;
-    const int_max_wvalue = switch (wasm_bits) {
-        32 => WValue{ .imm32 = @intCast(u32, int_max) },
-        64 => WValue{ .imm64 = @intCast(u64, int_max) },
-        else => unreachable,
-    };
-    const int_min = if (int_info.signedness == .unsigned)
-        @as(i64, 0)
-    else
-        -@as(i64, 1) << @intCast(u6, int_info.bits - 1);
-    const int_min_wvalue = switch (wasm_bits) {
-        32 => WValue{ .imm32 = @bitCast(u32, @intCast(i32, int_min)) },
-        64 => WValue{ .imm64 = @bitCast(u64, int_min) },
-        else => unreachable,
-    };
-
-    if (int_info.signedness == .unsigned and op == .add) {
-        const diff = try self.binOp(int_max_wvalue, lhs, lhs_ty, .sub);
-        const cmp_res = try self.cmp(rhs, diff, lhs_ty, .gt);
-        try self.emitWValue(cmp_res);
-        try self.addLabel(.local_set, overflow_bit.local);
-    } else if (op == .sub) {
-        const cmp_res = try self.cmp(lhs, rhs, lhs_ty, .lt);
-        try self.emitWValue(cmp_res);
-        try self.addLabel(.local_set, overflow_bit.local);
-    } else if (int_info.signedness == .signed and op != .shl) {
-        // for overflow, we first check if lhs is > 0 (or lhs < 0 in case of subtraction). If not, we will not overflow.
-        // We first create an outer block, where we handle overflow.
-        // Then we create an inner block, where underflow is handled.
-        try self.startBlock(.block, wasm.block_empty);
-        try self.startBlock(.block, wasm.block_empty);
-        {
-            try self.emitWValue(lhs);
-            const cmp_result = try self.cmp(lhs, zero, lhs_ty, .lt);
-            try self.emitWValue(cmp_result);
-        }
-        try self.addLabel(.br_if, 0); // break to outer block, and handle underflow
-
-        // handle overflow
-        {
-            const diff = try self.binOp(int_max_wvalue, lhs, lhs_ty, .sub);
-            const cmp_res = try self.cmp(rhs, diff, lhs_ty, if (op == .add) .gt else .lt);
-            try self.emitWValue(cmp_res);
-            try self.addLabel(.local_set, overflow_bit.local);
-        }
-        try self.addLabel(.br, 1); // break from blocks, and continue regular flow.
-        try self.endBlock();
-
-        // handle underflow
-        {
-            const diff = try self.binOp(int_min_wvalue, lhs, lhs_ty, .sub);
-            const cmp_res = try self.cmp(rhs, diff, lhs_ty, if (op == .add) .lt else .gt);
-            try self.emitWValue(cmp_res);
-            try self.addLabel(.local_set, overflow_bit.local);
-        }
-        try self.endBlock();
-    }
-
-    const bin_op = if (op == .shl) blk: {
-        const tmp_val = try self.binOp(lhs, rhs, lhs_ty, op);
-        const cmp_res = try self.cmp(tmp_val, int_max_wvalue, lhs_ty, .gt);
-        try self.emitWValue(cmp_res);
-        try self.addLabel(.local_set, overflow_bit.local);
-
-        try self.emitWValue(tmp_val);
-        try self.emitWValue(int_max_wvalue);
-        switch (wasm_bits) {
-            32 => try self.addTag(.i32_and),
-            64 => try self.addTag(.i64_and),
-            else => unreachable,
-        }
-        try self.addLabel(.local_set, tmp_val.local);
-        break :blk tmp_val;
-    } else try self.wrapBinOp(lhs, rhs, lhs_ty, op);
-
-    const result_ptr = try self.allocStack(self.air.typeOfIndex(inst));
-    try self.store(result_ptr, bin_op, lhs_ty, 0);
-    const offset = @intCast(u32, lhs_ty.abiSize(self.target));
-    try self.store(result_ptr, overflow_bit, Type.initTag(.u1), offset);
-
-    return result_ptr;
-}
-
 fn airAddSubWithOverflow(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
     assert(op == .add or op == .sub);
     const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
@@ -4065,13 +3956,9 @@ fn airAddSubWithOverflow(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!W
     const int_info = lhs_ty.intInfo(self.target);
     const is_signed = int_info.signedness == .signed;
     const wasm_bits = toWasmBits(int_info.bits) orelse {
-        return self.fail("TODO: Implement sub_with_overflow for integer bitsize: {d}", .{int_info.bits});
+        return self.fail("TODO: Implement {{add/sub}}_with_overflow for integer bitsize: {d}", .{int_info.bits});
     };
 
-    if (wasm_bits == 128) {
-        return self.fail("TODO: Implement sub_with_overflow for 128 bit integers", .{});
-    }
-
     const zero = switch (wasm_bits) {
         32 => WValue{ .imm32 = 0 },
         64 => WValue{ .imm64 = 0 },
@@ -4123,6 +4010,53 @@ fn airAddSubWithOverflow(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!W
     return result_ptr;
 }
 
+fn airShlWithOverflow(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
+    const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
+    const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;
+    const lhs = try self.resolveInst(extra.lhs);
+    const rhs = try self.resolveInst(extra.rhs);
+    const lhs_ty = self.air.typeOf(extra.lhs);
+
+    if (lhs_ty.zigTypeTag() == .Vector) {
+        return self.fail("TODO: Implement overflow arithmetic for vectors", .{});
+    }
+
+    const int_info = lhs_ty.intInfo(self.target);
+    const is_signed = int_info.signedness == .signed;
+    const wasm_bits = toWasmBits(int_info.bits) orelse {
+        return self.fail("TODO: Implement shl_with_overflow for integer bitsize: {d}", .{int_info.bits});
+    };
+
+    const shl = try self.binOp(lhs, rhs, lhs_ty, .shl);
+    const result = if (wasm_bits != int_info.bits) blk: {
+        break :blk try self.wrapOperand(shl, lhs_ty);
+    } else shl;
+
+    const overflow_bit = if (wasm_bits != int_info.bits and is_signed) blk: {
+        const shift_amt = wasm_bits - int_info.bits;
+        const shift_val = switch (wasm_bits) {
+            32 => WValue{ .imm32 = shift_amt },
+            64 => WValue{ .imm64 = shift_amt },
+            else => unreachable,
+        };
+
+        const secondary_shl = try self.binOp(shl, shift_val, lhs_ty, .shl);
+        const initial_shr = try self.binOp(secondary_shl, shift_val, lhs_ty, .shr);
+        const shr = try self.wrapBinOp(initial_shr, rhs, lhs_ty, .shr);
+        break :blk try self.cmp(lhs, shr, lhs_ty, .neq);
+    } else blk: {
+        const shr = try self.binOp(result, rhs, lhs_ty, .shr);
+        break :blk try self.cmp(lhs, shr, lhs_ty, .neq);
+    };
+
+    const result_ptr = try self.allocStack(self.air.typeOfIndex(inst));
+    try self.store(result_ptr, result, lhs_ty, 0);
+    const offset = @intCast(u32, lhs_ty.abiSize(self.target));
+    try self.store(result_ptr, overflow_bit, Type.initTag(.u1), offset);
+
+    return result_ptr;
+}
+
 fn airMulWithOverflow(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
     const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
     const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;
test/behavior/union.zig
@@ -212,7 +212,6 @@ test "union with specified enum tag" {
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
 
     try doTest();
     comptime try doTest();
@@ -222,7 +221,6 @@ test "packed union generates correctly aligned type" {
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage1) return error.SkipZigTest;
 
     const U = packed union {
test/behavior/while.zig
@@ -146,7 +146,6 @@ test "while with optional as condition" {
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
 
     numbers_left = 10;
     var sum: i32 = 0;
@@ -160,7 +159,6 @@ test "while with optional as condition with else" {
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
 
     numbers_left = 10;
     var sum: i32 = 0;
@@ -179,7 +177,6 @@ test "while with error union condition" {
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
 
     numbers_left = 10;
     var sum: i32 = 0;