Commit 9f10dfcb54

Veikka Tuominen <git@vexu.eu>
2022-07-16 01:15:24
Sema: implement shr_exact runtime safety
1 parent 4d20d68
src/Sema.zig
@@ -9996,40 +9996,34 @@ fn zirShl(
     } else rhs;
 
     try sema.requireRuntimeBlock(block, src, runtime_src);
-    if (block.wantSafety()) {
-        const maybe_op_ov: ?Air.Inst.Tag = switch (air_tag) {
-            .shl_exact => .shl_with_overflow,
-            else => null,
-        };
-        if (maybe_op_ov) |op_ov_tag| {
-            const op_ov_tuple_ty = try sema.overflowArithmeticTupleType(lhs_ty);
-            const op_ov = try block.addInst(.{
-                .tag = op_ov_tag,
-                .data = .{ .ty_pl = .{
-                    .ty = try sema.addType(op_ov_tuple_ty),
-                    .payload = try sema.addExtra(Air.Bin{
-                        .lhs = lhs,
-                        .rhs = rhs,
-                    }),
+    if (block.wantSafety() and air_tag == .shl_exact) {
+        const op_ov_tuple_ty = try sema.overflowArithmeticTupleType(lhs_ty);
+        const op_ov = try block.addInst(.{
+            .tag = .shl_with_overflow,
+            .data = .{ .ty_pl = .{
+                .ty = try sema.addType(op_ov_tuple_ty),
+                .payload = try sema.addExtra(Air.Bin{
+                    .lhs = lhs,
+                    .rhs = rhs,
+                }),
+            } },
+        });
+        const ov_bit = try sema.tupleFieldValByIndex(block, src, op_ov, 1, op_ov_tuple_ty);
+        const any_ov_bit = if (lhs_ty.zigTypeTag() == .Vector)
+            try block.addInst(.{
+                .tag = .reduce,
+                .data = .{ .reduce = .{
+                    .operand = ov_bit,
+                    .operation = .Or,
                 } },
-            });
-            const ov_bit = try sema.tupleFieldValByIndex(block, src, op_ov, 1, op_ov_tuple_ty);
-            const any_ov_bit = if (lhs_ty.zigTypeTag() == .Vector)
-                try block.addInst(.{
-                    .tag = .reduce,
-                    .data = .{ .reduce = .{
-                        .operand = ov_bit,
-                        .operation = .Or,
-                    } },
-                })
-            else
-                ov_bit;
-            const zero_ov = try sema.addConstant(Type.@"u1", Value.zero);
-            const no_ov = try block.addBinOp(.cmp_eq, any_ov_bit, zero_ov);
+            })
+        else
+            ov_bit;
+        const zero_ov = try sema.addConstant(Type.@"u1", Value.zero);
+        const no_ov = try block.addBinOp(.cmp_eq, any_ov_bit, zero_ov);
 
-            try sema.addSafetyCheck(block, no_ov, .shl_overflow);
-            return sema.tupleFieldValByIndex(block, src, op_ov, 0, op_ov_tuple_ty);
-        }
+        try sema.addSafetyCheck(block, no_ov, .shl_overflow);
+        return sema.tupleFieldValByIndex(block, src, op_ov, 0, op_ov_tuple_ty);
     }
     return block.addBinOp(air_tag, lhs, new_rhs);
 }
@@ -10107,7 +10101,23 @@ fn zirShr(
     } else rhs_src;
 
     try sema.requireRuntimeBlock(block, src, runtime_src);
-    return block.addBinOp(air_tag, lhs, rhs);
+    const result = try block.addBinOp(air_tag, lhs, rhs);
+    if (block.wantSafety() and air_tag == .shr_exact) {
+        const back = try block.addBinOp(.shl, result, rhs);
+
+        const ok = if (rhs_ty.zigTypeTag() == .Vector) ok: {
+            const eql = try block.addCmpVector(lhs, back, .eq, try sema.addType(rhs_ty));
+            break :ok try block.addInst(.{
+                .tag = .reduce,
+                .data = .{ .reduce = .{
+                    .operand = eql,
+                    .operation = .And,
+                } },
+            });
+        } else try block.addBinOp(.cmp_eq, lhs, back);
+        try sema.addSafetyCheck(block, ok, .shr_overflow);
+    }
+    return result;
 }
 
 fn zirBitwise(
@@ -18802,6 +18812,7 @@ pub const PanicId = enum {
     cast_truncated_data,
     integer_overflow,
     shl_overflow,
+    shr_overflow,
 };
 
 fn addSafetyCheck(
@@ -19019,6 +19030,7 @@ fn safetyPanic(
         .cast_truncated_data => "integer cast truncated bits",
         .integer_overflow => "integer overflow",
         .shl_overflow => "left shift overflowed bits",
+        .shr_overflow => "right shift overflowed bits",
     };
 
     const msg_inst = msg_inst: {
test/cases/safety/signed shift right overflow.zig
@@ -1,9 +1,11 @@
 const std = @import("std");
 
 pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn {
-    _ = message;
     _ = stack_trace;
-    std.process.exit(0);
+    if (std.mem.eql(u8, message, "right shift overflowed bits")) {
+        std.process.exit(0);
+    }
+    std.process.exit(1);
 }
 
 pub fn main() !void {
@@ -15,5 +17,5 @@ fn shr(a: i16, b: u4) i16 {
     return @shrExact(a, b);
 }
 // run
-// backend=stage1
-// target=native
\ No newline at end of file
+// backend=llvm
+// target=native
test/cases/safety/unsigned shift right overflow.zig
@@ -1,9 +1,11 @@
 const std = @import("std");
 
 pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn {
-    _ = message;
     _ = stack_trace;
-    std.process.exit(0);
+    if (std.mem.eql(u8, message, "right shift overflowed bits")) {
+        std.process.exit(0);
+    }
+    std.process.exit(1);
 }
 
 pub fn main() !void {
@@ -15,5 +17,5 @@ fn shr(a: u16, b: u4) u16 {
     return @shrExact(a, b);
 }
 // run
-// backend=stage1
-// target=native
\ No newline at end of file
+// backend=llvm
+// target=native