Commit d71312d104

Pavel Verigo <paul.verigo@gmail.com>
2024-07-23 17:06:18
stage2-wasm: mul_sat 32 bits <=, i64, i128
1 parent 0c6aa44
Changed files (2)
src
arch
test
src/arch/wasm/CodeGen.zig
@@ -1837,6 +1837,7 @@ fn genInst(func: *CodeGen, inst: Air.Inst.Index) InnerError!void {
         .sub_sat => func.airSatBinOp(inst, .sub),
         .sub_wrap => func.airWrapBinOp(inst, .sub),
         .mul => func.airBinOp(inst, .mul),
+        .mul_sat => func.airSatMul(inst),
         .mul_wrap => func.airWrapBinOp(inst, .mul),
         .div_float, .div_exact => func.airDiv(inst),
         .div_trunc => func.airDivTrunc(inst),
@@ -2002,7 +2003,6 @@ fn genInst(func: *CodeGen, inst: Air.Inst.Index) InnerError!void {
         .error_set_has_value => func.airErrorSetHasValue(inst),
         .frame_addr => func.airFrameAddress(inst),
 
-        .mul_sat,
         .assembly,
         .is_err_ptr,
         .is_non_err_ptr,
@@ -6783,6 +6783,106 @@ fn airMod(func: *CodeGen, inst: Air.Inst.Index) InnerError!void {
     return func.finishAir(inst, .stack, &.{ bin_op.lhs, bin_op.rhs });
 }
 
+fn airSatMul(func: *CodeGen, inst: Air.Inst.Index) InnerError!void {
+    const bin_op = func.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
+
+    const pt = func.pt;
+    const mod = pt.zcu;
+    const ty = func.typeOfIndex(inst);
+    const int_info = ty.intInfo(mod);
+    const is_signed = int_info.signedness == .signed;
+
+    const lhs = try func.resolveInst(bin_op.lhs);
+    const rhs = try func.resolveInst(bin_op.rhs);
+    const wasm_bits = toWasmBits(int_info.bits) orelse {
+        return func.fail("TODO: mul_sat for {}", .{ty.fmt(pt)});
+    };
+
+    switch (wasm_bits) {
+        32 => {
+            const upcast_ty: Type = if (is_signed) Type.i64 else Type.u64;
+            const lhs_up = try func.intcast(lhs, ty, upcast_ty);
+            const rhs_up = try func.intcast(rhs, ty, upcast_ty);
+            var mul_res = try (try func.binOp(lhs_up, rhs_up, upcast_ty, .mul)).toLocal(func, upcast_ty);
+            defer mul_res.free(func);
+            if (is_signed) {
+                const imm_max: WValue = .{ .imm64 = ~@as(u64, 0) >> @intCast(64 - (int_info.bits - 1)) };
+                try func.emitWValue(mul_res);
+                try func.emitWValue(imm_max);
+                _ = try func.cmp(mul_res, imm_max, upcast_ty, .lt);
+                try func.addTag(.select);
+
+                var tmp = try func.allocLocal(upcast_ty);
+                defer tmp.free(func);
+                try func.addLabel(.local_set, tmp.local.value);
+
+                const imm_min: WValue = .{ .imm64 = ~@as(u64, 0) << @intCast(int_info.bits - 1) };
+                try func.emitWValue(tmp);
+                try func.emitWValue(imm_min);
+                _ = try func.cmp(tmp, imm_min, upcast_ty, .gt);
+                try func.addTag(.select);
+            } else {
+                const imm_max: WValue = .{ .imm64 = ~@as(u64, 0) >> @intCast(64 - int_info.bits) };
+                try func.emitWValue(mul_res);
+                try func.emitWValue(imm_max);
+                _ = try func.cmp(mul_res, imm_max, upcast_ty, .lt);
+                try func.addTag(.select);
+            }
+            try func.addTag(.i32_wrap_i64);
+        },
+        64 => {
+            if (!(int_info.bits == 64 and int_info.signedness == .signed)) {
+                return func.fail("TODO: mul_sat for {}", .{ty.fmt(pt)});
+            }
+            const overflow_ret = try func.allocStack(Type.i32);
+            _ = try func.callIntrinsic(
+                "__mulodi4",
+                &[_]InternPool.Index{ .i64_type, .i64_type, .usize_type },
+                Type.i64,
+                &.{ lhs, rhs, overflow_ret },
+            );
+            const xor = try func.binOp(lhs, rhs, Type.i64, .xor);
+            const sign_v = try func.binOp(xor, .{ .imm64 = 63 }, Type.i64, .shr);
+            _ = try func.binOp(sign_v, .{ .imm64 = ~@as(u63, 0) }, Type.i64, .xor);
+            _ = try func.load(overflow_ret, Type.i32, 0);
+            try func.addTag(.i32_eqz);
+            try func.addTag(.select);
+        },
+        128 => {
+            if (!(int_info.bits == 128 and int_info.signedness == .signed)) {
+                return func.fail("TODO: mul_sat for {}", .{ty.fmt(pt)});
+            }
+            const overflow_ret = try func.allocStack(Type.i32);
+            const ret = try func.callIntrinsic(
+                "__muloti4",
+                &[_]InternPool.Index{ .i128_type, .i128_type, .usize_type },
+                Type.i128,
+                &.{ lhs, rhs, overflow_ret },
+            );
+            try func.lowerToStack(ret);
+            const xor = try func.binOp(lhs, rhs, Type.i128, .xor);
+            const sign_v = try func.binOp(xor, .{ .imm32 = 127 }, Type.i128, .shr);
+
+            // xor ~@as(u127, 0)
+            try func.emitWValue(sign_v);
+            const lsb = try func.load(sign_v, Type.u64, 0);
+            _ = try func.binOp(lsb, .{ .imm64 = ~@as(u64, 0) }, Type.u64, .xor);
+            try func.store(.stack, .stack, Type.u64, sign_v.offset());
+            try func.emitWValue(sign_v);
+            const msb = try func.load(sign_v, Type.u64, 8);
+            _ = try func.binOp(msb, .{ .imm64 = ~@as(u63, 0) }, Type.u64, .xor);
+            try func.store(.stack, .stack, Type.u64, sign_v.offset() + 8);
+
+            try func.lowerToStack(sign_v);
+            _ = try func.load(overflow_ret, Type.i32, 0);
+            try func.addTag(.i32_eqz);
+            try func.addTag(.select);
+        },
+        else => unreachable,
+    }
+    return func.finishAir(inst, .stack, &.{ bin_op.lhs, bin_op.rhs });
+}
+
 fn airSatBinOp(func: *CodeGen, inst: Air.Inst.Index, op: Op) InnerError!void {
     assert(op == .add or op == .sub);
     const bin_op = func.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
test/behavior/saturating_arithmetic.zig
@@ -154,6 +154,109 @@ test "saturating subtraction 128bit" {
     try comptime S.doTheTest();
 }
 
+fn testSatMul(comptime T: type, a: T, b: T, expected: T) !void {
+    const res: T = a *| b;
+    try expect(res == expected);
+}
+
+test "saturating multiplication <= 32 bits" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_c and comptime builtin.cpu.arch.isArmOrThumb()) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
+
+    if (builtin.zig_backend == .stage2_llvm and builtin.cpu.arch == .wasm32) {
+        // https://github.com/ziglang/zig/issues/9660
+        return error.SkipZigTest;
+    }
+
+    try testSatMul(u8, 0, maxInt(u8), 0);
+    try testSatMul(u8, 1 << 7, 1 << 7, maxInt(u8));
+    try testSatMul(u8, maxInt(u8) - 1, 2, maxInt(u8));
+    try testSatMul(u8, 1 << 4, 1 << 4, maxInt(u8));
+    try testSatMul(u8, 1 << 4, 1 << 3, 1 << 7);
+    try testSatMul(u8, 1 << 5, 1 << 3, maxInt(u8));
+    try testSatMul(u8, 10, 20, 200);
+
+    try testSatMul(u16, 0, maxInt(u16), 0);
+    try testSatMul(u16, 1 << 15, 1 << 15, maxInt(u16));
+    try testSatMul(u16, maxInt(u16) - 1, 2, maxInt(u16));
+    try testSatMul(u16, 1 << 8, 1 << 8, maxInt(u16));
+    try testSatMul(u16, 1 << 12, 1 << 3, 1 << 15);
+    try testSatMul(u16, 1 << 13, 1 << 3, maxInt(u16));
+    try testSatMul(u16, 10, 20, 200);
+
+    try testSatMul(u32, 0, maxInt(u32), 0);
+    try testSatMul(u32, 1 << 31, 1 << 31, maxInt(u32));
+    try testSatMul(u32, maxInt(u32) - 1, 2, maxInt(u32));
+    try testSatMul(u32, 1 << 16, 1 << 16, maxInt(u32));
+    try testSatMul(u32, 1 << 28, 1 << 3, 1 << 31);
+    try testSatMul(u32, 1 << 29, 1 << 3, maxInt(u32));
+    try testSatMul(u32, 10, 20, 200);
+
+    try testSatMul(i8, 0, maxInt(i8), 0);
+    try testSatMul(i8, 0, minInt(i8), 0);
+    try testSatMul(i8, 1 << 6, 1 << 6, maxInt(i8));
+    try testSatMul(i8, minInt(i8), minInt(i8), maxInt(i8));
+    try testSatMul(i8, maxInt(i8) - 1, 2, maxInt(i8));
+    try testSatMul(i8, minInt(i8) + 1, 2, minInt(i8));
+    try testSatMul(i8, 1 << 4, 1 << 4, maxInt(i8));
+    try testSatMul(i8, minInt(i4), 1 << 4, minInt(i8));
+    try testSatMul(i8, 10, 12, 120);
+    try testSatMul(i8, 10, -12, -120);
+
+    try testSatMul(i16, 0, maxInt(i16), 0);
+    try testSatMul(i16, 0, minInt(i16), 0);
+    try testSatMul(i16, 1 << 14, 1 << 14, maxInt(i16));
+    try testSatMul(i16, minInt(i16), minInt(i16), maxInt(i16));
+    try testSatMul(i16, maxInt(i16) - 1, 2, maxInt(i16));
+    try testSatMul(i16, minInt(i16) + 1, 2, minInt(i16));
+    try testSatMul(i16, 1 << 8, 1 << 8, maxInt(i16));
+    try testSatMul(i16, minInt(i8), 1 << 8, minInt(i16));
+    try testSatMul(i16, 10, 12, 120);
+    try testSatMul(i16, 10, -12, -120);
+
+    try testSatMul(i32, 0, maxInt(i32), 0);
+    try testSatMul(i32, 0, minInt(i32), 0);
+    try testSatMul(i32, 1 << 30, 1 << 30, maxInt(i32));
+    try testSatMul(i32, minInt(i32), minInt(i32), maxInt(i32));
+    try testSatMul(i32, maxInt(i32) - 1, 2, maxInt(i32));
+    try testSatMul(i32, minInt(i32) + 1, 2, minInt(i32));
+    try testSatMul(i32, 1 << 16, 1 << 16, maxInt(i32));
+    try testSatMul(i32, minInt(i16), 1 << 16, minInt(i32));
+    try testSatMul(i32, 10, 12, 120);
+    try testSatMul(i32, 10, -12, -120);
+}
+
+// TODO: remove this test, integrate into general test
+test "saturating mul i64, i128, wasm only" {
+    if (builtin.zig_backend != .stage2_wasm) return error.SkipZigTest;
+
+    try testSatMul(i64, 0, maxInt(i64), 0);
+    try testSatMul(i64, 0, minInt(i64), 0);
+    try testSatMul(i64, 1 << 62, 1 << 62, maxInt(i64));
+    try testSatMul(i64, minInt(i64), minInt(i64), maxInt(i64));
+    try testSatMul(i64, maxInt(i64) - 1, 2, maxInt(i64));
+    try testSatMul(i64, minInt(i64) + 1, 2, minInt(i64));
+    try testSatMul(i64, 1 << 32, 1 << 32, maxInt(i64));
+    try testSatMul(i64, minInt(i32), 1 << 32, minInt(i64));
+    try testSatMul(i64, 10, 12, 120);
+    try testSatMul(i64, 10, -12, -120);
+
+    try testSatMul(i128, 0, maxInt(i128), 0);
+    try testSatMul(i128, 0, minInt(i128), 0);
+    try testSatMul(i128, 1 << 126, 1 << 126, maxInt(i128));
+    try testSatMul(i128, minInt(i128), minInt(i128), maxInt(i128));
+    try testSatMul(i128, maxInt(i128) - 1, 2, maxInt(i128));
+    try testSatMul(i128, minInt(i128) + 1, 2, minInt(i128));
+    try testSatMul(i128, 1 << 64, 1 << 64, maxInt(i128));
+    try testSatMul(i128, minInt(i64), 1 << 64, minInt(i128));
+    try testSatMul(i128, 10, 12, 120);
+    try testSatMul(i128, 10, -12, -120);
+}
+
 test "saturating multiplication" {
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
@@ -183,23 +286,15 @@ test "saturating multiplication" {
             try testSatMul(u8, 2, 255, 255);
             try testSatMul(u128, maxInt(u128), maxInt(u128), maxInt(u128));
         }
-
-        fn testSatMul(comptime T: type, lhs: T, rhs: T, expected: T) !void {
-            try expect((lhs *| rhs) == expected);
-
-            var x = lhs;
-            x *|= rhs;
-            try expect(x == expected);
-        }
     };
 
     try S.doTheTest();
     try comptime S.doTheTest();
 
-    try comptime S.testSatMul(comptime_int, 0, 0, 0);
-    try comptime S.testSatMul(comptime_int, 3, 2, 6);
-    try comptime S.testSatMul(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 304852860194144160265083087140337419215516305999637969803722975979232817921935);
-    try comptime S.testSatMul(comptime_int, 7, -593423721213448152027139550640105366508, -4153966048494137064189976854480737565556);
+    try comptime testSatMul(comptime_int, 0, 0, 0);
+    try comptime testSatMul(comptime_int, 3, 2, 6);
+    try comptime testSatMul(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 304852860194144160265083087140337419215516305999637969803722975979232817921935);
+    try comptime testSatMul(comptime_int, 7, -593423721213448152027139550640105366508, -4153966048494137064189976854480737565556);
 }
 
 test "saturating shift-left" {