Commit 2e7f53f1f0

Andrew Kelley <andrew@ziglang.org>
2019-08-09 23:34:06
fix cancel inside an errdefer
1 parent 614cab5
src/all_types.hpp
@@ -1725,6 +1725,7 @@ struct CodeGen {
     LLVMValueRef cur_async_resume_index_ptr;
     LLVMValueRef cur_async_awaiter_ptr;
     LLVMValueRef cur_async_prev_val;
+    LLVMValueRef cur_async_prev_val_field_ptr;
     LLVMBasicBlockRef cur_preamble_llvm_block;
     size_t cur_resume_block_count;
     LLVMValueRef cur_err_ret_trace_val_arg;
@@ -1886,6 +1887,7 @@ struct CodeGen {
     bool system_linker_hack;
     bool reported_bad_link_libc_error;
     bool is_dynamic; // shared library rather than static library. dynamic musl rather than static musl.
+    bool cur_is_after_return;
 
     //////////////////////////// Participates in Input Parameter Cache Hash
     /////// Note: there is a separate cache hash for builtin.zig, when adding fields,
@@ -3639,8 +3641,6 @@ struct IrInstructionCoroResume {
 
 struct IrInstructionTestCancelRequested {
     IrInstruction base;
-
-    bool use_return_begin_prev_value;
 };
 
 enum ResultLocId {
@@ -3730,7 +3730,8 @@ static const size_t err_union_payload_index = 1;
 static const size_t coro_fn_ptr_index = 0;
 static const size_t coro_resume_index = 1;
 static const size_t coro_awaiter_index = 2;
-static const size_t coro_ret_start = 3;
+static const size_t coro_prev_val_index = 3;
+static const size_t coro_ret_start = 4;
 
 // TODO call graph analysis to find out what this number needs to be for every function
 // MUST BE A POWER OF TWO.
src/analyze.cpp
@@ -5246,6 +5246,9 @@ static Error resolve_coro_frame(CodeGen *g, ZigType *frame_type) {
     field_names.append("@awaiter");
     field_types.append(g->builtin_types.entry_usize);
 
+    field_names.append("@prev_val");
+    field_types.append(g->builtin_types.entry_usize);
+
     FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
     ZigType *ptr_return_type = get_pointer_to_type(g, fn_type_id->return_type, false);
     field_names.append("@result_ptr_callee");
@@ -7592,6 +7595,7 @@ static void resolve_llvm_types_any_frame(CodeGen *g, ZigType *any_frame_type, Re
     field_types.append(ptr_fn_llvm_type); // fn_ptr
     field_types.append(usize_type_ref); // resume_index
     field_types.append(usize_type_ref); // awaiter
+    field_types.append(usize_type_ref); // prev_val
 
     bool have_result_type = result_type != nullptr && type_has_bits(result_type);
     if (have_result_type) {
src/codegen.cpp
@@ -2226,7 +2226,18 @@ static LLVMValueRef gen_resume(CodeGen *g, LLVMValueRef fn_val, LLVMValueRef tar
     return ZigLLVMBuildCall(g->builder, fn_val, args, 2, LLVMFastCallConv, ZigLLVM_FnInlineAuto, "");
 }
 
+static LLVMValueRef get_cur_async_prev_val(CodeGen *g) {
+    if (g->cur_async_prev_val != nullptr) {
+        return g->cur_async_prev_val;
+    }
+    g->cur_async_prev_val = LLVMBuildLoad(g->builder, g->cur_async_prev_val_field_ptr, "");
+    return g->cur_async_prev_val;
+}
+
 static LLVMBasicBlockRef gen_suspend_begin(CodeGen *g, const char *name_hint) {
+    // This becomes invalid when a suspend happens.
+    g->cur_async_prev_val = nullptr;
+
     LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
     LLVMBasicBlockRef resume_bb = LLVMAppendBasicBlock(g->cur_fn_val, name_hint);
     size_t new_block_index = g->cur_resume_block_count;
@@ -2319,6 +2330,9 @@ static LLVMValueRef ir_render_return_begin(CodeGen *g, IrExecutable *executable,
     LLVMBasicBlockRef incoming_blocks[] = { after_resume_block, switch_bb };
     LLVMAddIncoming(g->cur_async_prev_val, incoming_values, incoming_blocks, 2);
 
+    g->cur_is_after_return = true;
+    LLVMBuildStore(g->builder, g->cur_async_prev_val, g->cur_async_prev_val_field_ptr);
+
     if (!ret_type_has_bits) {
         return nullptr;
     }
@@ -2366,7 +2380,7 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutable *executable, IrIns
         ZigType *any_frame_type = get_any_frame_type(g, ret_type);
         LLVMValueRef one = LLVMConstInt(usize_type_ref, 1, false);
         LLVMValueRef mask_val = LLVMConstNot(one);
-        LLVMValueRef masked_prev_val = LLVMBuildAnd(g->builder, g->cur_async_prev_val, mask_val, "");
+        LLVMValueRef masked_prev_val = LLVMBuildAnd(g->builder, get_cur_async_prev_val(g), mask_val, "");
         LLVMValueRef their_frame_ptr = LLVMBuildIntToPtr(g->builder, masked_prev_val,
                 get_llvm_type(g, any_frame_type), "");
         LLVMValueRef call_inst = gen_resume(g, nullptr, their_frame_ptr, ResumeIdReturn, nullptr);
@@ -5590,8 +5604,8 @@ static LLVMValueRef ir_render_test_cancel_requested(CodeGen *g, IrExecutable *ex
 {
     if (!fn_is_async(g->cur_fn))
         return LLVMConstInt(LLVMInt1Type(), 0, false);
-    if (instruction->use_return_begin_prev_value) {
-        return LLVMBuildTrunc(g->builder, g->cur_async_prev_val, LLVMInt1Type(), "");
+    if (g->cur_is_after_return) {
+        return LLVMBuildTrunc(g->builder, get_cur_async_prev_val(g), LLVMInt1Type(), "");
     } else {
         zig_panic("TODO");
     }
@@ -7063,6 +7077,7 @@ static void do_code_gen(CodeGen *g) {
         }
 
         if (is_async) {
+            g->cur_is_after_return = false;
             g->cur_resume_block_count = 0;
 
             LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
@@ -7099,6 +7114,8 @@ static void do_code_gen(CodeGen *g) {
                 g->cur_err_ret_trace_val_stack = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr,
                         trace_field_index_stack, "");
             }
+            g->cur_async_prev_val_field_ptr = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr,
+                    coro_prev_val_index, "");
 
             LLVMValueRef resume_index = LLVMBuildLoad(g->builder, resume_index_ptr, "");
             LLVMValueRef switch_instr = LLVMBuildSwitch(g->builder, resume_index, bad_resume_block, 4);
src/ir.cpp
@@ -3325,12 +3325,9 @@ static IrInstruction *ir_build_coro_resume(IrBuilder *irb, Scope *scope, AstNode
     return &instruction->base;
 }
 
-static IrInstruction *ir_build_test_cancel_requested(IrBuilder *irb, Scope *scope, AstNode *source_node,
-        bool use_return_begin_prev_value)
-{
+static IrInstruction *ir_build_test_cancel_requested(IrBuilder *irb, Scope *scope, AstNode *source_node) {
     IrInstructionTestCancelRequested *instruction = ir_build_instruction<IrInstructionTestCancelRequested>(irb, scope, source_node);
     instruction->base.value.type = irb->codegen->builtin_types.entry_bool;
-    instruction->use_return_begin_prev_value = use_return_begin_prev_value;
 
     return &instruction->base;
 }
@@ -3546,7 +3543,7 @@ static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node,
 
                 if (need_test_cancel) {
                     ir_set_cursor_at_end_and_append_block(irb, ok_block);
-                    IrInstruction *is_canceled = ir_build_test_cancel_requested(irb, scope, node, true);
+                    IrInstruction *is_canceled = ir_build_test_cancel_requested(irb, scope, node);
                     ir_mark_gen(ir_build_cond_br(irb, scope, node, is_canceled,
                                 all_defers_block, normal_defers_block, force_comptime));
                 }
@@ -3830,7 +3827,7 @@ static IrInstruction *ir_gen_block(IrBuilder *irb, Scope *parent_scope, AstNode
         ir_gen_defers_for_block(irb, child_scope, outer_block_scope, false);
         return ir_mark_gen(ir_build_return(irb, child_scope, result->source_node, result));
     }
-    IrInstruction *is_canceled = ir_build_test_cancel_requested(irb, child_scope, block_node, true);
+    IrInstruction *is_canceled = ir_build_test_cancel_requested(irb, child_scope, block_node);
     IrBasicBlock *all_defers_block = ir_create_basic_block(irb, child_scope, "ErrDefers");
     IrBasicBlock *normal_defers_block = ir_create_basic_block(irb, child_scope, "Defers");
     IrBasicBlock *ret_stmt_block = ir_create_basic_block(irb, child_scope, "RetStmt");
@@ -24725,8 +24722,7 @@ static IrInstruction *ir_analyze_instruction_test_cancel_requested(IrAnalyze *ir
     if (ir_should_inline(ira->new_irb.exec, instruction->base.scope)) {
         return ir_const_bool(ira, &instruction->base, false);
     }
-    return ir_build_test_cancel_requested(&ira->new_irb, instruction->base.scope, instruction->base.source_node,
-            instruction->use_return_begin_prev_value);
+    return ir_build_test_cancel_requested(&ira->new_irb, instruction->base.scope, instruction->base.source_node);
 }
 
 static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction *instruction) {
src/ir_print.cpp
@@ -1551,8 +1551,7 @@ static void ir_print_await_gen(IrPrint *irp, IrInstructionAwaitGen *instruction)
 }
 
 static void ir_print_test_cancel_requested(IrPrint *irp, IrInstructionTestCancelRequested *instruction) {
-    const char *arg = instruction->use_return_begin_prev_value ? "UseReturnBeginPrevValue" : "AdditionalCheck";
-    fprintf(irp->f, "@testCancelRequested(%s)", arg);
+    fprintf(irp->f, "@testCancelRequested()");
 }
 
 static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
test/stage1/behavior/coroutines.zig
@@ -318,7 +318,7 @@ test "@asyncCall with return type" {
         }
     };
     var foo = Foo{ .bar = Foo.middle };
-    var bytes: [100]u8 = undefined;
+    var bytes: [150]u8 = undefined;
     var aresult: i32 = 0;
     _ = @asyncCall(&bytes, &aresult, foo.bar);
     expect(aresult == 0);
@@ -589,3 +589,27 @@ test "pass string literal to async function" {
     };
     S.doTheTest();
 }
+
+test "cancel inside an errdefer" {
+    const S = struct {
+        var frame: anyframe = undefined;
+
+        fn doTheTest() void {
+            _ = async amainWrap();
+            resume frame;
+        }
+
+        fn amainWrap() !void {
+            var foo = async func();
+            errdefer cancel foo;
+            return error.Bad;
+        }
+
+        fn func() void {
+            frame = @frame();
+            suspend;
+        }
+
+    };
+    S.doTheTest();
+}