Commit 04ee3b01a1

Andrew Kelley <andrew@ziglang.org>
2020-02-09 23:19:28
fix defer interfering with return value spill
1 parent 5b10d9f
Changed files (4)
src/analyze.cpp
@@ -6367,7 +6367,9 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
             IrInstGen *instruction = block->instruction_list.at(instr_i);
             if (instruction->id == IrInstGenIdAwait ||
                 instruction->id == IrInstGenIdVarPtr ||
-                instruction->id == IrInstGenIdAlloca)
+                instruction->id == IrInstGenIdAlloca ||
+                instruction->id == IrInstGenIdSpillBegin ||
+                instruction->id == IrInstGenIdSpillEnd)
             {
                 // This instruction does its own spilling specially, or otherwise doesn't need it.
                 continue;
src/codegen.cpp
@@ -2561,7 +2561,12 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutableGen *executable, Ir
             LLVMBuildRet(g->builder, by_val_value);
         }
     } else if (instruction->operand == nullptr) {
-        LLVMBuildRetVoid(g->builder);
+        if (g->cur_ret_ptr == nullptr) {
+            LLVMBuildRetVoid(g->builder);
+        } else {
+            LLVMValueRef by_val_value = gen_load_untyped(g, g->cur_ret_ptr, 0, false, "");
+            LLVMBuildRet(g->builder, by_val_value);
+        }
     } else {
         LLVMValueRef value = ir_llvm_value(g, instruction->operand);
         LLVMBuildRet(g->builder, value);
@@ -5715,18 +5720,24 @@ static LLVMValueRef ir_render_unwrap_err_payload(CodeGen *g, IrExecutableGen *ex
     bool want_safety = instruction->safety_check_on && ir_want_runtime_safety(g, &instruction->base) &&
         g->errors_by_index.length > 1;
 
-    bool value_has_bits;
-    if ((err = type_has_bits2(g, instruction->base.value->type, &value_has_bits)))
-        codegen_report_errors_and_exit(g);
-
-    if (!want_safety && !value_has_bits)
-        return nullptr;
-
     ZigType *ptr_type = instruction->value->value->type;
     assert(ptr_type->id == ZigTypeIdPointer);
     ZigType *err_union_type = ptr_type->data.pointer.child_type;
     ZigType *payload_type = err_union_type->data.error_union.payload_type;
     LLVMValueRef err_union_ptr = ir_llvm_value(g, instruction->value);
+
+    LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, g->err_tag_type));
+    bool value_has_bits;
+    if ((err = type_has_bits2(g, instruction->base.value->type, &value_has_bits)))
+        codegen_report_errors_and_exit(g);
+    if (!want_safety && !value_has_bits) {
+        if (instruction->initializing) {
+            gen_store_untyped(g, zero, err_union_ptr, 0, false);
+        }
+        return nullptr;
+    }
+
+
     LLVMValueRef err_union_handle = get_handle_value(g, err_union_ptr, err_union_type, ptr_type);
 
     if (!type_has_bits(err_union_type->data.error_union.err_set_type)) {
@@ -5741,7 +5752,6 @@ static LLVMValueRef ir_render_unwrap_err_payload(CodeGen *g, IrExecutableGen *ex
         } else {
             err_val = err_union_handle;
         }
-        LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, g->err_tag_type));
         LLVMValueRef cond_val = LLVMBuildICmp(g->builder, LLVMIntEQ, err_val, zero, "");
         LLVMBasicBlockRef err_block = LLVMAppendBasicBlock(g->cur_fn_val, "UnwrapErrError");
         LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "UnwrapErrOk");
@@ -5761,6 +5771,9 @@ static LLVMValueRef ir_render_unwrap_err_payload(CodeGen *g, IrExecutableGen *ex
         }
         return LLVMBuildStructGEP(g->builder, err_union_handle, err_union_payload_index, "");
     } else {
+        if (instruction->initializing) {
+            gen_store_untyped(g, zero, err_union_ptr, 0, false);
+        }
         return nullptr;
     }
 }
src/ir.cpp
@@ -5252,6 +5252,7 @@ static IrInstSrc *ir_gen_return(IrBuilderSrc *irb, Scope *scope, AstNode *node,
                         return irb->codegen->invalid_inst_src;
                 } else {
                     return_value = ir_build_const_void(irb, scope, node);
+                    ir_build_end_expr(irb, scope, node, return_value, &result_loc_ret->base);
                 }
 
                 ir_mark_gen(ir_build_add_implicit_return_type(irb, scope, node, return_value, result_loc_ret));
@@ -5262,7 +5263,7 @@ static IrInstSrc *ir_gen_return(IrBuilderSrc *irb, Scope *scope, AstNode *node,
                 if (!have_err_defers && !irb->codegen->have_err_ret_tracing) {
                     // only generate unconditional defers
                     ir_gen_defers_for_block(irb, scope, outer_scope, false);
-                    IrInstSrc *result = ir_build_return_src(irb, scope, node, return_value);
+                    IrInstSrc *result = ir_build_return_src(irb, scope, node, nullptr);
                     result_loc_ret->base.source_instruction = result;
                     return result;
                 }
@@ -5271,10 +5272,6 @@ static IrInstSrc *ir_gen_return(IrBuilderSrc *irb, Scope *scope, AstNode *node,
                 IrBasicBlockSrc *err_block = ir_create_basic_block(irb, scope, "ErrRetErr");
                 IrBasicBlockSrc *ok_block = ir_create_basic_block(irb, scope, "ErrRetOk");
 
-                if (!have_err_defers) {
-                    ir_gen_defers_for_block(irb, scope, outer_scope, false);
-                }
-
                 IrInstSrc *is_err = ir_build_test_err_src(irb, scope, node, return_value, false, true);
 
                 IrInstSrc *is_comptime;
@@ -5288,22 +5285,18 @@ static IrInstSrc *ir_gen_return(IrBuilderSrc *irb, Scope *scope, AstNode *node,
                 IrBasicBlockSrc *ret_stmt_block = ir_create_basic_block(irb, scope, "RetStmt");
 
                 ir_set_cursor_at_end_and_append_block(irb, err_block);
-                if (have_err_defers) {
-                    ir_gen_defers_for_block(irb, scope, outer_scope, true);
-                }
+                ir_gen_defers_for_block(irb, scope, outer_scope, true);
                 if (irb->codegen->have_err_ret_tracing && !should_inline) {
                     ir_build_save_err_ret_addr_src(irb, scope, node);
                 }
                 ir_build_br(irb, scope, node, ret_stmt_block, is_comptime);
 
                 ir_set_cursor_at_end_and_append_block(irb, ok_block);
-                if (have_err_defers) {
-                    ir_gen_defers_for_block(irb, scope, outer_scope, false);
-                }
+                ir_gen_defers_for_block(irb, scope, outer_scope, false);
                 ir_build_br(irb, scope, node, ret_stmt_block, is_comptime);
 
                 ir_set_cursor_at_end_and_append_block(irb, ret_stmt_block);
-                IrInstSrc *result = ir_build_return_src(irb, scope, node, return_value);
+                IrInstSrc *result = ir_build_return_src(irb, scope, node, nullptr);
                 result_loc_ret->base.source_instruction = result;
                 return result;
             }
@@ -9622,7 +9615,10 @@ static IrInstSrc *ir_gen_catch(IrBuilderSrc *irb, Scope *parent_scope, AstNode *
     }
 
 
-    IrInstSrc *err_union_ptr = ir_gen_node_extra(irb, op1_node, parent_scope, LValPtr, nullptr);
+    ScopeExpr *spill_scope = create_expr_scope(irb->codegen, op1_node, parent_scope);
+    spill_scope->spill_harder = true;
+
+    IrInstSrc *err_union_ptr = ir_gen_node_extra(irb, op1_node, &spill_scope->base, LValPtr, nullptr);
     if (err_union_ptr == irb->codegen->invalid_inst_src)
         return irb->codegen->invalid_inst_src;
 
@@ -9644,7 +9640,7 @@ static IrInstSrc *ir_gen_catch(IrBuilderSrc *irb, Scope *parent_scope, AstNode *
             is_comptime);
 
     ir_set_cursor_at_end_and_append_block(irb, err_block);
-    Scope *subexpr_scope = create_runtime_scope(irb->codegen, node, parent_scope, is_comptime);
+    Scope *subexpr_scope = create_runtime_scope(irb->codegen, node, &spill_scope->base, is_comptime);
     Scope *err_scope;
     if (var_node) {
         assert(var_node->type == NodeTypeSymbol);
@@ -15497,6 +15493,12 @@ static IrInstGen *ir_analyze_instruction_add_implicit_return_type(IrAnalyze *ira
 }
 
 static IrInstGen *ir_analyze_instruction_return(IrAnalyze *ira, IrInstSrcReturn *instruction) {
+    if (instruction->operand == nullptr) {
+        // result location mechanism took care of it.
+        IrInstGen *result = ir_build_return_gen(ira, &instruction->base.base, nullptr);
+        return ir_finish_anal(ira, result);
+    }
+
     IrInstGen *operand = instruction->operand->child;
     if (type_is_invalid(operand->value->type))
         return ir_unreach_error(ira);
@@ -29551,8 +29553,13 @@ static IrInstGen *ir_analyze_instruction_spill_begin(IrAnalyze *ira, IrInstSrcSp
     if (!type_has_bits(operand->value->type))
         return ir_const_void(ira, &instruction->base.base);
 
-    ir_assert(instruction->spill_id == SpillIdRetErrCode, &instruction->base.base);
-    ira->new_irb.exec->need_err_code_spill = true;
+    switch (instruction->spill_id) {
+        case SpillIdInvalid:
+            zig_unreachable();
+        case SpillIdRetErrCode:
+            ira->new_irb.exec->need_err_code_spill = true;
+            break;
+    }
 
     return ir_build_spill_begin_gen(ira, &instruction->base.base, operand, instruction->spill_id);
 }
@@ -29562,8 +29569,12 @@ static IrInstGen *ir_analyze_instruction_spill_end(IrAnalyze *ira, IrInstSrcSpil
     if (type_is_invalid(operand->value->type))
         return ira->codegen->invalid_inst_gen;
 
-    if (ir_should_inline(ira->old_irb.exec, instruction->base.base.scope) || !type_has_bits(operand->value->type))
+    if (ir_should_inline(ira->old_irb.exec, instruction->base.base.scope) ||
+        !type_has_bits(operand->value->type) ||
+        instr_is_comptime(operand))
+    {
         return operand;
+    }
 
     ir_assert(instruction->begin->base.child->id == IrInstGenIdSpillBegin, &instruction->base.base);
     IrInstGenSpillBegin *begin = reinterpret_cast<IrInstGenSpillBegin *>(instruction->begin->base.child);
test/stage1/behavior/async_fn.zig
@@ -2,6 +2,7 @@ const std = @import("std");
 const builtin = @import("builtin");
 const expect = std.testing.expect;
 const expectEqual = std.testing.expectEqual;
+const expectError = std.testing.expectError;
 
 var global_x: i32 = 1;
 
@@ -1440,3 +1441,43 @@ test "properly spill optional payload capture value" {
     resume S.global_frame;
     expect(S.global_int == 1237);
 }
+
+test "handle defer interfering with return value spill" {
+    const S = struct {
+        var global_frame1: anyframe = undefined;
+        var global_frame2: anyframe = undefined;
+        var finished = false;
+        var baz_happened = false;
+
+        fn doTheTest() void {
+            _ = async testFoo();
+            resume global_frame1;
+            resume global_frame2;
+            expect(baz_happened);
+            expect(finished);
+        }
+
+        fn testFoo() void {
+            expectError(error.Bad, foo());
+            finished = true;
+        }
+
+        fn foo() anyerror!void {
+            defer baz();
+            return bar() catch |err| return err;
+        }
+
+        fn bar() anyerror!void {
+            global_frame1 = @frame();
+            suspend;
+            return error.Bad;
+        }
+
+        fn baz() void {
+            global_frame2 = @frame();
+            suspend;
+            baz_happened = true;
+        }
+    };
+    S.doTheTest();
+}