Commit 19d5ffc710

Veikka Tuominen <git@vexu.eu>
2022-08-05 16:47:01
Sema: add safety check for non-power-of-two shift amounts
1 parent 9116e26
src/Sema.zig
@@ -10227,34 +10227,57 @@ fn zirShl(
     } else rhs;
 
     try sema.requireRuntimeBlock(block, src, runtime_src);
-    if (block.wantSafety() and air_tag == .shl_exact) {
-        const op_ov_tuple_ty = try sema.overflowArithmeticTupleType(lhs_ty);
-        const op_ov = try block.addInst(.{
-            .tag = .shl_with_overflow,
-            .data = .{ .ty_pl = .{
-                .ty = try sema.addType(op_ov_tuple_ty),
-                .payload = try sema.addExtra(Air.Bin{
-                    .lhs = lhs,
-                    .rhs = rhs,
-                }),
-            } },
-        });
-        const ov_bit = try sema.tupleFieldValByIndex(block, src, op_ov, 1, op_ov_tuple_ty);
-        const any_ov_bit = if (lhs_ty.zigTypeTag() == .Vector)
-            try block.addInst(.{
-                .tag = if (block.float_mode == .Optimized) .reduce_optimized else .reduce,
-                .data = .{ .reduce = .{
-                    .operand = ov_bit,
-                    .operation = .Or,
+    if (block.wantSafety()) {
+        const bit_count = scalar_ty.intInfo(target).bits;
+        if (!std.math.isPowerOfTwo(bit_count)) {
+            const bit_count_val = try Value.Tag.int_u64.create(sema.arena, bit_count);
+
+            const ok = if (rhs_ty.zigTypeTag() == .Vector) ok: {
+                const bit_count_inst = try sema.addConstant(rhs_ty, try Value.Tag.repeated.create(sema.arena, bit_count_val));
+                const lt = try block.addCmpVector(rhs, bit_count_inst, .lt, try sema.addType(rhs_ty));
+                break :ok try block.addInst(.{
+                    .tag = .reduce,
+                    .data = .{ .reduce = .{
+                        .operand = lt,
+                        .operation = .And,
+                    } },
+                });
+            } else ok: {
+                const bit_count_inst = try sema.addConstant(rhs_ty, bit_count_val);
+                break :ok try block.addBinOp(.cmp_lt, rhs, bit_count_inst);
+            };
+            try sema.addSafetyCheck(block, ok, .shift_rhs_too_big);
+        }
+
+        if (air_tag == .shl_exact) {
+            const op_ov_tuple_ty = try sema.overflowArithmeticTupleType(lhs_ty);
+            const op_ov = try block.addInst(.{
+                .tag = .shl_with_overflow,
+                .data = .{ .ty_pl = .{
+                    .ty = try sema.addType(op_ov_tuple_ty),
+                    .payload = try sema.addExtra(Air.Bin{
+                        .lhs = lhs,
+                        .rhs = rhs,
+                    }),
                 } },
-            })
-        else
-            ov_bit;
-        const zero_ov = try sema.addConstant(Type.@"u1", Value.zero);
-        const no_ov = try block.addBinOp(.cmp_eq, any_ov_bit, zero_ov);
+            });
+            const ov_bit = try sema.tupleFieldValByIndex(block, src, op_ov, 1, op_ov_tuple_ty);
+            const any_ov_bit = if (lhs_ty.zigTypeTag() == .Vector)
+                try block.addInst(.{
+                    .tag = if (block.float_mode == .Optimized) .reduce_optimized else .reduce,
+                    .data = .{ .reduce = .{
+                        .operand = ov_bit,
+                        .operation = .Or,
+                    } },
+                })
+            else
+                ov_bit;
+            const zero_ov = try sema.addConstant(Type.@"u1", Value.zero);
+            const no_ov = try block.addBinOp(.cmp_eq, any_ov_bit, zero_ov);
 
-        try sema.addSafetyCheck(block, no_ov, .shl_overflow);
-        return sema.tupleFieldValByIndex(block, src, op_ov, 0, op_ov_tuple_ty);
+            try sema.addSafetyCheck(block, no_ov, .shl_overflow);
+            return sema.tupleFieldValByIndex(block, src, op_ov, 0, op_ov_tuple_ty);
+        }
     }
     return block.addBinOp(air_tag, lhs, new_rhs);
 }
@@ -10333,20 +10356,43 @@ fn zirShr(
 
     try sema.requireRuntimeBlock(block, src, runtime_src);
     const result = try block.addBinOp(air_tag, lhs, rhs);
-    if (block.wantSafety() and air_tag == .shr_exact) {
-        const back = try block.addBinOp(.shl, result, rhs);
-
-        const ok = if (rhs_ty.zigTypeTag() == .Vector) ok: {
-            const eql = try block.addCmpVector(lhs, back, .eq, try sema.addType(rhs_ty));
-            break :ok try block.addInst(.{
-                .tag = if (block.float_mode == .Optimized) .reduce_optimized else .reduce,
-                .data = .{ .reduce = .{
-                    .operand = eql,
-                    .operation = .And,
-                } },
-            });
-        } else try block.addBinOp(.cmp_eq, lhs, back);
-        try sema.addSafetyCheck(block, ok, .shr_overflow);
+    if (block.wantSafety()) {
+        const bit_count = scalar_ty.intInfo(target).bits;
+        if (!std.math.isPowerOfTwo(bit_count)) {
+            const bit_count_val = try Value.Tag.int_u64.create(sema.arena, bit_count);
+
+            const ok = if (rhs_ty.zigTypeTag() == .Vector) ok: {
+                const bit_count_inst = try sema.addConstant(rhs_ty, try Value.Tag.repeated.create(sema.arena, bit_count_val));
+                const lt = try block.addCmpVector(rhs, bit_count_inst, .lt, try sema.addType(rhs_ty));
+                break :ok try block.addInst(.{
+                    .tag = .reduce,
+                    .data = .{ .reduce = .{
+                        .operand = lt,
+                        .operation = .And,
+                    } },
+                });
+            } else ok: {
+                const bit_count_inst = try sema.addConstant(rhs_ty, bit_count_val);
+                break :ok try block.addBinOp(.cmp_lt, rhs, bit_count_inst);
+            };
+            try sema.addSafetyCheck(block, ok, .shift_rhs_too_big);
+        }
+
+        if (air_tag == .shr_exact) {
+            const back = try block.addBinOp(.shl, result, rhs);
+
+            const ok = if (rhs_ty.zigTypeTag() == .Vector) ok: {
+                const eql = try block.addCmpVector(lhs, back, .eq, try sema.addType(rhs_ty));
+                break :ok try block.addInst(.{
+                    .tag = if (block.float_mode == .Optimized) .reduce_optimized else .reduce,
+                    .data = .{ .reduce = .{
+                        .operand = eql,
+                        .operation = .And,
+                    } },
+                });
+            } else try block.addBinOp(.cmp_eq, lhs, back);
+            try sema.addSafetyCheck(block, ok, .shr_overflow);
+        }
     }
     return result;
 }
@@ -19972,6 +20018,7 @@ pub const PanicId = enum {
     inactive_union_field,
     integer_part_out_of_bounds,
     corrupt_switch,
+    shift_rhs_too_big,
 };
 
 fn addSafetyCheck(
@@ -20268,6 +20315,7 @@ fn safetyPanic(
         .inactive_union_field => "access of inactive union field",
         .integer_part_out_of_bounds => "integer part of floating point value out of bounds",
         .corrupt_switch => "switch on corrupt value",
+        .shift_rhs_too_big => "shift amount is greater than the type size",
     };
 
     const msg_inst = msg_inst: {
test/cases/safety/shift left by huge amount.zig
@@ -17,5 +17,5 @@ pub fn main() !void {
 }
 
 // run
-// backend=stage1
+// backend=llvm
 // target=native
test/cases/safety/shift right by huge amount.zig
@@ -17,5 +17,5 @@ pub fn main() !void {
 }
 
 // run
-// backend=stage1
+// backend=llvm
 // target=native
test/cases/safety/signed integer division overflow - vectors.zig
@@ -1,9 +1,11 @@
 const std = @import("std");
 
 pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn {
-    _ = message;
     _ = stack_trace;
-    std.process.exit(0);
+    if (std.mem.eql(u8, message, "integer overflow")) {
+        std.process.exit(0);
+    }
+    std.process.exit(1);
 }
 
 pub fn main() !void {
@@ -17,5 +19,5 @@ fn div(a: @Vector(4, i16), b: @Vector(4, i16)) @Vector(4, i16) {
     return @divTrunc(a, b);
 }
 // run
-// backend=stage1
-// target=native
\ No newline at end of file
+// backend=llvm
+// target=native
test/cases/safety/signed integer division overflow.zig
@@ -1,9 +1,11 @@
 const std = @import("std");
 
 pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn {
-    _ = message;
     _ = stack_trace;
-    std.process.exit(0);
+    if (std.mem.eql(u8, message, "integer overflow")) {
+        std.process.exit(0);
+    }
+    std.process.exit(1);
 }
 
 pub fn main() !void {
@@ -15,5 +17,5 @@ fn div(a: i16, b: i16) i16 {
     return @divTrunc(a, b);
 }
 // run
-// backend=stage1
-// target=native
\ No newline at end of file
+// backend=llvm
+// target=native