Commit 305b113a53

Luuk de Gram <luuk@degram.dev>
2022-07-25 07:10:56
wasm: keep result of `cmp` on the stack
By keeping the result on the stack, we prevent codegen from generating unneccesary locals when we have subsequent instructions that do not have to be re-used.
1 parent cc6f2b6
Changed files (1)
src
arch
src/arch/wasm/CodeGen.zig
@@ -2659,10 +2659,14 @@ fn airCmp(self: *Self, inst: Air.Inst.Index, op: std.math.CompareOperator) Inner
     const lhs = try self.resolveInst(bin_op.lhs);
     const rhs = try self.resolveInst(bin_op.rhs);
     const operand_ty = self.air.typeOf(bin_op.lhs);
-    return self.cmp(lhs, rhs, operand_ty, op);
+    return (try self.cmp(lhs, rhs, operand_ty, op)).toLocal(self, Type.u32); // comparison result is always 32 bits
 }
 
+/// Compares two operands.
+/// Asserts rhs is not a stack value when the lhs isn't a stack value either
+/// NOTE: This leaves the result on top of the stack, rather than a new local.
 fn cmp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareOperator) InnerError!WValue {
+    assert(!(lhs != .stack and rhs == .stack));
     if (ty.zigTypeTag() == .Optional and !ty.optionalReprIsPayload()) {
         var buf: Type.Payload.ElemType = undefined;
         const payload_ty = ty.optionalChild(&buf);
@@ -2704,9 +2708,7 @@ fn cmp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareOper
     });
     try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
 
-    const cmp_tmp = try self.allocLocal(Type.initTag(.i32)); // bool is always i32
-    try self.addLabel(.local_set, cmp_tmp.local);
-    return cmp_tmp;
+    return WValue{ .stack = {} };
 }
 
 fn cmpFloat16(self: *Self, lhs: WValue, rhs: WValue, op: std.math.CompareOperator) InnerError!WValue {
@@ -2729,9 +2731,7 @@ fn cmpFloat16(self: *Self, lhs: WValue, rhs: WValue, op: std.math.CompareOperato
     try self.emitWValue(ext_rhs);
     try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
 
-    const result = try self.allocLocal(Type.initTag(.i32)); // bool is always i32
-    try self.addLabel(.local_set, result.local);
-    return result;
+    return WValue{ .stack = {} };
 }
 
 fn airCmpVector(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
@@ -3982,13 +3982,16 @@ fn cmpOptionals(self: *Self, lhs: WValue, rhs: WValue, operand_ty: Type, op: std
     try self.addImm32(0);
     try self.addTag(if (op == .eq) .i32_ne else .i32_eq);
     try self.addLabel(.local_set, result.local);
-    return result;
+    try self.emitWValue(result);
+    return WValue{ .stack = {} };
 }
 
 /// Compares big integers by checking both its high bits and low bits.
+/// NOTE: Leaves the result of the comparison on top of the stack.
 /// TODO: Lower this to compiler_rt call when bitsize > 128
 fn cmpBigInt(self: *Self, lhs: WValue, rhs: WValue, operand_ty: Type, op: std.math.CompareOperator) InnerError!WValue {
     assert(operand_ty.abiSize(self.target) >= 16);
+    assert(!(lhs != .stack and rhs == .stack));
     if (operand_ty.intInfo(self.target).bits > 128) {
         return self.fail("TODO: Support cmpBigInt for integer bitsize: '{d}'", .{operand_ty.intInfo(self.target).bits});
     }
@@ -4012,20 +4015,15 @@ fn cmpBigInt(self: *Self, lhs: WValue, rhs: WValue, operand_ty: Type, op: std.ma
         },
         else => {
             const ty = if (operand_ty.isSignedInt()) Type.i64 else Type.u64;
-            const high_bit_eql = try self.cmp(lhs_high_bit, rhs_high_bit, ty, .eq);
-            const high_bit_cmp = try self.cmp(lhs_high_bit, rhs_high_bit, ty, op);
-            const low_bit_cmp = try self.cmp(lhs_low_bit, rhs_low_bit, ty, op);
-
-            try self.emitWValue(low_bit_cmp);
-            try self.emitWValue(high_bit_cmp);
-            try self.emitWValue(high_bit_eql);
+            // leave those value on top of the stack for '.select'
+            _ = try self.cmp(lhs_low_bit, rhs_low_bit, ty, op);
+            _ = try self.cmp(lhs_high_bit, rhs_high_bit, ty, op);
+            _ = try self.cmp(lhs_high_bit, rhs_high_bit, ty, .eq);
             try self.addTag(.select);
         },
     }
 
-    const result = try self.allocLocal(Type.initTag(.i32));
-    try self.addLabel(.local_set, result.local);
-    return result;
+    return WValue{ .stack = {} };
 }
 
 fn airSetUnionTag(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
@@ -4350,7 +4348,7 @@ fn airAddSubWithOverflow(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!W
         if (wasm_bits == int_info.bits) {
             const cmp_zero = try self.cmp(rhs, zero, lhs_ty, cmp_op);
             const lt = try self.cmp(bin_op, lhs, lhs_ty, .lt);
-            break :blk try (try self.binOp(cmp_zero, lt, Type.u32, .xor)).toLocal(self, Type.u32); // result of cmp_zero and lt is always 32bit
+            break :blk try self.binOp(cmp_zero, lt, Type.u32, .xor);
         }
         const abs = try self.signAbsValue(bin_op, lhs_ty);
         break :blk try self.cmp(abs, bin_op, lhs_ty, .neq);
@@ -4358,11 +4356,12 @@ fn airAddSubWithOverflow(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!W
         try self.cmp(bin_op, lhs, lhs_ty, cmp_op)
     else
         try self.cmp(bin_op, result, lhs_ty, .neq);
+    const overflow_local = try overflow_bit.toLocal(self, Type.u32);
 
     const result_ptr = try self.allocStack(self.air.typeOfIndex(inst));
     try self.store(result_ptr, result, lhs_ty, 0);
     const offset = @intCast(u32, lhs_ty.abiSize(self.target));
-    try self.store(result_ptr, overflow_bit, Type.initTag(.u1), offset);
+    try self.store(result_ptr, overflow_local, Type.initTag(.u1), offset);
 
     return result_ptr;
 }
@@ -4384,9 +4383,9 @@ fn airAddSubWithOverflowBigInt(self: *Self, lhs: WValue, rhs: WValue, ty: Type,
     const high_op_res = try (try self.binOp(lhs_high_bit, rhs_high_bit, Type.u64, op)).toLocal(self, Type.u64);
 
     const lt = if (op == .add) blk: {
-        break :blk try self.cmp(high_op_res, lhs_high_bit, Type.u64, .lt);
+        break :blk try (try self.cmp(high_op_res, lhs_high_bit, Type.u64, .lt)).toLocal(self, Type.u32);
     } else if (op == .sub) blk: {
-        break :blk try self.cmp(lhs_high_bit, rhs_high_bit, Type.u64, .lt);
+        break :blk try (try self.cmp(lhs_high_bit, rhs_high_bit, Type.u64, .lt)).toLocal(self, Type.u32);
     } else unreachable;
     const tmp = try self.intcast(lt, Type.u32, Type.u64);
     const tmp_op = try (try self.binOp(low_op_res, tmp, Type.u64, op)).toLocal(self, Type.u64);
@@ -4400,27 +4399,23 @@ fn airAddSubWithOverflowBigInt(self: *Self, lhs: WValue, rhs: WValue, ty: Type,
         const wrap = try self.binOp(to_wrap, xor_op, Type.u64, .@"and");
         break :blk try self.cmp(wrap, .{ .imm64 = 0 }, Type.i64, .lt); // i64 because signed
     } else blk: {
-        const eq = try self.cmp(tmp_op, lhs_low_bit, Type.u64, .eq);
-        const op_eq = try self.cmp(tmp_op, lhs_low_bit, Type.u64, if (op == .add) .lt else .gt);
-
         const first_arg = if (op == .sub) arg: {
             break :arg try self.cmp(high_op_res, lhs_high_bit, Type.u64, .gt);
         } else lt;
 
         try self.emitWValue(first_arg);
-        try self.emitWValue(op_eq);
-        try self.emitWValue(eq);
+        _ = try self.cmp(tmp_op, lhs_low_bit, Type.u64, if (op == .add) .lt else .gt);
+        _ = try self.cmp(tmp_op, lhs_low_bit, Type.u64, .eq);
         try self.addTag(.select);
 
-        const overflow_bit = try self.allocLocal(Type.initTag(.u1));
-        try self.addLabel(.local_set, overflow_bit.local);
-        break :blk overflow_bit;
+        break :blk WValue{ .stack = {} };
     };
+    const overflow_local = try overflow_bit.toLocal(self, Type.initTag(.u1));
 
     const result_ptr = try self.allocStack(result_ty);
     try self.store(result_ptr, high_op_res, Type.u64, 0);
     try self.store(result_ptr, tmp_op, Type.u64, 8);
-    try self.store(result_ptr, overflow_bit, Type.initTag(.u1), 16);
+    try self.store(result_ptr, overflow_local, Type.initTag(.u1), 16);
 
     return result_ptr;
 }
@@ -4455,11 +4450,12 @@ fn airShlWithOverflow(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
         const shr = try (try self.binOp(result, rhs, lhs_ty, .shr)).toLocal(self, lhs_ty);
         break :blk try self.cmp(lhs, shr, lhs_ty, .neq);
     };
+    const overflow_local = try overflow_bit.toLocal(self, Type.initTag(.u1));
 
     const result_ptr = try self.allocStack(self.air.typeOfIndex(inst));
     try self.store(result_ptr, result, lhs_ty, 0);
     const offset = @intCast(u32, lhs_ty.abiSize(self.target));
-    try self.store(result_ptr, overflow_bit, Type.initTag(.u1), offset);
+    try self.store(result_ptr, overflow_local, Type.initTag(.u1), offset);
 
     return result_ptr;
 }
@@ -4502,8 +4498,7 @@ fn airMulWithOverflow(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
         if (int_info.signedness == .unsigned) {
             const shr = try self.binOp(bin_op, .{ .imm64 = int_info.bits }, new_ty, .shr);
             const wrap = try self.intcast(shr, new_ty, lhs_ty);
-            const cmp_res = try self.cmp(wrap, zero, lhs_ty, .neq);
-            try self.emitWValue(cmp_res);
+            _ = try self.cmp(wrap, zero, lhs_ty, .neq);
             try self.addLabel(.local_set, overflow_bit.local);
             break :blk try self.intcast(bin_op, new_ty, lhs_ty);
         } else {
@@ -4512,8 +4507,7 @@ fn airMulWithOverflow(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
 
             const shr_res = try self.binOp(bin_op, .{ .imm64 = int_info.bits }, new_ty, .shr);
             const down_shr_res = try self.intcast(shr_res, new_ty, lhs_ty);
-            const cmp_res = try self.cmp(down_shr_res, shr, lhs_ty, .neq);
-            try self.emitWValue(cmp_res);
+            _ = try self.cmp(down_shr_res, shr, lhs_ty, .neq);
             try self.addLabel(.local_set, overflow_bit.local);
             break :blk down_cast;
         }
@@ -4522,8 +4516,7 @@ fn airMulWithOverflow(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
         const rhs_abs = try self.signAbsValue(rhs, lhs_ty);
         const bin_op = try (try self.binOp(lhs_abs, rhs_abs, lhs_ty, .mul)).toLocal(self, lhs_ty);
         const mul_abs = try self.signAbsValue(bin_op, lhs_ty);
-        const cmp_op = try self.cmp(mul_abs, bin_op, lhs_ty, .neq);
-        try self.emitWValue(cmp_op);
+        _ = try self.cmp(mul_abs, bin_op, lhs_ty, .neq);
         try self.addLabel(.local_set, overflow_bit.local);
         break :blk try self.wrapOperand(bin_op, lhs_ty);
     } else blk: {
@@ -4533,8 +4526,7 @@ fn airMulWithOverflow(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
         else
             WValue{ .imm64 = int_info.bits };
         const shr = try self.binOp(bin_op, shift_imm, lhs_ty, .shr);
-        const cmp_op = try self.cmp(shr, zero, lhs_ty, .neq);
-        try self.emitWValue(cmp_op);
+        _ = try self.cmp(shr, zero, lhs_ty, .neq);
         try self.addLabel(.local_set, overflow_bit.local);
         break :blk try self.wrapOperand(bin_op, lhs_ty);
     };
@@ -4562,12 +4554,10 @@ fn airMaxMin(self: *Self, inst: Air.Inst.Index, op: enum { max, min }) InnerErro
     const lhs = try self.resolveInst(bin_op.lhs);
     const rhs = try self.resolveInst(bin_op.rhs);
 
-    const cmp_result = try self.cmp(lhs, rhs, ty, if (op == .max) .gt else .lt);
-
     // operands to select from
     try self.lowerToStack(lhs);
     try self.lowerToStack(rhs);
-    try self.emitWValue(cmp_result);
+    _ = try self.cmp(lhs, rhs, ty, if (op == .max) .gt else .lt);
 
     // based on the result from comparison, return operand 0 or 1.
     try self.addTag(.select);
@@ -4638,7 +4628,6 @@ fn airClz(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
         128 => {
             const msb = try self.load(operand, Type.u64, 0);
             const lsb = try self.load(operand, Type.u64, 8);
-            const neq = try self.cmp(lsb, .{ .imm64 = 0 }, Type.u64, .neq);
 
             try self.emitWValue(lsb);
             try self.addTag(.i64_clz);
@@ -4646,7 +4635,7 @@ fn airClz(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
             try self.addTag(.i64_clz);
             try self.emitWValue(.{ .imm64 = 64 });
             try self.addTag(.i64_add);
-            try self.emitWValue(neq);
+            _ = try self.cmp(lsb, .{ .imm64 = 0 }, Type.u64, .neq);
             try self.addTag(.select);
             try self.addTag(.i32_wrap_i64);
         },
@@ -4700,7 +4689,6 @@ fn airCtz(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
         128 => {
             const msb = try self.load(operand, Type.u64, 0);
             const lsb = try self.load(operand, Type.u64, 8);
-            const neq = try self.cmp(msb, .{ .imm64 = 0 }, Type.u64, .neq);
 
             try self.emitWValue(msb);
             try self.addTag(.i64_ctz);
@@ -4716,7 +4704,7 @@ fn airCtz(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
             } else {
                 try self.addTag(.i64_add);
             }
-            try self.emitWValue(neq);
+            _ = try self.cmp(msb, .{ .imm64 = 0 }, Type.u64, .neq);
             try self.addTag(.select);
             try self.addTag(.i32_wrap_i64);
         },
@@ -4952,15 +4940,13 @@ fn airDivFloor(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
             64 => WValue{ .imm64 = 0 },
             else => unreachable,
         };
-        const lhs_less_than_zero = try self.cmp(lhs_res, zero, ty, .lt);
-        const rhs_less_than_zero = try self.cmp(rhs_res, zero, ty, .lt);
 
         const div_result = try self.allocLocal(ty);
         // leave on stack
         _ = try self.binOp(lhs_res, rhs_res, ty, .div);
         try self.addLabel(.local_tee, div_result.local);
-        try self.emitWValue(lhs_less_than_zero);
-        try self.emitWValue(rhs_less_than_zero);
+        _ = try self.cmp(lhs_res, zero, ty, .lt);
+        _ = try self.cmp(rhs_res, zero, ty, .lt);
         switch (wasm_bits) {
             32 => {
                 try self.addTag(.i32_xor);
@@ -5140,19 +5126,17 @@ fn airSatBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
             else => unreachable,
         };
 
-        const cmp_result = try self.cmp(bin_result, imm_val, ty, .lt);
         try self.emitWValue(bin_result);
         try self.emitWValue(imm_val);
-        try self.emitWValue(cmp_result);
+        _ = try self.cmp(bin_result, imm_val, ty, .lt);
     } else {
-        const cmp_result = try self.cmp(bin_result, lhs, ty, if (op == .add) .lt else .gt);
         switch (wasm_bits) {
             32 => try self.addImm32(if (op == .add) @as(i32, -1) else 0),
             64 => try self.addImm64(if (op == .add) @bitCast(u64, @as(i64, -1)) else 0),
             else => unreachable,
         }
         try self.emitWValue(bin_result);
-        try self.emitWValue(cmp_result);
+        _ = try self.cmp(bin_result, lhs, ty, if (op == .add) .lt else .gt);
     }
 
     try self.addTag(.select);
@@ -5184,17 +5168,15 @@ fn signedSat(self: *Self, lhs_operand: WValue, rhs_operand: WValue, ty: Type, op
 
     const bin_result = try (try self.binOp(lhs, rhs, ty, op)).toLocal(self, ty);
     if (!is_wasm_bits) {
-        const cmp_result_lt = try self.cmp(bin_result, max_wvalue, ty, .lt);
         try self.emitWValue(bin_result);
         try self.emitWValue(max_wvalue);
-        try self.emitWValue(cmp_result_lt);
+        _ = try self.cmp(bin_result, max_wvalue, ty, .lt);
         try self.addTag(.select);
         try self.addLabel(.local_set, bin_result.local); // re-use local
 
-        const cmp_result_gt = try self.cmp(bin_result, min_wvalue, ty, .gt);
         try self.emitWValue(bin_result);
         try self.emitWValue(min_wvalue);
-        try self.emitWValue(cmp_result_gt);
+        _ = try self.cmp(bin_result, min_wvalue, ty, .gt);
         try self.addTag(.select);
         try self.addLabel(.local_set, bin_result.local); // re-use local
         return self.wrapOperand(bin_result, ty);
@@ -5204,15 +5186,14 @@ fn signedSat(self: *Self, lhs_operand: WValue, rhs_operand: WValue, ty: Type, op
             64 => WValue{ .imm64 = 0 },
             else => unreachable,
         };
-        const cmp_bin_result = try self.cmp(bin_result, lhs, ty, .lt);
-        const cmp_zero_result = try self.cmp(rhs, zero, ty, if (op == .add) .lt else .gt);
-        const cmp_bin_zero_result = try self.cmp(bin_result, zero, ty, .lt);
         try self.emitWValue(max_wvalue);
         try self.emitWValue(min_wvalue);
-        try self.emitWValue(cmp_bin_zero_result);
+        _ = try self.cmp(bin_result, zero, ty, .lt);
         try self.addTag(.select);
         try self.emitWValue(bin_result);
         // leave on stack
+        const cmp_zero_result = try self.cmp(rhs, zero, ty, if (op == .add) .lt else .gt);
+        const cmp_bin_result = try self.cmp(bin_result, lhs, ty, .lt);
         _ = try self.binOp(cmp_zero_result, cmp_bin_result, Type.u32, .xor); // comparisons always return i32, so provide u32 as type to xor.
         try self.addTag(.select);
         try self.addLabel(.local_set, bin_result.local); // re-use local
@@ -5239,7 +5220,6 @@ fn airShlSat(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
     if (wasm_bits == int_info.bits) {
         const shl = try (try self.binOp(lhs, rhs, ty, .shl)).toLocal(self, ty);
         const shr = try (try self.binOp(shl, rhs, ty, .shr)).toLocal(self, ty);
-        const cmp_result = try self.cmp(lhs, shr, ty, .neq);
 
         switch (wasm_bits) {
             32 => blk: {
@@ -5247,10 +5227,9 @@ fn airShlSat(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
                     try self.addImm32(-1);
                     break :blk;
                 }
-                const less_than_zero = try self.cmp(lhs, .{ .imm32 = 0 }, ty, .lt);
                 try self.addImm32(std.math.minInt(i32));
                 try self.addImm32(std.math.maxInt(i32));
-                try self.emitWValue(less_than_zero);
+                _ = try self.cmp(lhs, .{ .imm32 = 0 }, ty, .lt);
                 try self.addTag(.select);
             },
             64 => blk: {
@@ -5258,16 +5237,15 @@ fn airShlSat(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
                     try self.addImm64(@bitCast(u64, @as(i64, -1)));
                     break :blk;
                 }
-                const less_than_zero = try self.cmp(lhs, .{ .imm64 = 0 }, ty, .lt);
                 try self.addImm64(@bitCast(u64, @as(i64, std.math.minInt(i64))));
                 try self.addImm64(@bitCast(u64, @as(i64, std.math.maxInt(i64))));
-                try self.emitWValue(less_than_zero);
+                _ = try self.cmp(lhs, .{ .imm64 = 0 }, ty, .lt);
                 try self.addTag(.select);
             },
             else => unreachable,
         }
         try self.emitWValue(shl);
-        try self.emitWValue(cmp_result);
+        _ = try self.cmp(lhs, shr, ty, .neq);
         try self.addTag(.select);
         try self.addLabel(.local_set, result.local);
         return result;
@@ -5282,7 +5260,6 @@ fn airShlSat(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
         const shl_res = try (try self.binOp(lhs, shift_value, ty, .shl)).toLocal(self, ty);
         const shl = try (try self.binOp(shl_res, rhs, ty, .shl)).toLocal(self, ty);
         const shr = try (try self.binOp(shl, rhs, ty, .shr)).toLocal(self, ty);
-        const cmp_result = try self.cmp(shl_res, shr, ty, .neq);
 
         switch (wasm_bits) {
             32 => blk: {
@@ -5291,10 +5268,9 @@ fn airShlSat(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
                     break :blk;
                 }
 
-                const less_than_zero = try self.cmp(shl_res, .{ .imm32 = 0 }, ty, .lt);
                 try self.addImm32(std.math.minInt(i32));
                 try self.addImm32(std.math.maxInt(i32));
-                try self.emitWValue(less_than_zero);
+                _ = try self.cmp(shl_res, .{ .imm32 = 0 }, ty, .lt);
                 try self.addTag(.select);
             },
             64 => blk: {
@@ -5303,16 +5279,15 @@ fn airShlSat(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
                     break :blk;
                 }
 
-                const less_than_zero = try self.cmp(shl_res, .{ .imm64 = 0 }, ty, .lt);
                 try self.addImm64(@bitCast(u64, @as(i64, std.math.minInt(i64))));
                 try self.addImm64(@bitCast(u64, @as(i64, std.math.maxInt(i64))));
-                try self.emitWValue(less_than_zero);
+                _ = try self.cmp(shl_res, .{ .imm64 = 0 }, ty, .lt);
                 try self.addTag(.select);
             },
             else => unreachable,
         }
         try self.emitWValue(shl);
-        try self.emitWValue(cmp_result);
+        _ = try self.cmp(shl_res, shr, ty, .neq);
         try self.addTag(.select);
         try self.addLabel(.local_set, result.local);
         const shift_result = try (try self.binOp(result, shift_value, ty, .shr)).toLocal(self, ty);