Commit 300fceac6e

LemonBoy <thatlemon@gmail.com>
2020-03-10 20:54:05
ir: Implement more safety checks for shl/shr
The checks are now valid on types whose size is not a power of two. Closes #2096
1 parent 9c4dc7b
src/all_types.hpp
@@ -1834,6 +1834,7 @@ enum PanicMsgId {
     PanicMsgIdBadNoAsyncCall,
     PanicMsgIdResumeNotSuspendedFn,
     PanicMsgIdBadSentinel,
+    PanicMsgIdShxTooBigRhs,
 
     PanicMsgIdCount,
 };
src/codegen.cpp
@@ -974,6 +974,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) {
             return buf_create_from_str("resumed a non-suspended function");
         case PanicMsgIdBadSentinel:
             return buf_create_from_str("sentinel mismatch");
+        case PanicMsgIdShxTooBigRhs:
+            return buf_create_from_str("shift amount is greater than the type size");
     }
     zig_unreachable();
 }
@@ -2841,6 +2843,26 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast
 
 }
 
+static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type, LLVMValueRef value) {
+    // We only check if the rhs value of the shift expression is greater or
+    // equal to the number of bits of the lhs if it's not a power of two,
+    // otherwise the check is useful as the allowed values are limited by the
+    // operand type itself
+    if (!is_power_of_2(lhs_type->data.integral.bit_count)) {
+        LLVMValueRef bit_count_value = LLVMConstInt(get_llvm_type(g, rhs_type),
+            lhs_type->data.integral.bit_count, false);
+        LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, "");
+        LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckFail");
+        LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk");
+        LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block);
+
+        LLVMPositionBuilderAtEnd(g->builder, fail_block);
+        gen_safety_crash(g, PanicMsgIdShxTooBigRhs);
+
+        LLVMPositionBuilderAtEnd(g->builder, ok_block);
+    }
+}
+
 static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
         IrInstGenBinOp *bin_op_instruction)
 {
@@ -2949,6 +2971,11 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
             {
                 assert(scalar_type->id == ZigTypeIdInt);
                 LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value);
+
+                if (want_runtime_safety) {
+                    gen_shift_rhs_check(g, scalar_type, op2->value->type, op2_value);
+                }
+
                 bool is_sloppy = (op_id == IrBinOpBitShiftLeftLossy);
                 if (is_sloppy) {
                     return LLVMBuildShl(g->builder, op1_value, op2_casted, "");
@@ -2965,6 +2992,11 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
             {
                 assert(scalar_type->id == ZigTypeIdInt);
                 LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value);
+
+                if (want_runtime_safety) {
+                    gen_shift_rhs_check(g, scalar_type, op2->value->type, op2_value);
+                }
+
                 bool is_sloppy = (op_id == IrBinOpBitShiftRightLossy);
                 if (is_sloppy) {
                     if (scalar_type->data.integral.is_signed) {
src/ir.cpp
@@ -16648,36 +16648,34 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
             return ira->codegen->invalid_inst_gen;
         }
     } else {
+        assert(op1->value->type->data.integral.bit_count > 0);
         ZigType *shift_amt_type = get_smallest_unsigned_int_type(ira->codegen,
-                op1->value->type->data.integral.bit_count - 1);
-        if (bin_op_instruction->op_id == IrBinOpBitShiftLeftLossy &&
-            op2->value->type->id == ZigTypeIdComptimeInt) {
+            op1->value->type->data.integral.bit_count - 1);
 
-            ZigValue *op2_val = ir_resolve_const(ira, op2, UndefBad);
+        casted_op2 = ir_implicit_cast(ira, op2, shift_amt_type);
+        if (type_is_invalid(casted_op2->value->type))
+            return ira->codegen->invalid_inst_gen;
+
+        if (instr_is_comptime(casted_op2)) {
+            ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad);
             if (op2_val == nullptr)
                 return ira->codegen->invalid_inst_gen;
-            if (!bigint_fits_in_bits(&op2_val->data.x_bigint,
-                                     shift_amt_type->data.integral.bit_count,
-                                     op2_val->data.x_bigint.is_negative)) {
-                Buf *val_buf = buf_alloc();
-                bigint_append_buf(val_buf, &op2_val->data.x_bigint, 10);
+
+            BigInt bit_count_value = {0};
+            bigint_init_unsigned(&bit_count_value, op1->value->type->data.integral.bit_count);
+
+            if (bigint_cmp(&op2_val->data.x_bigint, &bit_count_value) != CmpLT) {
                 ErrorMsg* msg = ir_add_error(ira,
                     &bin_op_instruction->base.base,
                     buf_sprintf("RHS of shift is too large for LHS type"));
-                add_error_note(
-                    ira->codegen,
-                    msg,
-                    op2->base.source_node,
-                    buf_sprintf("value %s cannot fit into type %s",
-                        buf_ptr(val_buf),
-                        buf_ptr(&shift_amt_type->name)));
+                add_error_note(ira->codegen, msg, op1->base.source_node,
+                    buf_sprintf("type %s has only %u bits",
+                        buf_ptr(&op1->value->type->name),
+                        op1->value->type->data.integral.bit_count));
+
                 return ira->codegen->invalid_inst_gen;
             }
         }
-
-        casted_op2 = ir_implicit_cast(ira, op2, shift_amt_type);
-        if (type_is_invalid(casted_op2->value->type))
-            return ira->codegen->invalid_inst_gen;
     }
 
     if (instr_is_comptime(op1) && instr_is_comptime(casted_op2)) {
test/compile_errors.zig
@@ -2,6 +2,38 @@ const tests = @import("tests.zig");
 const std = @import("std");
 
 pub fn addCases(cases: *tests.CompileErrorContext) void {
+    cases.addTest("shift on type with non-power-of-two size",
+        \\export fn entry() void {
+        \\    const S = struct {
+        \\        fn a() void {
+        \\            var x: u24 = 42;
+        \\            _ = x >> 24;
+        \\        }
+        \\        fn b() void {
+        \\            var x: u24 = 42;
+        \\            _ = x << 24;
+        \\        }
+        \\        fn c() void {
+        \\            var x: u24 = 42;
+        \\            _ = @shlExact(x, 24);
+        \\        }
+        \\        fn d() void {
+        \\            var x: u24 = 42;
+        \\            _ = @shrExact(x, 24);
+        \\        }
+        \\    };
+        \\    S.a();
+        \\    S.b();
+        \\    S.c();
+        \\    S.d();
+        \\}
+    , &[_][]const u8{
+        "tmp.zig:5:19: error: RHS of shift is too large for LHS type",
+        "tmp.zig:9:19: error: RHS of shift is too large for LHS type",
+        "tmp.zig:13:17: error: RHS of shift is too large for LHS type",
+        "tmp.zig:17:17: error: RHS of shift is too large for LHS type",
+    });
+
     cases.addTest("combination of noasync and async",
         \\export fn entry() void {
         \\    noasync {
@@ -4029,8 +4061,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void {
         \\}
         \\export fn entry() u16 { return f(); }
     , &[_][]const u8{
-        "tmp.zig:3:14: error: RHS of shift is too large for LHS type",
-        "tmp.zig:3:17: note: value 8 cannot fit into type u3",
+        "tmp.zig:3:17: error: integer value 8 cannot be coerced to type 'u3'",
     });
 
     cases.add("missing function call param",
test/runtime_safety.zig
@@ -1,6 +1,37 @@
 const tests = @import("tests.zig");
 
 pub fn addCases(cases: *tests.CompareOutputContext) void {
+    cases.addRuntimeSafety("shift left by huge amount",
+        \\const std = @import("std");
+        \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+        \\    std.debug.warn("{}\n", .{message});
+        \\    if (std.mem.eql(u8, message, "shift amount is greater than the type size")) {
+        \\        std.process.exit(126); // good
+        \\    }
+        \\    std.process.exit(0); // test failed
+        \\}
+        \\pub fn main() void {
+        \\    var x: u24 = 42;
+        \\    var y: u5 = 24;
+        \\    var z = x >> y;
+        \\}
+    );
+
+    cases.addRuntimeSafety("shift right by huge amount",
+        \\const std = @import("std");
+        \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+        \\    if (std.mem.eql(u8, message, "shift amount is greater than the type size")) {
+        \\        std.process.exit(126); // good
+        \\    }
+        \\    std.process.exit(0); // test failed
+        \\}
+        \\pub fn main() void {
+        \\    var x: u24 = 42;
+        \\    var y: u5 = 24;
+        \\    var z = x << y;
+        \\}
+    );
+
     cases.addRuntimeSafety("slice sentinel mismatch - optional pointers",
         \\const std = @import("std");
         \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {