Commit aa29f4a803

riverbl <94326797+riverbl@users.noreply.github.com>
2021-12-21 13:45:48
stage1: fix saturating arithmetic producing incorrect results on type comptime_int, allow saturating left shift on type comptime int
1 parent aca665c
Changed files (3)
src/stage1/ir.cpp
@@ -10230,13 +10230,7 @@ static Stage1AirInst *ir_analyze_bit_shift(IrAnalyze *ira, Stage1ZirInstBinOp *b
         // comptime_int has no finite bit width
         casted_op2 = op2;
 
-        if (op_id == IrBinOpShlSat) {
-            ir_add_error_node(ira, bin_op_instruction->base.source_node,
-                buf_sprintf("saturating shift on a comptime_int which has unlimited bits"));
-            return ira->codegen->invalid_inst_gen;
-        }
-
-        if (op_id == IrBinOpBitShiftLeftLossy) {
+        if (op_id == IrBinOpBitShiftLeftLossy || op_id == IrBinOpShlSat) {
             op_id = IrBinOpBitShiftLeftExact;
         }
 
@@ -10398,6 +10392,25 @@ static bool ok_float_op(IrBinOp op) {
     zig_unreachable();
 }
 
+static IrBinOp map_comptime_arithmetic_op(IrBinOp op) {
+    switch (op) {
+        case IrBinOpAddWrap:
+        case IrBinOpAddSat:
+            return IrBinOpAdd;
+
+        case IrBinOpSubWrap:
+        case IrBinOpSubSat:
+            return IrBinOpSub;
+
+        case IrBinOpMultWrap:
+        case IrBinOpMultSat:
+            return IrBinOpMult;
+
+        default:
+            return op;
+    }
+}
+
 static bool is_pointer_arithmetic_allowed(ZigType *lhs_type, IrBinOp op) {
     switch (op) {
         case IrBinOpAdd:
@@ -10620,15 +10633,10 @@ static Stage1AirInst *ir_analyze_bin_op_math(IrAnalyze *ira, Stage1ZirInstBinOp
     if (type_is_invalid(casted_op2->value->type))
         return ira->codegen->invalid_inst_gen;
 
-    // Comptime integers have no fixed size
+    // Comptime integers have no fixed size, so wrapping or saturating operations should be mapped
+    // to their non wrapping or saturating equivalents
     if (scalar_type->id == ZigTypeIdComptimeInt) {
-        if (op_id == IrBinOpAddWrap) {
-            op_id = IrBinOpAdd;
-        } else if (op_id == IrBinOpSubWrap) {
-            op_id = IrBinOpSub;
-        } else if (op_id == IrBinOpMultWrap) {
-            op_id = IrBinOpMult;
-        }
+        op_id = map_comptime_arithmetic_op(op_id);
     }
 
     if (instr_is_comptime(casted_op1) && instr_is_comptime(casted_op2)) {
test/behavior/saturating_arithmetic.zig
@@ -29,8 +29,14 @@ test "saturating add" {
             try expect(x == expected);
         }
     };
+
     try S.doTheTest();
     comptime try S.doTheTest();
+
+    comptime try S.testSatAdd(comptime_int, 0, 0, 0);
+    comptime try S.testSatAdd(comptime_int, 3, 2, 5);
+    comptime try S.testSatAdd(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 1119305249183743626545271163355074748512);
+    comptime try S.testSatAdd(comptime_int, 7, -593423721213448152027139550640105366508, -593423721213448152027139550640105366501);
 }
 
 test "saturating subtraction" {
@@ -56,8 +62,14 @@ test "saturating subtraction" {
             try expect(x == expected);
         }
     };
+
     try S.doTheTest();
     comptime try S.doTheTest();
+
+    comptime try S.testSatSub(comptime_int, 0, 0, 0);
+    comptime try S.testSatSub(comptime_int, 3, 2, 1);
+    comptime try S.testSatSub(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 182846383813587550256162760261375991602);
+    comptime try S.testSatSub(comptime_int, 7, -593423721213448152027139550640105366508, 593423721213448152027139550640105366515);
 }
 
 test "saturating multiplication" {
@@ -90,6 +102,11 @@ test "saturating multiplication" {
 
     try S.doTheTest();
     comptime try S.doTheTest();
+
+    comptime try S.testSatMul(comptime_int, 0, 0, 0);
+    comptime try S.testSatMul(comptime_int, 3, 2, 6);
+    comptime try S.testSatMul(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 304852860194144160265083087140337419215516305999637969803722975979232817921935);
+    comptime try S.testSatMul(comptime_int, 7, -593423721213448152027139550640105366508, -4153966048494137064189976854480737565556);
 }
 
 test "saturating shift-left" {
@@ -107,6 +124,7 @@ test "saturating shift-left" {
             try testSatShl(u8, 1, 2, 4);
             try testSatShl(u8, 255, 1, 255);
         }
+
         fn testSatShl(comptime T: type, lhs: T, rhs: T, expected: T) !void {
             try expect((lhs <<| rhs) == expected);
 
@@ -115,8 +133,14 @@ test "saturating shift-left" {
             try expect(x == expected);
         }
     };
+
     try S.doTheTest();
     comptime try S.doTheTest();
+
+    comptime try S.testSatShl(comptime_int, 0, 0, 0);
+    comptime try S.testSatShl(comptime_int, 1, 2, 4);
+    comptime try S.testSatShl(comptime_int, 13, 150, 18554220005177478453757717602843436772975706112);
+    comptime try S.testSatShl(comptime_int, -582769, 180, -893090893854873184096635538665358532628308979495815656505344);
 }
 
 test "saturating shl uses the LHS type" {
@@ -139,4 +163,6 @@ test "saturating shl uses the LHS type" {
     try expect((@as(u8, 1) <<| 8) == 255);
     try expect((@as(u8, 1) <<| rhs_const) == 255);
     try expect((@as(u8, 1) <<| rhs_var) == 255);
+
+    try expect((1 <<| @as(u8, 200)) == 1606938044258990275541962092341162602522202993782792835301376);
 }
test/behavior/wrapping_arithmetic.zig
@@ -0,0 +1,110 @@
+const std = @import("std");
+const builtin = @import("builtin");
+const minInt = std.math.minInt;
+const maxInt = std.math.maxInt;
+const expect = std.testing.expect;
+
+test "wrapping add" {
+    const S = struct {
+        fn doTheTest() !void {
+            try testWrapAdd(i8, -3, 10, 7);
+            try testWrapAdd(i8, -128, -128, 0);
+            try testWrapAdd(i2, 1, 1, -2);
+            try testWrapAdd(i64, maxInt(i64), 1, minInt(i64));
+            try testWrapAdd(i128, maxInt(i128), -maxInt(i128), 0);
+            try testWrapAdd(i128, minInt(i128), maxInt(i128), -1);
+            try testWrapAdd(i8, 127, 127, -2);
+            try testWrapAdd(u8, 3, 10, 13);
+            try testWrapAdd(u8, 255, 255, 254);
+            try testWrapAdd(u2, 3, 2, 1);
+            try testWrapAdd(u3, 7, 1, 0);
+            try testWrapAdd(u128, maxInt(u128), 1, minInt(u128));
+        }
+
+        fn testWrapAdd(comptime T: type, lhs: T, rhs: T, expected: T) !void {
+            try expect((lhs +% rhs) == expected);
+
+            var x = lhs;
+            x +%= rhs;
+            try expect(x == expected);
+        }
+    };
+
+    try S.doTheTest();
+    comptime try S.doTheTest();
+
+    comptime try S.testWrapAdd(comptime_int, 0, 0, 0);
+    comptime try S.testWrapAdd(comptime_int, 3, 2, 5);
+    comptime try S.testWrapAdd(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 1119305249183743626545271163355074748512);
+    comptime try S.testWrapAdd(comptime_int, 7, -593423721213448152027139550640105366508, -593423721213448152027139550640105366501);
+}
+
+test "wrapping subtraction" {
+    const S = struct {
+        fn doTheTest() !void {
+            try testWrapSub(i8, -3, 10, -13);
+            try testWrapSub(i8, -128, -128, 0);
+            try testWrapSub(i8, -1, 127, -128);
+            try testWrapSub(i64, minInt(i64), 1, maxInt(i64));
+            try testWrapSub(i128, maxInt(i128), -1, minInt(i128));
+            try testWrapSub(i128, minInt(i128), -maxInt(i128), -1);
+            try testWrapSub(u8, 10, 3, 7);
+            try testWrapSub(u8, 0, 255, 1);
+            try testWrapSub(u5, 0, 31, 1);
+            try testWrapSub(u128, 0, maxInt(u128), 1);
+        }
+
+        fn testWrapSub(comptime T: type, lhs: T, rhs: T, expected: T) !void {
+            try expect((lhs -% rhs) == expected);
+
+            var x = lhs;
+            x -%= rhs;
+            try expect(x == expected);
+        }
+    };
+
+    try S.doTheTest();
+    comptime try S.doTheTest();
+
+    comptime try S.testWrapSub(comptime_int, 0, 0, 0);
+    comptime try S.testWrapSub(comptime_int, 3, 2, 1);
+    comptime try S.testWrapSub(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 182846383813587550256162760261375991602);
+    comptime try S.testWrapSub(comptime_int, 7, -593423721213448152027139550640105366508, 593423721213448152027139550640105366515);
+}
+
+test "wrapping multiplication" {
+    // TODO: once #9660 has been solved, remove this line
+    if (builtin.cpu.arch == .wasm32) return error.SkipZigTest;
+
+    const S = struct {
+        fn doTheTest() !void {
+            try testWrapMul(i8, -3, 10, -30);
+            try testWrapMul(i4, 2, 4, -8);
+            try testWrapMul(i8, 2, 127, -2);
+            try testWrapMul(i8, -128, -128, 0);
+            try testWrapMul(i8, maxInt(i8), maxInt(i8), 1);
+            try testWrapMul(i16, maxInt(i16), -1, minInt(i16) + 1);
+            try testWrapMul(i128, maxInt(i128), -1, minInt(i128) + 1);
+            try testWrapMul(i128, minInt(i128), -1, minInt(i128));
+            try testWrapMul(u8, 10, 3, 30);
+            try testWrapMul(u8, 2, 255, 254);
+            try testWrapMul(u128, maxInt(u128), maxInt(u128), 1);
+        }
+
+        fn testWrapMul(comptime T: type, lhs: T, rhs: T, expected: T) !void {
+            try expect((lhs *% rhs) == expected);
+
+            var x = lhs;
+            x *%= rhs;
+            try expect(x == expected);
+        }
+    };
+
+    try S.doTheTest();
+    comptime try S.doTheTest();
+
+    comptime try S.testWrapMul(comptime_int, 0, 0, 0);
+    comptime try S.testWrapMul(comptime_int, 3, 2, 6);
+    comptime try S.testWrapMul(comptime_int, 651075816498665588400716961808225370057, 468229432685078038144554201546849378455, 304852860194144160265083087140337419215516305999637969803722975979232817921935);
+    comptime try S.testWrapMul(comptime_int, 7, -593423721213448152027139550640105366508, -4153966048494137064189976854480737565556);
+}