Commit d5991ee7ca

Xavier Bouchoux <xavierb@gmail.com>
2023-10-14 10:15:11
codegen/wasm: fix non-byte-sized signed integer comparison
1 parent 27a1990
Changed files (2)
src
arch
test
behavior
src/arch/wasm/CodeGen.zig
@@ -3602,11 +3602,6 @@ fn cmp(func: *CodeGen, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareO
         return func.cmpBigInt(lhs, rhs, ty, op);
     }
 
-    // ensure that when we compare pointers, we emit
-    // the true pointer of a stack value, rather than the stack pointer.
-    try func.lowerToStack(lhs);
-    try func.lowerToStack(rhs);
-
     const signedness: std.builtin.Signedness = blk: {
         // by default we tell the operand type is unsigned (i.e. bools and enum values)
         if (ty.zigTypeTag(mod) != .Int) break :blk .unsigned;
@@ -3614,6 +3609,30 @@ fn cmp(func: *CodeGen, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareO
         // incase of an actual integer, we emit the correct signedness
         break :blk ty.intInfo(mod).signedness;
     };
+    const extend_sign = blk: {
+        // do we need to extend the sign bit?
+        if (signedness != .signed) break :blk false;
+        if (op == .eq or op == .neq) break :blk false;
+        const int_bits = ty.intInfo(mod).bits;
+        const wasm_bits = toWasmBits(int_bits) orelse unreachable;
+        break :blk (wasm_bits != int_bits);
+    };
+
+    const lhs_wasm = if (extend_sign)
+        try func.signExtendInt(lhs, ty)
+    else
+        lhs;
+
+    const rhs_wasm = if (extend_sign)
+        try func.signExtendInt(rhs, ty)
+    else
+        rhs;
+
+    // ensure that when we compare pointers, we emit
+    // the true pointer of a stack value, rather than the stack pointer.
+    try func.lowerToStack(lhs_wasm);
+    try func.lowerToStack(rhs_wasm);
+
     const opcode: wasm.Opcode = buildOpcode(.{
         .valtype1 = typeToValtype(ty, mod),
         .op = switch (op) {
@@ -6920,12 +6939,13 @@ fn signedSat(func: *CodeGen, lhs_operand: WValue, rhs_operand: WValue, ty: Type,
     const int_info = ty.intInfo(mod);
     const wasm_bits = toWasmBits(int_info.bits).?;
     const is_wasm_bits = wasm_bits == int_info.bits;
+    const ext_ty = if (!is_wasm_bits) try mod.intType(int_info.signedness, wasm_bits) else ty;
 
     var lhs = if (!is_wasm_bits) lhs: {
-        break :lhs try (try func.signExtendInt(lhs_operand, ty)).toLocal(func, ty);
+        break :lhs try (try func.signExtendInt(lhs_operand, ty)).toLocal(func, ext_ty);
     } else lhs_operand;
     var rhs = if (!is_wasm_bits) rhs: {
-        break :rhs try (try func.signExtendInt(rhs_operand, ty)).toLocal(func, ty);
+        break :rhs try (try func.signExtendInt(rhs_operand, ty)).toLocal(func, ext_ty);
     } else rhs_operand;
 
     const max_val: u64 = @as(u64, @intCast((@as(u65, 1) << @as(u7, @intCast(int_info.bits - 1))) - 1));
@@ -6941,20 +6961,20 @@ fn signedSat(func: *CodeGen, lhs_operand: WValue, rhs_operand: WValue, ty: Type,
         else => unreachable,
     };
 
-    var bin_result = try (try func.binOp(lhs, rhs, ty, op)).toLocal(func, ty);
+    var bin_result = try (try func.binOp(lhs, rhs, ext_ty, op)).toLocal(func, ext_ty);
     if (!is_wasm_bits) {
         defer bin_result.free(func); // not returned in this branch
         defer lhs.free(func); // uses temporary local for absvalue
         defer rhs.free(func); // uses temporary local for absvalue
         try func.emitWValue(bin_result);
         try func.emitWValue(max_wvalue);
-        _ = try func.cmp(bin_result, max_wvalue, ty, .lt);
+        _ = try func.cmp(bin_result, max_wvalue, ext_ty, .lt);
         try func.addTag(.select);
         try func.addLabel(.local_set, bin_result.local.value); // re-use local
 
         try func.emitWValue(bin_result);
         try func.emitWValue(min_wvalue);
-        _ = try func.cmp(bin_result, min_wvalue, ty, .gt);
+        _ = try func.cmp(bin_result, min_wvalue, ext_ty, .gt);
         try func.addTag(.select);
         try func.addLabel(.local_set, bin_result.local.value); // re-use local
         return (try func.wrapOperand(bin_result, ty)).toLocal(func, ty);
@@ -7036,12 +7056,13 @@ fn airShlSat(func: *CodeGen, inst: Air.Inst.Index) InnerError!void {
             64 => WValue{ .imm64 = shift_size },
             else => unreachable,
         };
+        const ext_ty = try mod.intType(int_info.signedness, wasm_bits);
 
-        var shl_res = try (try func.binOp(lhs, shift_value, ty, .shl)).toLocal(func, ty);
+        var shl_res = try (try func.binOp(lhs, shift_value, ext_ty, .shl)).toLocal(func, ext_ty);
         defer shl_res.free(func);
-        var shl = try (try func.binOp(shl_res, rhs, ty, .shl)).toLocal(func, ty);
+        var shl = try (try func.binOp(shl_res, rhs, ext_ty, .shl)).toLocal(func, ext_ty);
         defer shl.free(func);
-        var shr = try (try func.binOp(shl, rhs, ty, .shr)).toLocal(func, ty);
+        var shr = try (try func.binOp(shl, rhs, ext_ty, .shr)).toLocal(func, ext_ty);
         defer shr.free(func);
 
         switch (wasm_bits) {
@@ -7053,7 +7074,7 @@ fn airShlSat(func: *CodeGen, inst: Air.Inst.Index) InnerError!void {
 
                 try func.addImm32(std.math.minInt(i32));
                 try func.addImm32(std.math.maxInt(i32));
-                _ = try func.cmp(shl_res, .{ .imm32 = 0 }, ty, .lt);
+                _ = try func.cmp(shl_res, .{ .imm32 = 0 }, ext_ty, .lt);
                 try func.addTag(.select);
             },
             64 => blk: {
@@ -7064,16 +7085,16 @@ fn airShlSat(func: *CodeGen, inst: Air.Inst.Index) InnerError!void {
 
                 try func.addImm64(@as(u64, @bitCast(@as(i64, std.math.minInt(i64)))));
                 try func.addImm64(@as(u64, @bitCast(@as(i64, std.math.maxInt(i64)))));
-                _ = try func.cmp(shl_res, .{ .imm64 = 0 }, ty, .lt);
+                _ = try func.cmp(shl_res, .{ .imm64 = 0 }, ext_ty, .lt);
                 try func.addTag(.select);
             },
             else => unreachable,
         }
         try func.emitWValue(shl);
-        _ = try func.cmp(shl_res, shr, ty, .neq);
+        _ = try func.cmp(shl_res, shr, ext_ty, .neq);
         try func.addTag(.select);
         try func.addLabel(.local_set, result.local.value);
-        var shift_result = try func.binOp(result, shift_value, ty, .shr);
+        var shift_result = try func.binOp(result, shift_value, ext_ty, .shr);
         if (is_signed) {
             shift_result = try func.wrapOperand(shift_result, ty);
         }
test/behavior/basic.zig
@@ -1172,3 +1172,51 @@ test "pointer to struct literal with runtime field is constant" {
     const ptr = &S{ .data = runtime_zero };
     try expect(@typeInfo(@TypeOf(ptr)).Pointer.is_const);
 }
+
+test "integer compare" {
+    const S = struct {
+        fn doTheTestSigned(comptime T: type) !void {
+            var z: T = 0;
+            var p: T = 123;
+            var n: T = -123;
+            try expect(z == z and z != p and z != n);
+            try expect(p == p and p != n and n == n);
+            try expect(z > n and z < p and z >= n and z <= p);
+            try expect(!(z < n or z > p or z <= n or z >= p or z > z or z < z));
+            try expect(p > n and n < p and p >= n and n <= p and p >= p and p <= p and n >= n and n <= n);
+            try expect(!(p < n or n > p or p <= n or n >= p or p > p or p < p or n > n or n < n));
+            try expect(z == 0 and z != 123 and z != -123 and 0 == z and 0 != p and 0 != n);
+            try expect(z > -123 and p > -123 and !(n > 123));
+            try expect(z < 123 and !(p < 123) and n < 123);
+            try expect(-123 <= z and -123 <= p and -123 <= n);
+            try expect(123 >= z and 123 >= p and 123 >= n);
+            try expect(!(0 != z or 123 != p or -123 != n));
+            try expect(!(z > 0 or -123 > p or 123 < n));
+        }
+        fn doTheTestUnsigned(comptime T: type) !void {
+            var z: T = 0;
+            var p: T = 123;
+            try expect(z == z and z != p);
+            try expect(p == p);
+            try expect(z < p and z <= p);
+            try expect(!(z > p or z >= p or z > z or z < z));
+            try expect(p >= p and p <= p);
+            try expect(!(p > p or p < p));
+            try expect(z == 0 and z != 123 and z != -123 and 0 == z and 0 != p);
+            try expect(z > -123 and p > -123);
+            try expect(z < 123 and !(p < 123));
+            try expect(-123 <= z and -123 <= p);
+            try expect(123 >= z and 123 >= p);
+            try expect(!(0 != z or 123 != p));
+            try expect(!(z > 0 or -123 > p));
+        }
+    };
+    inline for (.{ u8, u16, u32, u64, usize, u10, u20, u30, u60 }) |T| {
+        try S.doTheTestUnsigned(T);
+        try comptime S.doTheTestUnsigned(T);
+    }
+    inline for (.{ i8, i16, i32, i64, isize, i10, i20, i30, i60 }) |T| {
+        try S.doTheTestSigned(T);
+        try comptime S.doTheTestSigned(T);
+    }
+}