Commit 22428a7546

Andrew Kelley <andrew@ziglang.org>
2019-08-10 21:20:08
fix try in an async function with error union and non-zero-bit payload
1 parent b9d1d45
src/all_types.hpp
@@ -74,6 +74,7 @@ struct IrExecutable {
     bool invalid;
     bool is_inline;
     bool is_generic_instantiation;
+    bool need_err_code_spill;
 };
 
 enum OutType {
@@ -1384,6 +1385,7 @@ struct ZigFn {
     size_t prealloc_backward_branch_quota;
     AstNode **param_source_nodes;
     Buf **param_names;
+    IrInstruction *err_code_spill;
 
     AstNode *fn_no_inline_set_node;
     AstNode *fn_static_eval_set_node;
@@ -2366,7 +2368,8 @@ enum IrInstructionId {
     IrInstructionIdAwaitGen,
     IrInstructionIdCoroResume,
     IrInstructionIdTestCancelRequested,
-    IrInstructionIdSpill,
+    IrInstructionIdSpillBegin,
+    IrInstructionIdSpillEnd,
 };
 
 struct IrInstruction {
@@ -3649,13 +3652,19 @@ enum SpillId {
     SpillIdRetErrCode,
 };
 
-struct IrInstructionSpill {
+struct IrInstructionSpillBegin {
     IrInstruction base;
 
     SpillId spill_id;
     IrInstruction *operand;
 };
 
+struct IrInstructionSpillEnd {
+    IrInstruction base;
+
+    IrInstructionSpillBegin *begin;
+};
+
 enum ResultLocId {
     ResultLocIdInvalid,
     ResultLocIdNone,
src/analyze.cpp
@@ -5190,6 +5190,18 @@ static Error resolve_coro_frame(CodeGen *g, ZigType *frame_type) {
     }
     ZigType *fn_type = get_async_fn_type(g, fn->type_entry);
 
+    if (fn->analyzed_executable.need_err_code_spill) {
+        IrInstructionAllocaGen *alloca_gen = allocate<IrInstructionAllocaGen>(1);
+        alloca_gen->base.id = IrInstructionIdAllocaGen;
+        alloca_gen->base.source_node = fn->proto_node;
+        alloca_gen->base.scope = fn->child_scope;
+        alloca_gen->base.value.type = get_pointer_to_type(g, g->builtin_types.entry_global_error_set, false);
+        alloca_gen->base.ref_count = 1;
+        alloca_gen->name_hint = "";
+        fn->alloca_gen_list.append(alloca_gen);
+        fn->err_code_spill = &alloca_gen->base;
+    }
+
     for (size_t i = 0; i < fn->call_list.length; i += 1) {
         IrInstructionCallGen *call = fn->call_list.at(i);
         ZigFn *callee = call->fn_entry;
src/codegen.cpp
@@ -2274,16 +2274,16 @@ static LLVMValueRef gen_maybe_atomic_op(CodeGen *g, LLVMAtomicRMWBinOp op, LLVMV
 static LLVMValueRef ir_render_return_begin(CodeGen *g, IrExecutable *executable,
         IrInstructionReturnBegin *instruction)
 {
-    bool ret_type_has_bits = instruction->operand != nullptr &&
-        type_has_bits(instruction->operand->value.type);
-
+    ZigType *operand_type = (instruction->operand != nullptr) ? instruction->operand->value.type : nullptr;
+    bool operand_has_bits = (operand_type != nullptr) && type_has_bits(operand_type);
     if (!fn_is_async(g->cur_fn)) {
-        return ret_type_has_bits ? ir_llvm_value(g, instruction->operand) : nullptr;
+        return operand_has_bits ? ir_llvm_value(g, instruction->operand) : nullptr;
     }
 
+    ZigType *ret_type = g->cur_fn->type_entry->data.fn.fn_type_id.return_type;
+    bool ret_type_has_bits = type_has_bits(ret_type);
     LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
 
-    ZigType *ret_type = ret_type_has_bits ? instruction->operand->value.type : nullptr;
     if (ret_type_has_bits && !handle_is_ptr(ret_type)) {
         // It's a scalar, so it didn't get written to the result ptr. Do that before the atomic rmw.
         LLVMBuildStore(g->builder, ir_llvm_value(g, instruction->operand), g->cur_ret_ptr);
@@ -2333,11 +2333,11 @@ static LLVMValueRef ir_render_return_begin(CodeGen *g, IrExecutable *executable,
     g->cur_is_after_return = true;
     LLVMBuildStore(g->builder, g->cur_async_prev_val, g->cur_async_prev_val_field_ptr);
 
-    if (!ret_type_has_bits) {
+    if (!operand_has_bits) {
         return nullptr;
     }
 
-    return get_handle_value(g, g->cur_ret_ptr, ret_type, get_pointer_to_type(g, ret_type, true));
+    return get_handle_value(g, g->cur_ret_ptr, operand_type, get_pointer_to_type(g, operand_type, true));
 }
 
 static LLVMValueRef ir_render_return(CodeGen *g, IrExecutable *executable, IrInstructionReturn *instruction) {
@@ -5113,18 +5113,6 @@ static LLVMValueRef ir_render_test_err(CodeGen *g, IrExecutable *executable, IrI
     return LLVMBuildICmp(g->builder, LLVMIntNE, err_val, zero, "");
 }
 
-static LLVMValueRef gen_unwrap_err_code(CodeGen *g, LLVMValueRef err_union_ptr, ZigType *ptr_type) {
-    ZigType *err_union_type = ptr_type->data.pointer.child_type;
-    ZigType *payload_type = err_union_type->data.error_union.payload_type;
-    if (!type_has_bits(payload_type)) {
-        return err_union_ptr;
-    } else {
-        // TODO assign undef to the payload
-        LLVMValueRef err_union_handle = get_handle_value(g, err_union_ptr, err_union_type, ptr_type);
-        return LLVMBuildStructGEP(g->builder, err_union_handle, err_union_err_index, "");
-    }
-}
-
 static LLVMValueRef ir_render_unwrap_err_code(CodeGen *g, IrExecutable *executable,
         IrInstructionUnwrapErrCode *instruction)
 {
@@ -5133,8 +5121,16 @@ static LLVMValueRef ir_render_unwrap_err_code(CodeGen *g, IrExecutable *executab
 
     ZigType *ptr_type = instruction->err_union_ptr->value.type;
     assert(ptr_type->id == ZigTypeIdPointer);
+    ZigType *err_union_type = ptr_type->data.pointer.child_type;
+    ZigType *payload_type = err_union_type->data.error_union.payload_type;
     LLVMValueRef err_union_ptr = ir_llvm_value(g, instruction->err_union_ptr);
-    return gen_unwrap_err_code(g, err_union_ptr, ptr_type);
+    if (!type_has_bits(payload_type)) {
+        return err_union_ptr;
+    } else {
+        // TODO assign undef to the payload
+        LLVMValueRef err_union_handle = get_handle_value(g, err_union_ptr, err_union_type, ptr_type);
+        return LLVMBuildStructGEP(g->builder, err_union_handle, err_union_err_index, "");
+    }
 }
 
 static LLVMValueRef ir_render_unwrap_err_payload(CodeGen *g, IrExecutable *executable,
@@ -5615,21 +5611,36 @@ static LLVMValueRef ir_render_test_cancel_requested(CodeGen *g, IrExecutable *ex
     }
 }
 
-static LLVMValueRef ir_render_spill(CodeGen *g, IrExecutable *executable, IrInstructionSpill *instruction) {
+static LLVMValueRef ir_render_spill_begin(CodeGen *g, IrExecutable *executable,
+        IrInstructionSpillBegin *instruction)
+{
     if (!fn_is_async(g->cur_fn))
-        return ir_llvm_value(g, instruction->operand);
+        return nullptr;
 
     switch (instruction->spill_id) {
         case SpillIdInvalid:
             zig_unreachable();
         case SpillIdRetErrCode: {
-            LLVMValueRef ret_ptr = LLVMBuildLoad(g->builder, g->cur_ret_ptr, "");
-            ZigType *ret_type = g->cur_fn->type_entry->data.fn.fn_type_id.return_type;
-            if (ret_type->id == ZigTypeIdErrorUnion) {
-                return gen_unwrap_err_code(g, ret_ptr, get_pointer_to_type(g, ret_type, true));
-            } else {
-                zig_unreachable();
-            }
+            LLVMValueRef operand = ir_llvm_value(g, instruction->operand);
+            LLVMValueRef ptr = ir_llvm_value(g, g->cur_fn->err_code_spill);
+            LLVMBuildStore(g->builder, operand, ptr);
+            return nullptr;
+        }
+
+    }
+    zig_unreachable();
+}
+
+static LLVMValueRef ir_render_spill_end(CodeGen *g, IrExecutable *executable, IrInstructionSpillEnd *instruction) {
+    if (!fn_is_async(g->cur_fn))
+        return ir_llvm_value(g, instruction->begin->operand);
+
+    switch (instruction->begin->spill_id) {
+        case SpillIdInvalid:
+            zig_unreachable();
+        case SpillIdRetErrCode: {
+            LLVMValueRef ptr = ir_llvm_value(g, g->cur_fn->err_code_spill);
+            return LLVMBuildLoad(g->builder, ptr, "");
         }
 
     }
@@ -5891,8 +5902,10 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
             return ir_render_await(g, executable, (IrInstructionAwaitGen *)instruction);
         case IrInstructionIdTestCancelRequested:
             return ir_render_test_cancel_requested(g, executable, (IrInstructionTestCancelRequested *)instruction);
-        case IrInstructionIdSpill:
-            return ir_render_spill(g, executable, (IrInstructionSpill *)instruction);
+        case IrInstructionIdSpillBegin:
+            return ir_render_spill_begin(g, executable, (IrInstructionSpillBegin *)instruction);
+        case IrInstructionIdSpillEnd:
+            return ir_render_spill_end(g, executable, (IrInstructionSpillEnd *)instruction);
     }
     zig_unreachable();
 }
src/ir.cpp
@@ -1066,8 +1066,12 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionTestCancelReques
     return IrInstructionIdTestCancelRequested;
 }
 
-static constexpr IrInstructionId ir_instruction_id(IrInstructionSpill *) {
-    return IrInstructionIdSpill;
+static constexpr IrInstructionId ir_instruction_id(IrInstructionSpillBegin *) {
+    return IrInstructionIdSpillBegin;
+}
+
+static constexpr IrInstructionId ir_instruction_id(IrInstructionSpillEnd *) {
+    return IrInstructionIdSpillEnd;
 }
 
 template<typename T>
@@ -3336,15 +3340,28 @@ static IrInstruction *ir_build_test_cancel_requested(IrBuilder *irb, Scope *scop
     return &instruction->base;
 }
 
-static IrInstruction *ir_build_spill(IrBuilder *irb, Scope *scope, AstNode *source_node,
+static IrInstructionSpillBegin *ir_build_spill_begin(IrBuilder *irb, Scope *scope, AstNode *source_node,
         IrInstruction *operand, SpillId spill_id)
 {
-    IrInstructionSpill *instruction = ir_build_instruction<IrInstructionSpill>(irb, scope, source_node);
+    IrInstructionSpillBegin *instruction = ir_build_instruction<IrInstructionSpillBegin>(irb, scope, source_node);
+    instruction->base.value.special = ConstValSpecialStatic;
+    instruction->base.value.type = irb->codegen->builtin_types.entry_void;
     instruction->operand = operand;
     instruction->spill_id = spill_id;
 
     ir_ref_instruction(operand, irb->current_basic_block);
 
+    return instruction;
+}
+
+static IrInstruction *ir_build_spill_end(IrBuilder *irb, Scope *scope, AstNode *source_node,
+        IrInstructionSpillBegin *begin)
+{
+    IrInstructionSpillEnd *instruction = ir_build_instruction<IrInstructionSpillEnd>(irb, scope, source_node);
+    instruction->begin = begin;
+
+    ir_ref_instruction(&begin->base, irb->current_basic_block);
+
     return &instruction->base;
 }
 
@@ -3602,14 +3619,15 @@ static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node,
                 IrInstruction *err_val_ptr = ir_build_unwrap_err_code(irb, scope, node, err_union_ptr);
                 IrInstruction *err_val = ir_build_load_ptr(irb, scope, node, err_val_ptr);
                 ir_mark_gen(ir_build_add_implicit_return_type(irb, scope, node, err_val));
-                err_val = ir_build_return_begin(irb, scope, node, err_val);
+                IrInstructionSpillBegin *spill_begin = ir_build_spill_begin(irb, scope, node, err_val,
+                        SpillIdRetErrCode);
+                ir_build_return_begin(irb, scope, node, err_val);
+                err_val = ir_build_spill_end(irb, scope, node, spill_begin);
+                ResultLocReturn *result_loc_ret = allocate<ResultLocReturn>(1);
+                result_loc_ret->base.id = ResultLocIdReturn;
+                ir_build_reset_result(irb, scope, node, &result_loc_ret->base);
+                ir_build_end_expr(irb, scope, node, err_val, &result_loc_ret->base);
                 if (!ir_gen_defers_for_block(irb, scope, outer_scope, true)) {
-                    ResultLocReturn *result_loc_ret = allocate<ResultLocReturn>(1);
-                    result_loc_ret->base.id = ResultLocIdReturn;
-                    ir_build_reset_result(irb, scope, node, &result_loc_ret->base);
-                    err_val = ir_build_spill(irb, scope, node, err_val, SpillIdRetErrCode);
-                    ir_build_end_expr(irb, scope, node, err_val, &result_loc_ret->base);
-
                     if (irb->codegen->have_err_ret_tracing && !should_inline) {
                         ir_build_save_err_ret_addr(irb, scope, node);
                     }
@@ -12778,8 +12796,21 @@ static IrInstruction *ir_analyze_instruction_return(IrAnalyze *ira, IrInstructio
         return ir_finish_anal(ira, result);
     }
 
+    // This cast might have been already done from IrInstructionReturnBegin but it also
+    // might not have, in the case of `try`.
+    IrInstruction *casted_operand = ir_implicit_cast(ira, operand, ira->explicit_return_type);
+    if (type_is_invalid(casted_operand->value.type)) {
+        AstNode *source_node = ira->explicit_return_type_source_node;
+        if (source_node != nullptr) {
+            ErrorMsg *msg = ira->codegen->errors.last();
+            add_error_note(ira->codegen, msg, source_node,
+                buf_sprintf("return type declared here"));
+        }
+        return ir_unreach_error(ira);
+    }
+
     IrInstruction *result = ir_build_return(&ira->new_irb, instruction->base.scope,
-            instruction->base.source_node, operand);
+            instruction->base.source_node, casted_operand);
     result->value.type = ira->codegen->builtin_types.entry_unreachable;
     return ir_finish_anal(ira, result);
 }
@@ -24742,15 +24773,38 @@ static IrInstruction *ir_analyze_instruction_test_cancel_requested(IrAnalyze *ir
     return ir_build_test_cancel_requested(&ira->new_irb, instruction->base.scope, instruction->base.source_node);
 }
 
-static IrInstruction *ir_analyze_instruction_spill(IrAnalyze *ira, IrInstructionSpill *instruction) {
+static IrInstruction *ir_analyze_instruction_spill_begin(IrAnalyze *ira, IrInstructionSpillBegin *instruction) {
+    if (ir_should_inline(ira->new_irb.exec, instruction->base.scope))
+        return ir_const_void(ira, &instruction->base);
+
     IrInstruction *operand = instruction->operand->child;
     if (type_is_invalid(operand->value.type))
         return ira->codegen->invalid_instruction;
-    if (ir_should_inline(ira->new_irb.exec, instruction->base.scope)) {
+
+    if (!type_has_bits(operand->value.type))
+        return ir_const_void(ira, &instruction->base);
+
+    ir_assert(instruction->spill_id == SpillIdRetErrCode, &instruction->base);
+    ira->new_irb.exec->need_err_code_spill = true;
+
+    IrInstructionSpillBegin *result = ir_build_spill_begin(&ira->new_irb, instruction->base.scope,
+            instruction->base.source_node, operand, instruction->spill_id);
+    return &result->base;
+}
+
+static IrInstruction *ir_analyze_instruction_spill_end(IrAnalyze *ira, IrInstructionSpillEnd *instruction) {
+    IrInstruction *operand = instruction->begin->operand->child;
+    if (type_is_invalid(operand->value.type))
+        return ira->codegen->invalid_instruction;
+
+    if (ir_should_inline(ira->new_irb.exec, instruction->base.scope) || !type_has_bits(operand->value.type))
         return operand;
-    }
-    IrInstruction *result = ir_build_spill(&ira->new_irb, instruction->base.scope, instruction->base.source_node,
-            operand, instruction->spill_id);
+
+    ir_assert(instruction->begin->base.child->id == IrInstructionIdSpillBegin, &instruction->base);
+    IrInstructionSpillBegin *begin = reinterpret_cast<IrInstructionSpillBegin *>(instruction->begin->base.child);
+
+    IrInstruction *result = ir_build_spill_end(&ira->new_irb, instruction->base.scope,
+            instruction->base.source_node, begin);
     result->value.type = operand->value.type;
     return result;
 }
@@ -25054,8 +25108,10 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction
             return ir_analyze_instruction_await(ira, (IrInstructionAwaitSrc *)instruction);
         case IrInstructionIdTestCancelRequested:
             return ir_analyze_instruction_test_cancel_requested(ira, (IrInstructionTestCancelRequested *)instruction);
-        case IrInstructionIdSpill:
-            return ir_analyze_instruction_spill(ira, (IrInstructionSpill *)instruction);
+        case IrInstructionIdSpillBegin:
+            return ir_analyze_instruction_spill_begin(ira, (IrInstructionSpillBegin *)instruction);
+        case IrInstructionIdSpillEnd:
+            return ir_analyze_instruction_spill_end(ira, (IrInstructionSpillEnd *)instruction);
     }
     zig_unreachable();
 }
@@ -25193,6 +25249,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdCoroResume:
         case IrInstructionIdAwaitSrc:
         case IrInstructionIdAwaitGen:
+        case IrInstructionIdSpillBegin:
             return true;
 
         case IrInstructionIdPhi:
@@ -25291,7 +25348,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdAllocaSrc:
         case IrInstructionIdAllocaGen:
         case IrInstructionIdTestCancelRequested:
-        case IrInstructionIdSpill:
+        case IrInstructionIdSpillEnd:
             return false;
 
         case IrInstructionIdAsm:
src/ir_print.cpp
@@ -1554,12 +1554,18 @@ static void ir_print_test_cancel_requested(IrPrint *irp, IrInstructionTestCancel
     fprintf(irp->f, "@testCancelRequested()");
 }
 
-static void ir_print_spill(IrPrint *irp, IrInstructionSpill *instruction) {
-    fprintf(irp->f, "@spill(");
+static void ir_print_spill_begin(IrPrint *irp, IrInstructionSpillBegin *instruction) {
+    fprintf(irp->f, "@spillBegin(");
     ir_print_other_instruction(irp, instruction->operand);
     fprintf(irp->f, ")");
 }
 
+static void ir_print_spill_end(IrPrint *irp, IrInstructionSpillEnd *instruction) {
+    fprintf(irp->f, "@spillEnd(");
+    ir_print_other_instruction(irp, &instruction->begin->base);
+    fprintf(irp->f, ")");
+}
+
 static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
     ir_print_prefix(irp, instruction);
     switch (instruction->id) {
@@ -2045,8 +2051,11 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
         case IrInstructionIdTestCancelRequested:
             ir_print_test_cancel_requested(irp, (IrInstructionTestCancelRequested *)instruction);
             break;
-        case IrInstructionIdSpill:
-            ir_print_spill(irp, (IrInstructionSpill *)instruction);
+        case IrInstructionIdSpillBegin:
+            ir_print_spill_begin(irp, (IrInstructionSpillBegin *)instruction);
+            break;
+        case IrInstructionIdSpillEnd:
+            ir_print_spill_end(irp, (IrInstructionSpillEnd *)instruction);
             break;
     }
     fprintf(irp->f, "\n");
test/stage1/behavior/coroutines.zig
@@ -642,3 +642,33 @@ test "combining try with errdefer cancel" {
     };
     S.doTheTest();
 }
+
+test "try in an async function with error union and non-zero-bit payload" {
+    const S = struct {
+        var frame: anyframe = undefined;
+        var ok = false;
+
+        fn doTheTest() void {
+            _ = async amain();
+            resume frame;
+            expect(ok);
+        }
+
+        fn amain() void {
+            std.testing.expectError(error.Bad, theProblem());
+            ok = true;
+        }
+
+        fn theProblem() ![]u8 {
+            frame = @frame();
+            suspend;
+            const result = try other();
+            return result;
+        }
+
+        fn other() ![]u8 {
+            return error.Bad;
+        }
+    };
+    S.doTheTest();
+}