Commit 55fe34100f

Veikka Tuominen <git@vexu.eu>
2022-07-16 15:10:11
Sema: exact division safety
1 parent 76d0999
src/Sema.zig
@@ -11917,6 +11917,47 @@ fn analyzeArithmetic(
             },
             else => {},
         }
+        if (rs.air_tag == .div_exact) {
+            const result = try block.addBinOp(.div_exact, casted_lhs, casted_rhs);
+            const ok = if (scalar_tag == .Float) ok: {
+                const floored = try block.addUnOp(.floor, result);
+
+                if (resolved_type.zigTypeTag() == .Vector) {
+                    const eql = try block.addCmpVector(result, floored, .eq, try sema.addType(resolved_type));
+                    break :ok try block.addInst(.{
+                        .tag = .reduce,
+                        .data = .{ .reduce = .{
+                            .operand = eql,
+                            .operation = .And,
+                        } },
+                    });
+                } else {
+                    const is_in_range = try block.addBinOp(.cmp_eq, result, floored);
+                    break :ok is_in_range;
+                }
+            } else ok: {
+                const remainder = try block.addBinOp(.rem, casted_lhs, casted_rhs);
+
+                if (resolved_type.zigTypeTag() == .Vector) {
+                    const zero_val = try Value.Tag.repeated.create(sema.arena, Value.zero);
+                    const zero = try sema.addConstant(sema.typeOf(casted_rhs), zero_val);
+                    const eql = try block.addCmpVector(remainder, zero, .eq, try sema.addType(resolved_type));
+                    break :ok try block.addInst(.{
+                        .tag = .reduce,
+                        .data = .{ .reduce = .{
+                            .operand = eql,
+                            .operation = .And,
+                        } },
+                    });
+                } else {
+                    const zero = try sema.addConstant(sema.typeOf(casted_rhs), Value.zero);
+                    const is_in_range = try block.addBinOp(.cmp_eq, remainder, zero);
+                    break :ok is_in_range;
+                }
+            };
+            try sema.addSafetyCheck(block, ok, .exact_division_remainder);
+            return result;
+        }
     }
     return block.addBinOp(rs.air_tag, casted_lhs, casted_rhs);
 }
@@ -18856,6 +18897,7 @@ pub const PanicId = enum {
     shr_overflow,
     divide_by_zero,
     remainder_division_zero_negative,
+    exact_division_remainder,
 };
 
 fn addSafetyCheck(
@@ -19077,6 +19119,7 @@ fn safetyPanic(
         .shr_overflow => "right shift overflowed bits",
         .divide_by_zero => "division by zero",
         .remainder_division_zero_negative => "remainder division by zero or negative value",
+        .exact_division_remainder => "exact division produced remainder",
     };
 
     const msg_inst = msg_inst: {
test/behavior/math.zig
@@ -377,6 +377,7 @@ fn testBinaryNot(x: u16) !void {
 }
 
 test "division" {
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
test/cases/safety/exact division failure - vectors.zig
@@ -1,9 +1,11 @@
 const std = @import("std");
 
 pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn {
-    _ = message;
     _ = stack_trace;
-    std.process.exit(0);
+    if (std.mem.eql(u8, message, "exact division produced remainder")) {
+        std.process.exit(0);
+    }
+    std.process.exit(1);
 }
 
 pub fn main() !void {
@@ -17,5 +19,5 @@ fn divExact(a: @Vector(4, i32), b: @Vector(4, i32)) @Vector(4, i32) {
     return @divExact(a, b);
 }
 // run
-// backend=stage1
-// target=native
\ No newline at end of file
+// backend=llvm
+// target=native
test/cases/safety/exact division failure.zig
@@ -1,9 +1,11 @@
 const std = @import("std");
 
 pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn {
-    _ = message;
     _ = stack_trace;
-    std.process.exit(0);
+    if (std.mem.eql(u8, message, "exact division produced remainder")) {
+        std.process.exit(0);
+    }
+    std.process.exit(1);
 }
 
 pub fn main() !void {
@@ -15,5 +17,5 @@ fn divExact(a: i32, b: i32) i32 {
     return @divExact(a, b);
 }
 // run
-// backend=stage1
-// target=native
\ No newline at end of file
+// backend=llvm
+// target=native