Commit ce5d934f5f

Luuk de Gram <luuk@degram.dev>
2022-06-15 22:03:18
wasm: saturating add and sub for signed integers
1 parent fcd4280
Changed files (1)
src
arch
src/arch/wasm/CodeGen.zig
@@ -4885,8 +4885,8 @@ fn airSatBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
 
     const bin_op = self.air.instructions.items(.data)[inst].bin_op;
     const ty = self.air.typeOfIndex(inst);
-    const lhs_operand = try self.resolveInst(bin_op.lhs);
-    const rhs_operand = try self.resolveInst(bin_op.rhs);
+    const lhs = try self.resolveInst(bin_op.lhs);
+    const rhs = try self.resolveInst(bin_op.rhs);
 
     const int_info = ty.intInfo(self.target);
     const is_signed = int_info.signedness == .signed;
@@ -4895,22 +4895,12 @@ fn airSatBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
         return self.fail("TODO: saturating arithmetic for integers with bitsize '{d}'", .{int_info.bits});
     }
 
-    const wasm_bits = toWasmBits(int_info.bits).?;
-
-    const lhs = if (is_signed) blk: {
-        break :blk try self.signAbsValue(lhs_operand, ty);
-    } else lhs_operand;
-    const rhs = if (is_signed) blk: {
-        break :blk try self.signAbsValue(rhs_operand, ty);
-    } else rhs_operand;
-
-    const opcode = buildOpcode(.{ .op = op, .valtype1 = typeToValtype(ty, self.target) });
-    try self.emitWValue(lhs);
-    try self.emitWValue(rhs);
-    try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
-    const bin_result = try self.allocLocal(ty);
-    try self.addLabel(.local_set, bin_result.local);
+    if (is_signed) {
+        return signedSat(self, lhs, rhs, ty, op);
+    }
 
+    const wasm_bits = toWasmBits(int_info.bits).?;
+    const bin_result = try self.binOp(lhs, rhs, ty, op);
     if (wasm_bits != int_info.bits and op == .add) {
         const val: u64 = @intCast(u64, (@as(u65, 1) << @intCast(u7, int_info.bits)) - 1);
         const imm_val = switch (wasm_bits) {
@@ -4919,7 +4909,7 @@ fn airSatBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
             else => unreachable,
         };
 
-        const cmp_result = try self.cmp(bin_result, imm_val, ty, if (op == .add) .lt else .gt);
+        const cmp_result = try self.cmp(bin_result, imm_val, ty, .lt);
         try self.emitWValue(bin_result);
         try self.emitWValue(imm_val);
         try self.emitWValue(cmp_result);
@@ -4939,3 +4929,62 @@ fn airSatBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
     try self.addLabel(.local_set, result.local);
     return result;
 }
+
+fn signedSat(self: *Self, lhs_operand: WValue, rhs_operand: WValue, ty: Type, op: Op) InnerError!WValue {
+    const int_info = ty.intInfo(self.target);
+    const wasm_bits = toWasmBits(int_info.bits).?;
+    const is_wasm_bits = wasm_bits == int_info.bits;
+
+    const lhs = if (!is_wasm_bits) try self.signAbsValue(lhs_operand, ty) else lhs_operand;
+    const rhs = if (!is_wasm_bits) try self.signAbsValue(rhs_operand, ty) else rhs_operand;
+
+    const max_val: u64 = @intCast(u64, (@as(u65, 1) << @intCast(u7, int_info.bits - 1)) - 1);
+    const min_val = @intCast(i64, ~@intCast(u63, max_val));
+    const max_wvalue = switch (wasm_bits) {
+        32 => WValue{ .imm32 = @intCast(u32, max_val) },
+        64 => WValue{ .imm64 = max_val },
+        else => unreachable,
+    };
+    const min_wvalue = switch (wasm_bits) {
+        32 => WValue{ .imm32 = @bitCast(u32, @truncate(i32, min_val)) },
+        64 => WValue{ .imm64 = @bitCast(u64, min_val) },
+        else => unreachable,
+    };
+
+    const bin_result = try self.binOp(lhs, rhs, ty, op);
+    if (!is_wasm_bits) {
+        const cmp_result_lt = try self.cmp(bin_result, max_wvalue, ty, .lt);
+        try self.emitWValue(bin_result);
+        try self.emitWValue(max_wvalue);
+        try self.emitWValue(cmp_result_lt);
+        try self.addTag(.select);
+        try self.addLabel(.local_set, bin_result.local); // re-use local
+
+        const cmp_result_gt = try self.cmp(bin_result, min_wvalue, ty, .gt);
+        try self.emitWValue(bin_result);
+        try self.emitWValue(min_wvalue);
+        try self.emitWValue(cmp_result_gt);
+        try self.addTag(.select);
+        try self.addLabel(.local_set, bin_result.local); // re-use local
+        return self.wrapOperand(bin_result, ty);
+    } else {
+        const zero = switch (wasm_bits) {
+            32 => WValue{ .imm32 = 0 },
+            64 => WValue{ .imm64 = 0 },
+            else => unreachable,
+        };
+        const cmp_bin_result = try self.cmp(bin_result, lhs, ty, .lt);
+        const cmp_zero_result = try self.cmp(rhs, zero, ty, if (op == .add) .lt else .gt);
+        const xor = try self.binOp(cmp_zero_result, cmp_bin_result, ty, .xor);
+        const cmp_bin_zero_result = try self.cmp(bin_result, zero, ty, .lt);
+        try self.emitWValue(max_wvalue);
+        try self.emitWValue(min_wvalue);
+        try self.emitWValue(cmp_bin_zero_result);
+        try self.addTag(.select);
+        try self.emitWValue(bin_result);
+        try self.emitWValue(xor);
+        try self.addTag(.select);
+        try self.addLabel(.local_set, bin_result.local); // re-use local
+        return bin_result;
+    }
+}