Commit f330eebe4b

Andrew Kelley <andrew@ziglang.org>
2019-02-07 22:02:45
fix using the result of @intCast to u0
closes #1817
1 parent 7843c96
src/all_types.hpp
@@ -2239,6 +2239,7 @@ enum IrInstructionId {
     IrInstructionIdCheckRuntimeScope,
     IrInstructionIdVectorToArray,
     IrInstructionIdArrayToVector,
+    IrInstructionIdAssertZero,
 };
 
 struct IrInstruction {
@@ -3381,6 +3382,12 @@ struct IrInstructionVectorToArray {
     LLVMValueRef tmp_ptr;
 };
 
+struct IrInstructionAssertZero {
+    IrInstruction base;
+
+    IrInstruction *target;
+};
+
 static const size_t slice_ptr_index = 0;
 static const size_t slice_len_index = 1;
 
src/codegen.cpp
@@ -1651,10 +1651,25 @@ static void add_bounds_check(CodeGen *g, LLVMValueRef target_val,
     LLVMPositionBuilderAtEnd(g->builder, ok_block);
 }
 
+static LLVMValueRef gen_assert_zero(CodeGen *g, LLVMValueRef expr_val, ZigType *int_type) {
+    LLVMValueRef zero = LLVMConstNull(int_type->type_ref);
+    LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, expr_val, zero, "");
+    LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CastShortenOk");
+    LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "CastShortenFail");
+    LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
+
+    LLVMPositionBuilderAtEnd(g->builder, fail_block);
+    gen_safety_crash(g, PanicMsgIdCastTruncatedData);
+
+    LLVMPositionBuilderAtEnd(g->builder, ok_block);
+    return nullptr;
+}
+
 static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, ZigType *actual_type,
         ZigType *wanted_type, LLVMValueRef expr_val)
 {
     assert(actual_type->id == wanted_type->id);
+    assert(expr_val != nullptr);
 
     uint64_t actual_bits;
     uint64_t wanted_bits;
@@ -1707,17 +1722,7 @@ static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, Z
                 if (!want_runtime_safety)
                     return nullptr;
 
-                LLVMValueRef zero = LLVMConstNull(actual_type->type_ref);
-                LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, expr_val, zero, "");
-                LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CastShortenOk");
-                LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "CastShortenFail");
-                LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
-
-                LLVMPositionBuilderAtEnd(g->builder, fail_block);
-                gen_safety_crash(g, PanicMsgIdCastTruncatedData);
-
-                LLVMPositionBuilderAtEnd(g->builder, ok_block);
-                return nullptr;
+                return gen_assert_zero(g, expr_val, actual_type);
             }
             LLVMValueRef trunc_val = LLVMBuildTrunc(g->builder, expr_val, wanted_type->type_ref, "");
             if (!want_runtime_safety) {
@@ -5209,6 +5214,17 @@ static LLVMValueRef ir_render_array_to_vector(CodeGen *g, IrExecutable *executab
     return gen_load_untyped(g, casted_ptr, 0, false, "");
 }
 
+static LLVMValueRef ir_render_assert_zero(CodeGen *g, IrExecutable *executable,
+        IrInstructionAssertZero *instruction)
+{
+    LLVMValueRef target = ir_llvm_value(g, instruction->target);
+    ZigType *int_type = instruction->target->value.type;
+    if (ir_want_runtime_safety(g, &instruction->base)) {
+        return gen_assert_zero(g, target, int_type);
+    }
+    return nullptr;
+}
+
 static void set_debug_location(CodeGen *g, IrInstruction *instruction) {
     AstNode *source_node = instruction->source_node;
     Scope *scope = instruction->scope;
@@ -5458,6 +5474,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
             return ir_render_array_to_vector(g, executable, (IrInstructionArrayToVector *)instruction);
         case IrInstructionIdVectorToArray:
             return ir_render_vector_to_array(g, executable, (IrInstructionVectorToArray *)instruction);
+        case IrInstructionIdAssertZero:
+            return ir_render_assert_zero(g, executable, (IrInstructionAssertZero *)instruction);
     }
     zig_unreachable();
 }
src/ir.cpp
@@ -908,6 +908,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionArrayToVector *)
     return IrInstructionIdArrayToVector;
 }
 
+static constexpr IrInstructionId ir_instruction_id(IrInstructionAssertZero *) {
+    return IrInstructionIdAssertZero;
+}
+
 template<typename T>
 static T *ir_create_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node) {
     T *special_instruction = allocate<T>(1);
@@ -2858,6 +2862,19 @@ static IrInstruction *ir_build_array_to_vector(IrAnalyze *ira, IrInstruction *so
     return &instruction->base;
 }
 
+static IrInstruction *ir_build_assert_zero(IrAnalyze *ira, IrInstruction *source_instruction,
+        IrInstruction *target)
+{
+    IrInstructionAssertZero *instruction = ir_build_instruction<IrInstructionAssertZero>(&ira->new_irb,
+        source_instruction->scope, source_instruction->source_node);
+    instruction->base.value.type = ira->codegen->builtin_types.entry_void;
+    instruction->target = target;
+
+    ir_ref_instruction(target, ira->new_irb.current_basic_block);
+
+    return &instruction->base;
+}
+
 static void ir_count_defers(IrBuilder *irb, Scope *inner_scope, Scope *outer_scope, size_t *results) {
     results[ReturnKindUnconditional] = 0;
     results[ReturnKindError] = 0;
@@ -10395,6 +10412,18 @@ static IrInstruction *ir_analyze_widen_or_shorten(IrAnalyze *ira, IrInstruction
         return result;
     }
 
+    // If the destination integer type has no bits, then we can emit a comptime
+    // zero. However, we still want to emit a runtime safety check to make sure
+    // the target is zero.
+    if (!type_has_bits(wanted_type)) {
+        assert(wanted_type->id == ZigTypeIdInt);
+        assert(type_has_bits(target->value.type));
+        ir_build_assert_zero(ira, source_instr, target);
+        IrInstruction *result = ir_const_unsigned(ira, source_instr, 0);
+        result->value.type = wanted_type;
+        return result;
+    }
+
     IrInstruction *result = ir_build_widen_or_shorten(&ira->new_irb, source_instr->scope,
             source_instr->source_node, target);
     result->value.type = wanted_type;
@@ -21705,6 +21734,7 @@ static IrInstruction *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructio
         case IrInstructionIdCmpxchgGen:
         case IrInstructionIdArrayToVector:
         case IrInstructionIdVectorToArray:
+        case IrInstructionIdAssertZero:
             zig_unreachable();
 
         case IrInstructionIdReturn:
@@ -22103,6 +22133,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdAtomicRmw:
         case IrInstructionIdCmpxchgGen:
         case IrInstructionIdCmpxchgSrc:
+        case IrInstructionIdAssertZero:
             return true;
 
         case IrInstructionIdPhi:
src/ir_print.cpp
@@ -984,6 +984,12 @@ static void ir_print_vector_to_array(IrPrint *irp, IrInstructionVectorToArray *i
     fprintf(irp->f, ")");
 }
 
+static void ir_print_assert_zero(IrPrint *irp, IrInstructionAssertZero *instruction) {
+    fprintf(irp->f, "AssertZero(");
+    ir_print_other_instruction(irp, instruction->target);
+    fprintf(irp->f, ")");
+}
+
 static void ir_print_int_to_err(IrPrint *irp, IrInstructionIntToErr *instruction) {
     fprintf(irp->f, "inttoerr ");
     ir_print_other_instruction(irp, instruction->target);
@@ -1843,6 +1849,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
         case IrInstructionIdVectorToArray:
             ir_print_vector_to_array(irp, (IrInstructionVectorToArray *)instruction);
             break;
+        case IrInstructionIdAssertZero:
+            ir_print_assert_zero(irp, (IrInstructionAssertZero *)instruction);
+            break;
     }
     fprintf(irp->f, "\n");
 }
test/stage1/behavior/cast.zig
@@ -471,3 +471,14 @@ test "@intToEnum passed a comptime_int to an enum with one item" {
     const x = @intToEnum(E, 0);
     assertOrPanic(x == E.A);
 }
+
+test "@intCast to u0 and use the result" {
+    const S = struct {
+        fn doTheTest(zero: u1, one: u1, bigzero: i32) void {
+            assertOrPanic((one << @intCast(u0, bigzero)) == 1);
+            assertOrPanic((zero << @intCast(u0, bigzero)) == 0);
+        }
+    };
+    S.doTheTest(0, 1, 0);
+    comptime S.doTheTest(0, 1, 0);
+}
test/stage1/behavior/eval.zig
@@ -697,12 +697,6 @@ test "bit shift a u1" {
     assertOrPanic(y == 1);
 }
 
-test "@intCast to a u0" {
-    var x: u8 = 0;
-    var y: u0 = @intCast(u0, x);
-    assertOrPanic(y == 0);
-}
-
 test "@bytesToslice on a packed struct" {
     const F = packed struct {
         a: u8,
test/runtime_safety.zig
@@ -362,6 +362,23 @@ pub fn addCases(cases: *tests.CompareOutputContext) void {
         \\}
     );
 
+    // @intCast a runtime integer to u0 actually results in a comptime-known value,
+    // but we still emit a safety check to ensure the integer was 0 and thus
+    // did not truncate information.
+    cases.addRuntimeSafety("@intCast to u0",
+        \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+        \\    @import("std").os.exit(126);
+        \\}
+        \\
+        \\pub fn main() void {
+        \\    bar(1, 1);
+        \\}
+        \\
+        \\fn bar(one: u1, not_zero: i32) void {
+        \\    var x = one << @intCast(u0, not_zero);
+        \\}
+    );
+
     // This case makes sure that the code compiles and runs. There is not actually a special
     // runtime safety check having to do specifically with error return traces across suspend points.
     cases.addRuntimeSafety("error return trace across suspend points",