Commit 2419f0c914

Andrew Kelley <superjoe30@gmail.com>
2016-12-19 23:25:09
IR: support maybe defers
1 parent 09d50e3
src/all_types.hpp
@@ -1410,7 +1410,7 @@ enum IrInstructionId {
     IrInstructionIdAsm,
     IrInstructionIdCompileVar,
     IrInstructionIdSizeOf,
-    IrInstructionIdTestNull,
+    IrInstructionIdTestNonNull,
     IrInstructionIdUnwrapMaybe,
     IrInstructionIdMaybeWrap,
     IrInstructionIdEnumTag,
@@ -1787,7 +1787,7 @@ struct IrInstructionSizeOf {
 
 // returns true if nonnull, returns false if null
 // this is so that `zeroes` sets maybe values to null
-struct IrInstructionTestNull {
+struct IrInstructionTestNonNull {
     IrInstruction base;
 
     IrInstruction *value;
src/codegen.cpp
@@ -1554,26 +1554,20 @@ static LLVMValueRef ir_render_asm(CodeGen *g, IrExecutable *executable, IrInstru
     return LLVMBuildCall(g->builder, asm_fn, param_values, input_and_output_count, "");
 }
 
-// 0 - null, 1 - non null
-static LLVMValueRef gen_null_bit(CodeGen *g, TypeTableEntry *ptr_type, LLVMValueRef maybe_ptr) {
-    assert(ptr_type->id == TypeTableEntryIdPointer);
-    TypeTableEntry *maybe_type = ptr_type->data.pointer.child_type;
-    assert(maybe_type->id == TypeTableEntryIdMaybe);
-    TypeTableEntry *child_type = maybe_type->data.maybe.child_type;
-    LLVMValueRef maybe_struct_ref = get_handle_value(g, maybe_ptr, maybe_type);
-    bool maybe_is_ptr = (child_type->id == TypeTableEntryIdPointer || child_type->id == TypeTableEntryIdFn);
+static LLVMValueRef gen_non_null_bit(CodeGen *g, TypeTableEntry *maybe_type, LLVMValueRef maybe_handle) {
+    bool maybe_is_ptr = (maybe_type->id == TypeTableEntryIdPointer || maybe_type->id == TypeTableEntryIdFn);
     if (maybe_is_ptr) {
-        return LLVMBuildICmp(g->builder, LLVMIntNE, maybe_struct_ref, LLVMConstNull(child_type->type_ref), "");
+        return LLVMBuildICmp(g->builder, LLVMIntNE, maybe_handle, LLVMConstNull(maybe_type->type_ref), "");
     } else {
-        LLVMValueRef maybe_field_ptr = LLVMBuildStructGEP(g->builder, maybe_struct_ref, maybe_null_index, "");
+        LLVMValueRef maybe_field_ptr = LLVMBuildStructGEP(g->builder, maybe_handle, maybe_null_index, "");
         return LLVMBuildLoad(g->builder, maybe_field_ptr, "");
     }
 }
 
-static LLVMValueRef ir_render_test_null(CodeGen *g, IrExecutable *executable, IrInstructionTestNull *instruction) {
-    TypeTableEntry *ptr_type = instruction->value->type_entry;
-    assert(ptr_type->id == TypeTableEntryIdPointer);
-    return gen_null_bit(g, ptr_type, ir_llvm_value(g, instruction->value));
+static LLVMValueRef ir_render_test_non_null(CodeGen *g, IrExecutable *executable,
+    IrInstructionTestNonNull *instruction)
+{
+    return gen_non_null_bit(g, instruction->value->type_entry, ir_llvm_value(g, instruction->value));
 }
 
 static LLVMValueRef ir_render_unwrap_maybe(CodeGen *g, IrExecutable *executable,
@@ -1586,11 +1580,12 @@ static LLVMValueRef ir_render_unwrap_maybe(CodeGen *g, IrExecutable *executable,
     TypeTableEntry *child_type = maybe_type->data.maybe.child_type;
     bool maybe_is_ptr = (child_type->id == TypeTableEntryIdPointer || child_type->id == TypeTableEntryIdFn);
     LLVMValueRef maybe_ptr = ir_llvm_value(g, instruction->value);
+    LLVMValueRef maybe_handle = get_handle_value(g, maybe_ptr, maybe_type);
     if (ir_want_debug_safety(g, &instruction->base) && instruction->safety_check_on) {
-        LLVMValueRef nonnull_bit = gen_null_bit(g, ptr_type, maybe_ptr);
+        LLVMValueRef non_null_bit = gen_non_null_bit(g, maybe_type, maybe_handle);
         LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "UnwrapMaybeOk");
         LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "UnwrapMaybeFail");
-        LLVMBuildCondBr(g->builder, nonnull_bit, ok_block, fail_block);
+        LLVMBuildCondBr(g->builder, non_null_bit, ok_block, fail_block);
 
         LLVMPositionBuilderAtEnd(g->builder, fail_block);
         gen_debug_safety_crash(g);
@@ -2227,8 +2222,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
             return ir_render_enum_field_ptr(g, executable, (IrInstructionEnumFieldPtr *)instruction);
         case IrInstructionIdAsm:
             return ir_render_asm(g, executable, (IrInstructionAsm *)instruction);
-        case IrInstructionIdTestNull:
-            return ir_render_test_null(g, executable, (IrInstructionTestNull *)instruction);
+        case IrInstructionIdTestNonNull:
+            return ir_render_test_non_null(g, executable, (IrInstructionTestNonNull *)instruction);
         case IrInstructionIdUnwrapMaybe:
             return ir_render_unwrap_maybe(g, executable, (IrInstructionUnwrapMaybe *)instruction);
         case IrInstructionIdClz:
src/ir.cpp
@@ -275,8 +275,8 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionSizeOf *) {
     return IrInstructionIdSizeOf;
 }
 
-static constexpr IrInstructionId ir_instruction_id(IrInstructionTestNull *) {
-    return IrInstructionIdTestNull;
+static constexpr IrInstructionId ir_instruction_id(IrInstructionTestNonNull *) {
+    return IrInstructionIdTestNonNull;
 }
 
 static constexpr IrInstructionId ir_instruction_id(IrInstructionUnwrapMaybe *) {
@@ -1200,7 +1200,7 @@ static IrInstruction *ir_build_size_of(IrBuilder *irb, Scope *scope, AstNode *so
 }
 
 static IrInstruction *ir_build_test_null(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *value) {
-    IrInstructionTestNull *instruction = ir_build_instruction<IrInstructionTestNull>(irb, scope, source_node);
+    IrInstructionTestNonNull *instruction = ir_build_instruction<IrInstructionTestNonNull>(irb, scope, source_node);
     instruction->value = value;
 
     ir_ref_instruction(value);
@@ -1956,10 +1956,27 @@ static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node,
                     ir_gen_defers_for_block(irb, scope, outer_scope, false, false);
                     return ir_build_return(irb, scope, node, return_value);
                 } else if (defer_counts[ReturnKindMaybe] > 0) {
-                    // TODO in this situation we need to make a conditional
-                    // branch on the maybe value. we potentially must make multiple conditional branches,
-                    // if unconditional defers are interleaved with error defers.
-                    zig_panic("TODO handle maybe defers");
+                    IrBasicBlock *null_block = ir_build_basic_block(irb, scope, "MaybeRetNull");
+                    IrBasicBlock *ok_block = ir_build_basic_block(irb, scope, "MaybeRetOk");
+
+                    IrInstruction *is_non_null = ir_build_test_null(irb, scope, node, return_value);
+
+                    IrInstruction *is_comptime;
+                    if (ir_should_inline(irb)) {
+                        is_comptime = ir_build_const_bool(irb, scope, node, true);
+                    } else {
+                        is_comptime = ir_build_test_comptime(irb, scope, node, is_non_null);
+                    }
+
+                    ir_build_cond_br(irb, scope, node, is_non_null, ok_block, null_block, is_comptime);
+
+                    ir_set_cursor_at_end(irb, null_block);
+                    ir_gen_defers_for_block(irb, scope, outer_scope, false, true);
+                    ir_build_return(irb, scope, node, return_value);
+
+                    ir_set_cursor_at_end(irb, ok_block);
+                    ir_gen_defers_for_block(irb, scope, outer_scope, false, false);
+                    return ir_build_return(irb, scope, node, return_value);
                 } else {
                     // generate unconditional defers
                     ir_gen_defers_for_block(irb, scope, outer_scope, false, false);
@@ -1998,12 +2015,13 @@ static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node,
                 IrInstruction *maybe_val_ptr = ir_gen_node_extra(irb, expr_node, scope, LValPurposeAddressOf);
                 if (maybe_val_ptr == irb->codegen->invalid_instruction)
                     return irb->codegen->invalid_instruction;
-                IrInstruction *is_nonnull_val = ir_build_test_null(irb, scope, node, maybe_val_ptr);
+                IrInstruction *maybe_val = ir_build_load_ptr(irb, scope, node, maybe_val_ptr);
+                IrInstruction *is_non_null = ir_build_test_null(irb, scope, node, maybe_val);
 
                 IrBasicBlock *return_block = ir_build_basic_block(irb, scope, "MaybeRetReturn");
                 IrBasicBlock *continue_block = ir_build_basic_block(irb, scope, "MaybeRetContinue");
                 IrInstruction *is_comptime = ir_build_const_bool(irb, scope, node, ir_should_inline(irb));
-                ir_build_cond_br(irb, scope, node, is_nonnull_val, continue_block, return_block, is_comptime);
+                ir_build_cond_br(irb, scope, node, is_non_null, continue_block, return_block, is_comptime);
 
                 ir_set_cursor_at_end(irb, return_block);
                 ir_gen_defers_for_block(irb, scope, outer_scope, false, true);
@@ -2247,7 +2265,8 @@ static IrInstruction *ir_gen_maybe_ok_or(IrBuilder *irb, Scope *parent_scope, As
     if (maybe_ptr == irb->codegen->invalid_instruction)
         return irb->codegen->invalid_instruction;
 
-    IrInstruction *is_non_null = ir_build_test_null(irb, parent_scope, node, maybe_ptr);
+    IrInstruction *maybe_val = ir_build_load_ptr(irb, parent_scope, node, maybe_ptr);
+    IrInstruction *is_non_null = ir_build_test_null(irb, parent_scope, node, maybe_val);
 
     IrInstruction *is_comptime;
     if (ir_should_inline(irb)) {
@@ -3514,11 +3533,12 @@ static IrInstruction *ir_gen_if_var_expr(IrBuilder *irb, Scope *scope, AstNode *
     AstNode *else_node = node->data.if_var_expr.else_node;
     bool var_is_ptr = node->data.if_var_expr.var_is_ptr;
 
-    IrInstruction *expr_value = ir_gen_node_extra(irb, expr_node, scope, LValPurposeAddressOf);
-    if (expr_value == irb->codegen->invalid_instruction)
-        return expr_value;
+    IrInstruction *maybe_val_ptr = ir_gen_node_extra(irb, expr_node, scope, LValPurposeAddressOf);
+    if (maybe_val_ptr == irb->codegen->invalid_instruction)
+        return maybe_val_ptr;
 
-    IrInstruction *is_nonnull_value = ir_build_test_null(irb, scope, node, expr_value);
+    IrInstruction *maybe_val = ir_build_load_ptr(irb, scope, node, maybe_val_ptr);
+    IrInstruction *is_non_null = ir_build_test_null(irb, scope, node, maybe_val);
 
     IrBasicBlock *then_block = ir_build_basic_block(irb, scope, "MaybeThen");
     IrBasicBlock *else_block = ir_build_basic_block(irb, scope, "MaybeElse");
@@ -3528,9 +3548,9 @@ static IrInstruction *ir_gen_if_var_expr(IrBuilder *irb, Scope *scope, AstNode *
     if (ir_should_inline(irb) || node->data.if_var_expr.is_inline) {
         is_comptime = ir_build_const_bool(irb, scope, node, true);
     } else {
-        is_comptime = ir_build_test_comptime(irb, scope, node, is_nonnull_value);
+        is_comptime = ir_build_test_comptime(irb, scope, node, is_non_null);
     }
-    ir_build_cond_br(irb, scope, node, is_nonnull_value, then_block, else_block, is_comptime);
+    ir_build_cond_br(irb, scope, node, is_non_null, then_block, else_block, is_comptime);
 
     ir_set_cursor_at_end(irb, then_block);
     IrInstruction *var_type = nullptr;
@@ -3544,7 +3564,7 @@ static IrInstruction *ir_gen_if_var_expr(IrBuilder *irb, Scope *scope, AstNode *
     VariableTableEntry *var = ir_create_var(irb, node, scope,
             var_decl->symbol, is_const, is_const, is_shadowable, is_comptime);
 
-    IrInstruction *var_ptr_value = ir_build_unwrap_maybe(irb, scope, node, expr_value, false);
+    IrInstruction *var_ptr_value = ir_build_unwrap_maybe(irb, scope, node, maybe_val_ptr, false);
     IrInstruction *var_value = var_is_ptr ? var_ptr_value : ir_build_load_ptr(irb, scope, node, var_ptr_value);
     ir_build_var_decl(irb, scope, node, var, var_type, var_value);
     IrInstruction *then_expr_result = ir_gen_node(irb, then_node, var->child_scope);
@@ -4440,7 +4460,7 @@ static ImplicitCastMatchResult ir_types_match_with_implicit_cast(IrAnalyze *ira,
     }
 
     // implicitly take a const pointer to something
-    {
+    if (!type_requires_comptime(actual_type)) {
         TypeTableEntry *const_ptr_actual = get_pointer_to_type(ira->codegen, actual_type, true);
         if (types_match_const_cast_only(expected_type, const_ptr_actual)) {
             return ImplicitCastMatchResultYes;
@@ -5230,7 +5250,7 @@ static IrInstruction *ir_analyze_cast(IrAnalyze *ira, IrInstruction *source_inst
     }
 
     // explicit cast from something to const pointer of it
-    {
+    if (!type_requires_comptime(actual_type)) {
         TypeTableEntry *const_ptr_actual = get_pointer_to_type(ira->codegen, actual_type, true);
         if (types_match_const_cast_only(wanted_type, const_ptr_actual)) {
             return ir_analyze_cast_ref(ira, source_instr, value, wanted_type);
@@ -7763,39 +7783,36 @@ static TypeTableEntry *ir_analyze_instruction_size_of(IrAnalyze *ira,
     zig_unreachable();
 }
 
-static TypeTableEntry *ir_analyze_instruction_test_null(IrAnalyze *ira,
-        IrInstructionTestNull *test_null_instruction)
-{
-    IrInstruction *value = test_null_instruction->value->other;
+static TypeTableEntry *ir_analyze_instruction_test_non_null(IrAnalyze *ira, IrInstructionTestNonNull *instruction) {
+    IrInstruction *value = instruction->value->other;
     if (value->type_entry->id == TypeTableEntryIdInvalid)
         return ira->codegen->builtin_types.entry_invalid;
 
-    // This will be a pointer type because test null IR instruction operates on a pointer to a thing.
-    TypeTableEntry *ptr_type = value->type_entry;
-    assert(ptr_type->id == TypeTableEntryIdPointer);
-
-    TypeTableEntry *type_entry = ptr_type->data.pointer.child_type;
-    if (type_entry->id != TypeTableEntryIdMaybe) {
-        add_node_error(ira->codegen, test_null_instruction->base.source_node,
-                buf_sprintf("expected nullable type, found '%s'", buf_ptr(&type_entry->name)));
-        return ira->codegen->builtin_types.entry_invalid;
-    }
+    TypeTableEntry *type_entry = value->type_entry;
 
-    if (value->static_value.special != ConstValSpecialRuntime) {
-        ConstExprValue *maybe_val = value->static_value.data.x_ptr.base_ptr;
-        assert(value->static_value.data.x_ptr.index == SIZE_MAX);
+    if (type_entry->id == TypeTableEntryIdMaybe) {
+        if (instr_is_comptime(value)) {
+            ConstExprValue *maybe_val = ir_resolve_const(ira, value, UndefBad);
+            if (!maybe_val)
+                return ira->codegen->builtin_types.entry_invalid;
 
-        if (maybe_val->special != ConstValSpecialRuntime) {
-            bool depends_on_compile_var = maybe_val->depends_on_compile_var;
-            ConstExprValue *out_val = ir_build_const_from(ira, &test_null_instruction->base,
-                    depends_on_compile_var);
-            out_val->data.x_bool = (maybe_val->data.x_maybe == nullptr);
+            ConstExprValue *out_val = ir_build_const_from(ira, &instruction->base,
+                maybe_val->depends_on_compile_var);
+            out_val->data.x_bool = (maybe_val->data.x_maybe != nullptr);
             return ira->codegen->builtin_types.entry_bool;
         }
-    }
 
-    ir_build_test_null_from(&ira->new_irb, &test_null_instruction->base, value);
-    return ira->codegen->builtin_types.entry_bool;
+        ir_build_test_null_from(&ira->new_irb, &instruction->base, value);
+        return ira->codegen->builtin_types.entry_bool;
+    } else if (type_entry->id == TypeTableEntryIdNullLit) {
+        ConstExprValue *out_val = ir_build_const_from(ira, &instruction->base, false);
+        out_val->data.x_bool = false;
+        return ira->codegen->builtin_types.entry_bool;
+    } else {
+        ConstExprValue *out_val = ir_build_const_from(ira, &instruction->base, false);
+        out_val->data.x_bool = true;
+        return ira->codegen->builtin_types.entry_bool;
+    }
 }
 
 static TypeTableEntry *ir_analyze_instruction_unwrap_maybe(IrAnalyze *ira,
@@ -9689,8 +9706,8 @@ static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructi
             return ir_analyze_instruction_compile_var(ira, (IrInstructionCompileVar *)instruction);
         case IrInstructionIdSizeOf:
             return ir_analyze_instruction_size_of(ira, (IrInstructionSizeOf *)instruction);
-        case IrInstructionIdTestNull:
-            return ir_analyze_instruction_test_null(ira, (IrInstructionTestNull *)instruction);
+        case IrInstructionIdTestNonNull:
+            return ir_analyze_instruction_test_non_null(ira, (IrInstructionTestNonNull *)instruction);
         case IrInstructionIdUnwrapMaybe:
             return ir_analyze_instruction_unwrap_maybe(ira, (IrInstructionUnwrapMaybe *)instruction);
         case IrInstructionIdClz:
@@ -9908,7 +9925,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdSliceType:
         case IrInstructionIdCompileVar:
         case IrInstructionIdSizeOf:
-        case IrInstructionIdTestNull:
+        case IrInstructionIdTestNonNull:
         case IrInstructionIdUnwrapMaybe:
         case IrInstructionIdClz:
         case IrInstructionIdCtz:
src/ir_print.cpp
@@ -579,7 +579,7 @@ static void ir_print_size_of(IrPrint *irp, IrInstructionSizeOf *instruction) {
     fprintf(irp->f, ")");
 }
 
-static void ir_print_test_null(IrPrint *irp, IrInstructionTestNull *instruction) {
+static void ir_print_test_null(IrPrint *irp, IrInstructionTestNonNull *instruction) {
     fprintf(irp->f, "*");
     ir_print_other_instruction(irp, instruction->value);
     fprintf(irp->f, " != null");
@@ -1012,8 +1012,8 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
         case IrInstructionIdSizeOf:
             ir_print_size_of(irp, (IrInstructionSizeOf *)instruction);
             break;
-        case IrInstructionIdTestNull:
-            ir_print_test_null(irp, (IrInstructionTestNull *)instruction);
+        case IrInstructionIdTestNonNull:
+            ir_print_test_null(irp, (IrInstructionTestNonNull *)instruction);
             break;
         case IrInstructionIdUnwrapMaybe:
             ir_print_unwrap_maybe(irp, (IrInstructionUnwrapMaybe *)instruction);
test/cases3/defer.zig
@@ -3,7 +3,7 @@ var index: usize = undefined;
 
 error FalseNotAllowed;
 
-fn runSomeDefers(x: bool) -> %bool {
+fn runSomeErrorDefers(x: bool) -> %bool {
     index = 0;
     defer {result[index] = 'a'; index += 1;};
     %defer {result[index] = 'b'; index += 1;};
@@ -11,14 +11,22 @@ fn runSomeDefers(x: bool) -> %bool {
     return if (x) x else error.FalseNotAllowed;
 }
 
+fn runSomeMaybeDefers(x: bool) -> ?bool {
+    index = 0;
+    defer {result[index] = 'a'; index += 1;};
+    ?defer {result[index] = 'b'; index += 1;};
+    defer {result[index] = 'c'; index += 1;};
+    return if (x) x else null;
+}
+
 fn mixingNormalAndErrorDefers() {
     @setFnTest(this);
 
-    assert(%%runSomeDefers(true));
+    assert(%%runSomeErrorDefers(true));
     assert(result[0] == 'c');
     assert(result[1] == 'a');
 
-    const ok = runSomeDefers(false) %% |err| {
+    const ok = runSomeErrorDefers(false) %% |err| {
         assert(err == error.FalseNotAllowed);
         true
     };
@@ -28,6 +36,20 @@ fn mixingNormalAndErrorDefers() {
     assert(result[2] == 'a');
 }
 
+fn mixingNormalAndMaybeDefers() {
+    @setFnTest(this);
+
+    assert(??runSomeMaybeDefers(true));
+    assert(result[0] == 'c');
+    assert(result[1] == 'a');
+
+    const ok = runSomeMaybeDefers(false) ?? true;
+    assert(ok);
+    assert(result[0] == 'c');
+    assert(result[1] == 'b');
+    assert(result[2] == 'a');
+}
+
 // TODO const assert = @import("std").debug.assert;
 fn assert(ok: bool) {
     if (!ok)