Commit dc3176d628

Pavel Verigo <paul.verigo@gmail.com>
2024-07-18 17:54:44
stage2-wasm: enhance add/subWithOverflow
Added behavior tests to verify implementation
1 parent d1bd951
Changed files (2)
src
arch
test
behavior
src/arch/wasm/CodeGen.zig
@@ -5980,6 +5980,26 @@ fn airPtrSliceFieldPtr(func: *CodeGen, inst: Air.Inst.Index, offset: u32) InnerE
     return func.finishAir(inst, result, &.{ty_op.operand});
 }
 
+/// NOTE: Allocates place for result on virtual stack, when integer size > 64 bits
+fn intZeroValue(func: *CodeGen, ty: Type) InnerError!WValue {
+    const mod = func.bin_file.base.comp.module.?;
+    const int_info = ty.intInfo(mod);
+    const wasm_bits = toWasmBits(int_info.bits) orelse {
+        return func.fail("TODO: Implement intZeroValue for integer bitsize: {d}", .{int_info.bits});
+    };
+    switch (wasm_bits) {
+        32 => return .{ .imm32 = 0 },
+        64 => return .{ .imm64 = 0 },
+        128 => {
+            const result = try func.allocStack(ty);
+            try func.store(result, .{ .imm64 = 0 }, Type.u64, 0);
+            try func.store(result, .{ .imm64 = 0 }, Type.u64, 8);
+            return result;
+        },
+        else => unreachable,
+    }
+}
+
 fn airAddSubWithOverflow(func: *CodeGen, inst: Air.Inst.Index, op: Op) InnerError!void {
     assert(op == .add or op == .sub);
     const ty_pl = func.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
@@ -5987,124 +6007,44 @@ fn airAddSubWithOverflow(func: *CodeGen, inst: Air.Inst.Index, op: Op) InnerErro
 
     const lhs = try func.resolveInst(extra.lhs);
     const rhs = try func.resolveInst(extra.rhs);
-    const lhs_ty = func.typeOf(extra.lhs);
+    const ty = func.typeOf(extra.lhs);
     const pt = func.pt;
     const mod = pt.zcu;
 
-    if (lhs_ty.zigTypeTag(mod) == .Vector) {
+    if (ty.zigTypeTag(mod) == .Vector) {
         return func.fail("TODO: Implement overflow arithmetic for vectors", .{});
     }
 
-    const int_info = lhs_ty.intInfo(mod);
+    const int_info = ty.intInfo(mod);
     const is_signed = int_info.signedness == .signed;
-    const wasm_bits = toWasmBits(int_info.bits) orelse {
+    if (int_info.bits > 128) {
         return func.fail("TODO: Implement {{add/sub}}_with_overflow for integer bitsize: {d}", .{int_info.bits});
-    };
-
-    if (wasm_bits == 128) {
-        const result = try func.addSubWithOverflowBigInt(lhs, rhs, lhs_ty, func.typeOfIndex(inst), op);
-        return func.finishAir(inst, result, &.{ extra.lhs, extra.rhs });
     }
 
-    const zero: WValue = switch (wasm_bits) {
-        32 => .{ .imm32 = 0 },
-        64 => .{ .imm64 = 0 },
+    const op_result = try func.wrapBinOp(lhs, rhs, ty, op);
+    var op_tmp = try op_result.toLocal(func, ty);
+    defer op_tmp.free(func);
+
+    const cmp_op: std.math.CompareOperator = switch (op) {
+        .add => .lt,
+        .sub => .gt,
         else => unreachable,
     };
-
-    const bin_op = try (try func.binOp(lhs, rhs, lhs_ty, op)).toLocal(func, lhs_ty);
-    var result = if (wasm_bits != int_info.bits) blk: {
-        break :blk try (try func.wrapOperand(bin_op, lhs_ty)).toLocal(func, lhs_ty);
-    } else bin_op;
-    defer result.free(func);
-
-    const cmp_op: std.math.CompareOperator = if (op == .sub) .gt else .lt;
-    const overflow_bit: WValue = if (is_signed) blk: {
-        if (wasm_bits == int_info.bits) {
-            const cmp_zero = try func.cmp(rhs, zero, lhs_ty, cmp_op);
-            const lt = try func.cmp(bin_op, lhs, lhs_ty, .lt);
-            break :blk try func.binOp(cmp_zero, lt, Type.u32, .xor);
-        }
-        break :blk try func.cmp(bin_op, bin_op, lhs_ty, .neq);
-    } else if (wasm_bits == int_info.bits)
-        try func.cmp(bin_op, lhs, lhs_ty, cmp_op)
-    else
-        try func.cmp(bin_op, result, lhs_ty, .neq);
-    var overflow_local = try overflow_bit.toLocal(func, Type.u32);
-    defer overflow_local.free(func);
-
-    const result_ptr = try func.allocStack(func.typeOfIndex(inst));
-    try func.store(result_ptr, result, lhs_ty, 0);
-    const offset = @as(u32, @intCast(lhs_ty.abiSize(pt)));
-    try func.store(result_ptr, overflow_local, Type.u1, offset);
-
-    return func.finishAir(inst, result_ptr, &.{ extra.lhs, extra.rhs });
-}
-
-fn addSubWithOverflowBigInt(func: *CodeGen, lhs: WValue, rhs: WValue, ty: Type, result_ty: Type, op: Op) InnerError!WValue {
-    const pt = func.pt;
-    const mod = pt.zcu;
-    assert(op == .add or op == .sub);
-    const int_info = ty.intInfo(mod);
-    const is_signed = int_info.signedness == .signed;
-    if (int_info.bits != 128) {
-        return func.fail("TODO: Implement @{{add/sub}}WithOverflow for integer bitsize '{d}'", .{int_info.bits});
-    }
-
-    var lhs_high_bit = try (try func.load(lhs, Type.u64, 0)).toLocal(func, Type.u64);
-    defer lhs_high_bit.free(func);
-    var lhs_low_bit = try (try func.load(lhs, Type.u64, 8)).toLocal(func, Type.u64);
-    defer lhs_low_bit.free(func);
-    var rhs_high_bit = try (try func.load(rhs, Type.u64, 0)).toLocal(func, Type.u64);
-    defer rhs_high_bit.free(func);
-    var rhs_low_bit = try (try func.load(rhs, Type.u64, 8)).toLocal(func, Type.u64);
-    defer rhs_low_bit.free(func);
-
-    var low_op_res = try (try func.binOp(lhs_low_bit, rhs_low_bit, Type.u64, op)).toLocal(func, Type.u64);
-    defer low_op_res.free(func);
-    var high_op_res = try (try func.binOp(lhs_high_bit, rhs_high_bit, Type.u64, op)).toLocal(func, Type.u64);
-    defer high_op_res.free(func);
-
-    var lt = if (op == .add) blk: {
-        break :blk try (try func.cmp(high_op_res, lhs_high_bit, Type.u64, .lt)).toLocal(func, Type.u32);
-    } else if (op == .sub) blk: {
-        break :blk try (try func.cmp(lhs_high_bit, rhs_high_bit, Type.u64, .lt)).toLocal(func, Type.u32);
-    } else unreachable;
-    defer lt.free(func);
-    var tmp = try (try func.intcast(lt, Type.u32, Type.u64)).toLocal(func, Type.u64);
-    defer tmp.free(func);
-    var tmp_op = try (try func.binOp(low_op_res, tmp, Type.u64, op)).toLocal(func, Type.u64);
-    defer tmp_op.free(func);
-
     const overflow_bit = if (is_signed) blk: {
-        const xor_low = try func.binOp(lhs_low_bit, rhs_low_bit, Type.u64, .xor);
-        const to_wrap = if (op == .add) wrap: {
-            break :wrap try func.binOp(xor_low, .{ .imm64 = ~@as(u64, 0) }, Type.u64, .xor);
-        } else xor_low;
-        const xor_op = try func.binOp(lhs_low_bit, tmp_op, Type.u64, .xor);
-        const wrap = try func.binOp(to_wrap, xor_op, Type.u64, .@"and");
-        break :blk try func.cmp(wrap, .{ .imm64 = 0 }, Type.i64, .lt); // i64 because signed
-    } else blk: {
-        const first_arg = if (op == .sub) arg: {
-            break :arg try func.cmp(high_op_res, lhs_high_bit, Type.u64, .gt);
-        } else lt;
-
-        try func.emitWValue(first_arg);
-        _ = try func.cmp(tmp_op, lhs_low_bit, Type.u64, if (op == .add) .lt else .gt);
-        _ = try func.cmp(tmp_op, lhs_low_bit, Type.u64, .eq);
-        try func.addTag(.select);
-
-        break :blk .stack;
-    };
-    var overflow_local = try overflow_bit.toLocal(func, Type.u1);
-    defer overflow_local.free(func);
-
-    const result_ptr = try func.allocStack(result_ty);
-    try func.store(result_ptr, high_op_res, Type.u64, 0);
-    try func.store(result_ptr, tmp_op, Type.u64, 8);
-    try func.store(result_ptr, overflow_local, Type.u1, 16);
-
-    return result_ptr;
+        const zero = try intZeroValue(func, ty);
+        const rhs_is_neg = try func.cmp(rhs, zero, ty, .lt);
+        const overflow_cmp = try func.cmp(op_tmp, lhs, ty, cmp_op);
+        break :blk try func.cmp(rhs_is_neg, overflow_cmp, Type.u1, .neq);
+    } else try func.cmp(op_tmp, lhs, ty, cmp_op);
+    var bit_tmp = try overflow_bit.toLocal(func, Type.u1);
+    defer bit_tmp.free(func);
+
+    const result = try func.allocStack(func.typeOfIndex(inst));
+    const offset: u32 = @intCast(ty.abiSize(pt));
+    try func.store(result, op_tmp, ty, 0);
+    try func.store(result, bit_tmp, Type.u1, offset);
+
+    return func.finishAir(inst, result, &.{ extra.lhs, extra.rhs });
 }
 
 fn airShlWithOverflow(func: *CodeGen, inst: Air.Inst.Index) InnerError!void {
test/behavior/math.zig
@@ -828,56 +828,72 @@ test "128-bit multiplication" {
     }
 }
 
+fn testAddWithOverflow(comptime T: type, a: T, b: T, add: T, bit: u1) !void {
+    const ov = @addWithOverflow(a, b);
+    try expect(ov[0] == add);
+    try expect(ov[1] == bit);
+}
+
 test "@addWithOverflow" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
-    {
-        var a: u8 = 250;
-        _ = &a;
-        const ov = @addWithOverflow(a, 100);
-        try expect(ov[0] == 94);
-        try expect(ov[1] == 1);
-    }
-    {
-        var a: u8 = 100;
-        _ = &a;
-        const ov = @addWithOverflow(a, 150);
-        try expect(ov[0] == 250);
-        try expect(ov[1] == 0);
-    }
-    {
-        var a: u8 = 200;
-        _ = &a;
-        var b: u8 = 99;
-        var ov = @addWithOverflow(a, b);
-        try expect(ov[0] == 43);
-        try expect(ov[1] == 1);
-        b = 55;
-        ov = @addWithOverflow(a, b);
-        try expect(ov[0] == 255);
-        try expect(ov[1] == 0);
-    }
+    try testAddWithOverflow(u8, 250, 100, 94, 1);
+    try testAddWithOverflow(u8, 100, 150, 250, 0);
 
-    {
-        var a: usize = 6;
-        var b: usize = 6;
-        _ = .{ &a, &b };
-        const ov = @addWithOverflow(a, b);
-        try expect(ov[0] == 12);
-        try expect(ov[1] == 0);
-    }
+    try testAddWithOverflow(u8, 200, 99, 43, 1);
+    try testAddWithOverflow(u8, 200, 55, 255, 0);
 
-    {
-        var a: isize = -6;
-        var b: isize = -6;
-        _ = .{ &a, &b };
-        const ov = @addWithOverflow(a, b);
-        try expect(ov[0] == -12);
-        try expect(ov[1] == 0);
-    }
+    try testAddWithOverflow(usize, 6, 6, 12, 0);
+    try testAddWithOverflow(usize, maxInt(usize), 6, 5, 1);
+
+    try testAddWithOverflow(isize, -6, -6, -12, 0);
+    try testAddWithOverflow(isize, minInt(isize), -6, maxInt(isize) - 5, 1);
+}
+
+test "@addWithOverflow > 64 bits" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
+
+    try testAddWithOverflow(u65, 4, 105, 109, 0);
+    try testAddWithOverflow(u65, 1000, 100, 1100, 0);
+    try testAddWithOverflow(u65, 100, maxInt(u65) - 99, 0, 1);
+    try testAddWithOverflow(u65, maxInt(u65), maxInt(u65), maxInt(u65) - 1, 1);
+    try testAddWithOverflow(u65, maxInt(u65) - 1, maxInt(u65), maxInt(u65) - 2, 1);
+    try testAddWithOverflow(u65, maxInt(u65), maxInt(u65) - 1, maxInt(u65) - 2, 1);
+
+    try testAddWithOverflow(u128, 4, 105, 109, 0);
+    try testAddWithOverflow(u128, 1000, 100, 1100, 0);
+    try testAddWithOverflow(u128, 100, maxInt(u128) - 99, 0, 1);
+    try testAddWithOverflow(u128, maxInt(u128), maxInt(u128), maxInt(u128) - 1, 1);
+    try testAddWithOverflow(u128, maxInt(u128) - 1, maxInt(u128), maxInt(u128) - 2, 1);
+    try testAddWithOverflow(u128, maxInt(u128), maxInt(u128) - 1, maxInt(u128) - 2, 1);
+
+    try testAddWithOverflow(i65, 4, -105, -101, 0);
+    try testAddWithOverflow(i65, 1000, 100, 1100, 0);
+    try testAddWithOverflow(i65, minInt(i65), 1, minInt(i65) + 1, 0);
+    try testAddWithOverflow(i65, maxInt(i65), minInt(i65), -1, 0);
+    try testAddWithOverflow(i65, minInt(i65), maxInt(i65), -1, 0);
+    try testAddWithOverflow(i65, maxInt(i65), -2, maxInt(i65) - 2, 0);
+    try testAddWithOverflow(i65, maxInt(i65), maxInt(i65), -2, 1);
+    try testAddWithOverflow(i65, minInt(i65), minInt(i65), 0, 1);
+    try testAddWithOverflow(i65, maxInt(i65) - 1, maxInt(i65), -3, 1);
+    try testAddWithOverflow(i65, maxInt(i65), maxInt(i65) - 1, -3, 1);
+
+    try testAddWithOverflow(i128, 4, -105, -101, 0);
+    try testAddWithOverflow(i128, 1000, 100, 1100, 0);
+    try testAddWithOverflow(i128, minInt(i128), 1, minInt(i128) + 1, 0);
+    try testAddWithOverflow(i128, maxInt(i128), minInt(i128), -1, 0);
+    try testAddWithOverflow(i128, minInt(i128), maxInt(i128), -1, 0);
+    try testAddWithOverflow(i128, maxInt(i128), -2, maxInt(i128) - 2, 0);
+    try testAddWithOverflow(i128, maxInt(i128), maxInt(i128), -2, 1);
+    try testAddWithOverflow(i128, minInt(i128), minInt(i128), 0, 1);
+    try testAddWithOverflow(i128, maxInt(i128) - 1, maxInt(i128), -3, 1);
+    try testAddWithOverflow(i128, maxInt(i128), maxInt(i128) - 1, -3, 1);
 }
 
 test "small int addition" {
@@ -1265,56 +1281,68 @@ test "@mulWithOverflow u256" {
     }
 }
 
+fn testSubWithOverflow(comptime T: type, a: T, b: T, sub: T, bit: u1) !void {
+    const ov = @subWithOverflow(a, b);
+    try expect(ov[0] == sub);
+    try expect(ov[1] == bit);
+}
+
 test "@subWithOverflow" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
 
-    {
-        var a: u8 = 1;
-        _ = &a;
-        const ov = @subWithOverflow(a, 2);
-        try expect(ov[0] == 255);
-        try expect(ov[1] == 1);
-    }
-    {
-        var a: u8 = 1;
-        _ = &a;
-        const ov = @subWithOverflow(a, 1);
-        try expect(ov[0] == 0);
-        try expect(ov[1] == 0);
-    }
+    try testSubWithOverflow(u8, 1, 2, 255, 1);
+    try testSubWithOverflow(u8, 1, 1, 0, 0);
 
-    {
-        var a: u8 = 1;
-        _ = &a;
-        var b: u8 = 2;
-        var ov = @subWithOverflow(a, b);
-        try expect(ov[0] == 255);
-        try expect(ov[1] == 1);
-        b = 1;
-        ov = @subWithOverflow(a, b);
-        try expect(ov[0] == 0);
-        try expect(ov[1] == 0);
-    }
+    try testSubWithOverflow(u16, 10000, 10002, 65534, 1);
+    try testSubWithOverflow(u16, 10000, 9999, 1, 0);
 
-    {
-        var a: usize = 6;
-        var b: usize = 6;
-        _ = .{ &a, &b };
-        const ov = @subWithOverflow(a, b);
-        try expect(ov[0] == 0);
-        try expect(ov[1] == 0);
-    }
+    try testSubWithOverflow(usize, 6, 6, 0, 0);
+    try testSubWithOverflow(usize, 6, 7, maxInt(usize), 1);
+    try testSubWithOverflow(isize, -6, -6, 0, 0);
+    try testSubWithOverflow(isize, minInt(isize), 6, maxInt(isize) - 5, 1);
+}
 
-    {
-        var a: isize = -6;
-        var b: isize = -6;
-        _ = .{ &a, &b };
-        const ov = @subWithOverflow(a, b);
-        try expect(ov[0] == 0);
-        try expect(ov[1] == 0);
-    }
+test "@subWithOverflow > 64 bits" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
+
+    try testSubWithOverflow(u65, 4, 105, maxInt(u65) - 100, 1);
+    try testSubWithOverflow(u65, 1000, 100, 900, 0);
+    try testSubWithOverflow(u65, maxInt(u65), maxInt(u65), 0, 0);
+    try testSubWithOverflow(u65, maxInt(u65) - 1, maxInt(u65), maxInt(u65), 1);
+    try testSubWithOverflow(u65, maxInt(u65), maxInt(u65) - 1, 1, 0);
+
+    try testSubWithOverflow(u128, 4, 105, maxInt(u128) - 100, 1);
+    try testSubWithOverflow(u128, 1000, 100, 900, 0);
+    try testSubWithOverflow(u128, maxInt(u128), maxInt(u128), 0, 0);
+    try testSubWithOverflow(u128, maxInt(u128) - 1, maxInt(u128), maxInt(u128), 1);
+    try testSubWithOverflow(u128, maxInt(u128), maxInt(u128) - 1, 1, 0);
+
+    try testSubWithOverflow(i65, 4, 105, -101, 0);
+    try testSubWithOverflow(i65, 1000, 100, 900, 0);
+    try testSubWithOverflow(i65, maxInt(i65), maxInt(i65), 0, 0);
+    try testSubWithOverflow(i65, minInt(i65), minInt(i65), 0, 0);
+    try testSubWithOverflow(i65, maxInt(i65) - 1, maxInt(i65), -1, 0);
+    try testSubWithOverflow(i65, maxInt(i65), maxInt(i65) - 1, 1, 0);
+    try testSubWithOverflow(i65, minInt(i65), 1, maxInt(i65), 1);
+    try testSubWithOverflow(i65, maxInt(i65), minInt(i65), -1, 1);
+    try testSubWithOverflow(i65, minInt(i65), maxInt(i65), 1, 1);
+    try testSubWithOverflow(i65, maxInt(i65), -2, minInt(i65) + 1, 1);
+
+    try testSubWithOverflow(i128, 4, 105, -101, 0);
+    try testSubWithOverflow(i128, 1000, 100, 900, 0);
+    try testSubWithOverflow(i128, maxInt(i128), maxInt(i128), 0, 0);
+    try testSubWithOverflow(i128, minInt(i128), minInt(i128), 0, 0);
+    try testSubWithOverflow(i128, maxInt(i128) - 1, maxInt(i128), -1, 0);
+    try testSubWithOverflow(i128, maxInt(i128), maxInt(i128) - 1, 1, 0);
+    try testSubWithOverflow(i128, minInt(i128), 1, maxInt(i128), 1);
+    try testSubWithOverflow(i128, maxInt(i128), minInt(i128), -1, 1);
+    try testSubWithOverflow(i128, minInt(i128), maxInt(i128), 1, 1);
+    try testSubWithOverflow(i128, maxInt(i128), -2, minInt(i128) + 1, 1);
 }
 
 test "@shlWithOverflow" {