Commit 34bfdf193a

Andrew Kelley <andrew@ziglang.org>
2019-08-08 17:37:49
cancel, defer, errdefer all working as intended now
1 parent e11cafb
src/all_types.hpp
@@ -2363,6 +2363,7 @@ enum IrInstructionId {
     IrInstructionIdAwaitSrc,
     IrInstructionIdAwaitGen,
     IrInstructionIdCoroResume,
+    IrInstructionIdTestCancelRequested,
 };
 
 struct IrInstruction {
@@ -3636,6 +3637,12 @@ struct IrInstructionCoroResume {
     IrInstruction *frame;
 };
 
+struct IrInstructionTestCancelRequested {
+    IrInstruction base;
+
+    bool use_return_begin_prev_value;
+};
+
 enum ResultLocId {
     ResultLocIdInvalid,
     ResultLocIdNone,
src/codegen.cpp
@@ -5557,6 +5557,18 @@ static LLVMValueRef ir_render_frame_size(CodeGen *g, IrExecutable *executable,
     return gen_frame_size(g, fn_val);
 }
 
+static LLVMValueRef ir_render_test_cancel_requested(CodeGen *g, IrExecutable *executable,
+        IrInstructionTestCancelRequested *instruction)
+{
+    if (!fn_is_async(g->cur_fn))
+        return LLVMConstInt(LLVMInt1Type(), 0, false);
+    if (instruction->use_return_begin_prev_value) {
+        return LLVMBuildTrunc(g->builder, g->cur_async_prev_val, LLVMInt1Type(), "");
+    } else {
+        zig_panic("TODO");
+    }
+}
+
 static void set_debug_location(CodeGen *g, IrInstruction *instruction) {
     AstNode *source_node = instruction->source_node;
     Scope *scope = instruction->scope;
@@ -5810,6 +5822,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
             return ir_render_frame_size(g, executable, (IrInstructionFrameSizeGen *)instruction);
         case IrInstructionIdAwaitGen:
             return ir_render_await(g, executable, (IrInstructionAwaitGen *)instruction);
+        case IrInstructionIdTestCancelRequested:
+            return ir_render_test_cancel_requested(g, executable, (IrInstructionTestCancelRequested *)instruction);
     }
     zig_unreachable();
 }
src/ir.cpp
@@ -26,6 +26,7 @@ struct IrBuilder {
     CodeGen *codegen;
     IrExecutable *exec;
     IrBasicBlock *current_basic_block;
+    AstNode *main_block_node;
 };
 
 struct IrAnalyze {
@@ -1061,6 +1062,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionCoroResume *) {
     return IrInstructionIdCoroResume;
 }
 
+static constexpr IrInstructionId ir_instruction_id(IrInstructionTestCancelRequested *) {
+    return IrInstructionIdTestCancelRequested;
+}
+
 template<typename T>
 static T *ir_create_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node) {
     T *special_instruction = allocate<T>(1);
@@ -3320,6 +3325,16 @@ static IrInstruction *ir_build_coro_resume(IrBuilder *irb, Scope *scope, AstNode
     return &instruction->base;
 }
 
+static IrInstruction *ir_build_test_cancel_requested(IrBuilder *irb, Scope *scope, AstNode *source_node,
+        bool use_return_begin_prev_value)
+{
+    IrInstructionTestCancelRequested *instruction = ir_build_instruction<IrInstructionTestCancelRequested>(irb, scope, source_node);
+    instruction->base.value.type = irb->codegen->builtin_types.entry_bool;
+    instruction->use_return_begin_prev_value = use_return_begin_prev_value;
+
+    return &instruction->base;
+}
+
 static void ir_count_defers(IrBuilder *irb, Scope *inner_scope, Scope *outer_scope, size_t *results) {
     results[ReturnKindUnconditional] = 0;
     results[ReturnKindError] = 0;
@@ -3494,45 +3509,62 @@ static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node,
                 size_t defer_counts[2];
                 ir_count_defers(irb, scope, outer_scope, defer_counts);
                 bool have_err_defers = defer_counts[ReturnKindError] > 0;
-                if (have_err_defers || irb->codegen->have_err_ret_tracing) {
-                    IrBasicBlock *err_block = ir_create_basic_block(irb, scope, "ErrRetErr");
-                    IrBasicBlock *ok_block = ir_create_basic_block(irb, scope, "ErrRetOk");
+                if (!have_err_defers && !irb->codegen->have_err_ret_tracing) {
+                    // only generate unconditional defers
+                    ir_gen_defers_for_block(irb, scope, outer_scope, false);
+                    IrInstruction *result = ir_build_return(irb, scope, node, return_value);
+                    result_loc_ret->base.source_instruction = result;
+                    return result;
+                }
+                bool should_inline = ir_should_inline(irb->exec, scope);
+                bool need_test_cancel = !should_inline && have_err_defers;
 
-                    IrInstruction *is_err = ir_build_test_err_src(irb, scope, node, return_value, false, true);
+                IrBasicBlock *err_block = ir_create_basic_block(irb, scope, "ErrRetErr");
+                IrBasicBlock *normal_defers_block = ir_create_basic_block(irb, scope, "Defers");
+                IrBasicBlock *ok_block = need_test_cancel ?
+                    ir_create_basic_block(irb, scope, "ErrRetOk") : normal_defers_block;
+                IrBasicBlock *all_defers_block = have_err_defers ? ir_create_basic_block(irb, scope, "ErrDefers") : normal_defers_block;
 
-                    bool should_inline = ir_should_inline(irb->exec, scope);
-                    IrInstruction *is_comptime;
-                    if (should_inline) {
-                        is_comptime = ir_build_const_bool(irb, scope, node, true);
-                    } else {
-                        is_comptime = ir_build_test_comptime(irb, scope, node, is_err);
-                    }
+                IrInstruction *is_err = ir_build_test_err_src(irb, scope, node, return_value, false, true);
 
-                    ir_mark_gen(ir_build_cond_br(irb, scope, node, is_err, err_block, ok_block, is_comptime));
-                    IrBasicBlock *ret_stmt_block = ir_create_basic_block(irb, scope, "RetStmt");
+                IrInstruction *force_comptime = ir_build_const_bool(irb, scope, node, should_inline);
+                IrInstruction *err_is_comptime;
+                if (should_inline) {
+                    err_is_comptime = force_comptime;
+                } else {
+                    err_is_comptime = ir_build_test_comptime(irb, scope, node, is_err);
+                }
 
-                    ir_set_cursor_at_end_and_append_block(irb, err_block);
-                    if (irb->codegen->have_err_ret_tracing && !should_inline) {
-                        ir_build_save_err_ret_addr(irb, scope, node);
-                    }
-                    ir_gen_defers_for_block(irb, scope, outer_scope, true);
-                    ir_build_br(irb, scope, node, ret_stmt_block, is_comptime);
+                ir_mark_gen(ir_build_cond_br(irb, scope, node, is_err, err_block, ok_block, err_is_comptime));
+                IrBasicBlock *ret_stmt_block = ir_create_basic_block(irb, scope, "RetStmt");
 
+                ir_set_cursor_at_end_and_append_block(irb, err_block);
+                if (irb->codegen->have_err_ret_tracing && !should_inline) {
+                    ir_build_save_err_ret_addr(irb, scope, node);
+                }
+                ir_build_br(irb, scope, node, all_defers_block, err_is_comptime);
+
+                if (need_test_cancel) {
                     ir_set_cursor_at_end_and_append_block(irb, ok_block);
-                    ir_gen_defers_for_block(irb, scope, outer_scope, false);
-                    ir_build_br(irb, scope, node, ret_stmt_block, is_comptime);
+                    IrInstruction *is_canceled = ir_build_test_cancel_requested(irb, scope, node, true);
+                    ir_mark_gen(ir_build_cond_br(irb, scope, node, is_canceled,
+                                all_defers_block, normal_defers_block, force_comptime));
+                }
 
-                    ir_set_cursor_at_end_and_append_block(irb, ret_stmt_block);
-                    IrInstruction *result = ir_build_return(irb, scope, node, return_value);
-                    result_loc_ret->base.source_instruction = result;
-                    return result;
-                } else {
-                    // generate unconditional defers
-                    ir_gen_defers_for_block(irb, scope, outer_scope, false);
-                    IrInstruction *result = ir_build_return(irb, scope, node, return_value);
-                    result_loc_ret->base.source_instruction = result;
-                    return result;
+                if (all_defers_block != normal_defers_block) {
+                    ir_set_cursor_at_end_and_append_block(irb, all_defers_block);
+                    ir_gen_defers_for_block(irb, scope, outer_scope, true);
+                    ir_build_br(irb, scope, node, ret_stmt_block, force_comptime);
                 }
+
+                ir_set_cursor_at_end_and_append_block(irb, normal_defers_block);
+                ir_gen_defers_for_block(irb, scope, outer_scope, false);
+                ir_build_br(irb, scope, node, ret_stmt_block, force_comptime);
+
+                ir_set_cursor_at_end_and_append_block(irb, ret_stmt_block);
+                IrInstruction *result = ir_build_return(irb, scope, node, return_value);
+                result_loc_ret->base.source_instruction = result;
+                return result;
             }
         case ReturnKindError:
             {
@@ -3765,18 +3797,59 @@ static IrInstruction *ir_gen_block(IrBuilder *irb, Scope *parent_scope, AstNode
         incoming_values.append(else_expr_result);
     }
 
-    if (block_node->data.block.name != nullptr) {
+    bool is_return_from_fn = block_node == irb->main_block_node;
+    if (!is_return_from_fn) {
         ir_gen_defers_for_block(irb, child_scope, outer_block_scope, false);
+    }
+
+    IrInstruction *result;
+    if (block_node->data.block.name != nullptr) {
         ir_mark_gen(ir_build_br(irb, parent_scope, block_node, scope_block->end_block, scope_block->is_comptime));
         ir_set_cursor_at_end_and_append_block(irb, scope_block->end_block);
         IrInstruction *phi = ir_build_phi(irb, parent_scope, block_node, incoming_blocks.length,
                 incoming_blocks.items, incoming_values.items, scope_block->peer_parent);
-        return ir_expr_wrap(irb, parent_scope, phi, result_loc);
+        result = ir_expr_wrap(irb, parent_scope, phi, result_loc);
     } else {
-        ir_gen_defers_for_block(irb, child_scope, outer_block_scope, false);
         IrInstruction *void_inst = ir_mark_gen(ir_build_const_void(irb, child_scope, block_node));
-        return ir_lval_wrap(irb, parent_scope, void_inst, lval, result_loc);
+        result = ir_lval_wrap(irb, parent_scope, void_inst, lval, result_loc);
     }
+    if (!is_return_from_fn)
+        return result;
+
+    // no need for save_err_ret_addr because this cannot return error
+    // but if it is a canceled async function we do need to run the errdefers
+
+    ir_mark_gen(ir_build_add_implicit_return_type(irb, child_scope, block_node, result));
+    result = ir_mark_gen(ir_build_return_begin(irb, child_scope, block_node, result));
+
+    size_t defer_counts[2];
+    ir_count_defers(irb, child_scope, outer_block_scope, defer_counts);
+    bool have_err_defers = defer_counts[ReturnKindError] > 0;
+    if (!have_err_defers) {
+        // only generate unconditional defers
+        ir_gen_defers_for_block(irb, child_scope, outer_block_scope, false);
+        return ir_mark_gen(ir_build_return(irb, child_scope, result->source_node, result));
+    }
+    IrInstruction *is_canceled = ir_build_test_cancel_requested(irb, child_scope, block_node, true);
+    IrBasicBlock *all_defers_block = ir_create_basic_block(irb, child_scope, "ErrDefers");
+    IrBasicBlock *normal_defers_block = ir_create_basic_block(irb, child_scope, "Defers");
+    IrBasicBlock *ret_stmt_block = ir_create_basic_block(irb, child_scope, "RetStmt");
+    bool should_inline = ir_should_inline(irb->exec, child_scope);
+    IrInstruction *errdefers_is_comptime = ir_build_const_bool(irb, child_scope, block_node,
+            should_inline || !have_err_defers);
+    ir_mark_gen(ir_build_cond_br(irb, child_scope, block_node, is_canceled,
+                all_defers_block, normal_defers_block, errdefers_is_comptime));
+
+    ir_set_cursor_at_end_and_append_block(irb, all_defers_block);
+    ir_gen_defers_for_block(irb, child_scope, outer_block_scope, true);
+    ir_build_br(irb, child_scope, block_node, ret_stmt_block, errdefers_is_comptime);
+
+    ir_set_cursor_at_end_and_append_block(irb, normal_defers_block);
+    ir_gen_defers_for_block(irb, child_scope, outer_block_scope, false);
+    ir_build_br(irb, child_scope, block_node, ret_stmt_block, errdefers_is_comptime);
+
+    ir_set_cursor_at_end_and_append_block(irb, ret_stmt_block);
+    return ir_mark_gen(ir_build_return(irb, child_scope, result->source_node, result));
 }
 
 static IrInstruction *ir_gen_bin_op_id(IrBuilder *irb, Scope *scope, AstNode *node, IrBinOp op_id) {
@@ -8111,6 +8184,7 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec
 
     irb->codegen = codegen;
     irb->exec = ir_executable;
+    irb->main_block_node = node;
 
     IrBasicBlock *entry_block = ir_create_basic_block(irb, scope, "Entry");
     ir_set_cursor_at_end_and_append_block(irb, entry_block);
@@ -24603,6 +24677,16 @@ static IrInstruction *ir_analyze_instruction_coro_resume(IrAnalyze *ira, IrInstr
     return ir_build_coro_resume(&ira->new_irb, instruction->base.scope, instruction->base.source_node, casted_frame);
 }
 
+static IrInstruction *ir_analyze_instruction_test_cancel_requested(IrAnalyze *ira,
+        IrInstructionTestCancelRequested *instruction)
+{
+    if (ir_should_inline(ira->new_irb.exec, instruction->base.scope)) {
+        return ir_const_bool(ira, &instruction->base, false);
+    }
+    return ir_build_test_cancel_requested(&ira->new_irb, instruction->base.scope, instruction->base.source_node,
+            instruction->use_return_begin_prev_value);
+}
+
 static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction *instruction) {
     switch (instruction->id) {
         case IrInstructionIdInvalid:
@@ -24900,6 +24984,8 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction
             return ir_analyze_instruction_coro_resume(ira, (IrInstructionCoroResume *)instruction);
         case IrInstructionIdAwaitSrc:
             return ir_analyze_instruction_await(ira, (IrInstructionAwaitSrc *)instruction);
+        case IrInstructionIdTestCancelRequested:
+            return ir_analyze_instruction_test_cancel_requested(ira, (IrInstructionTestCancelRequested *)instruction);
     }
     zig_unreachable();
 }
@@ -25134,6 +25220,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdHasDecl:
         case IrInstructionIdAllocaSrc:
         case IrInstructionIdAllocaGen:
+        case IrInstructionIdTestCancelRequested:
             return false;
 
         case IrInstructionIdAsm:
src/ir_print.cpp
@@ -1550,6 +1550,11 @@ static void ir_print_await_gen(IrPrint *irp, IrInstructionAwaitGen *instruction)
     fprintf(irp->f, ")");
 }
 
+static void ir_print_test_cancel_requested(IrPrint *irp, IrInstructionTestCancelRequested *instruction) {
+    const char *arg = instruction->use_return_begin_prev_value ? "UseReturnBeginPrevValue" : "AdditionalCheck";
+    fprintf(irp->f, "@testCancelRequested(%s)", arg);
+}
+
 static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
     ir_print_prefix(irp, instruction);
     switch (instruction->id) {
@@ -2032,6 +2037,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
         case IrInstructionIdAwaitGen:
             ir_print_await_gen(irp, (IrInstructionAwaitGen *)instruction);
             break;
+        case IrInstructionIdTestCancelRequested:
+            ir_print_test_cancel_requested(irp, (IrInstructionTestCancelRequested *)instruction);
+            break;
     }
     fprintf(irp->f, "\n");
 }
test/stage1/behavior/cancel.zig
@@ -48,8 +48,9 @@ var defer_b3: bool = false;
 var defer_b4: bool = false;
 
 test "cancel backwards" {
-    _ = async b1();
+    var b1_frame = async b1();
     resume b4_handle;
+    _ = async awaitAFrame(&b1_frame);
     expect(defer_b1);
     expect(defer_b2);
     expect(defer_b3);
@@ -63,7 +64,7 @@ async fn b1() void {
     b2();
 }
 
-var b4_handle: anyframe = undefined;
+var b4_handle: anyframe->void = undefined;
 
 async fn b2() void {
     const b3_handle = async b3();
@@ -93,6 +94,10 @@ async fn b4() void {
     suspend;
 }
 
+fn awaitAFrame(f: anyframe->void) void {
+    await f;
+}
+
 test "cancel on a non-pointer" {
     const S = struct {
         fn doTheTest() void {
test/stage1/behavior/coroutines.zig
@@ -134,29 +134,44 @@ test "@frameSize" {
 }
 
 test "coroutine suspend, resume" {
-    seq('a');
-    const p = async testAsyncSeq();
-    seq('c');
-    resume p;
-    seq('f');
-    // `cancel` is now a suspend point so it cannot be done here
-    seq('g');
+    const S = struct {
+        var frame: anyframe = undefined;
 
-    expect(std.mem.eql(u8, points, "abcdefg"));
-}
-async fn testAsyncSeq() void {
-    defer seq('e');
+        fn doTheTest() void {
+            _ = async amain();
+            seq('d');
+            resume frame;
+            seq('h');
 
-    seq('b');
-    suspend;
-    seq('d');
-}
-var points = [_]u8{0} ** "abcdefg".len;
-var index: usize = 0;
+            expect(std.mem.eql(u8, points, "abcdefgh"));
+        }
+
+        fn amain() void {
+            seq('a');
+            var f = async testAsyncSeq();
+            seq('c');
+            cancel f;
+            seq('g');
+        }
+
+        fn testAsyncSeq() void {
+            defer seq('f');
 
-fn seq(c: u8) void {
-    points[index] = c;
-    index += 1;
+            seq('b');
+            suspend {
+                frame = @frame();
+            }
+            seq('e');
+        }
+        var points = [_]u8{'x'} ** "abcdefgh".len;
+        var index: usize = 0;
+
+        fn seq(c: u8) void {
+            points[index] = c;
+            index += 1;
+        }
+    };
+    S.doTheTest();
 }
 
 test "coroutine suspend with block" {
@@ -267,12 +282,19 @@ test "async fn pointer in a struct field" {
     };
     var foo = Foo{ .bar = simpleAsyncFn2 };
     var bytes: [64]u8 = undefined;
-    const p = @asyncCall(&bytes, {}, foo.bar, &data);
-    comptime expect(@typeOf(p) == anyframe->void);
+    const f = @asyncCall(&bytes, {}, foo.bar, &data);
+    comptime expect(@typeOf(f) == anyframe->void);
     expect(data == 2);
-    resume p;
+    resume f;
+    expect(data == 2);
+    _ = async doTheAwait(f);
     expect(data == 4);
 }
+
+fn doTheAwait(f: anyframe->void) void {
+    await f;
+}
+
 async fn simpleAsyncFn2(y: *i32) void {
     defer y.* += 2;
     y.* += 1;
@@ -507,3 +529,42 @@ test "call async function which has struct return type" {
     };
     S.doTheTest();
 }
+
+test "errdefers in scope get run when canceling async fn call" {
+    const S = struct {
+        var frame: anyframe = undefined;
+        var x: u32 = 0;
+
+        fn doTheTest() void {
+            x = 9;
+            _ = async cancelIt();
+            resume frame;
+            expect(x == 6);
+
+            x = 9;
+            _ = async awaitIt();
+            resume frame;
+            expect(x == 11);
+        }
+
+        fn cancelIt() void {
+            var f = async func();
+            cancel f;
+        }
+
+        fn awaitIt() void {
+            var f = async func();
+            await f;
+        }
+
+        fn func() void {
+            defer x += 1;
+            errdefer x /= 2;
+            defer x += 1;
+            suspend {
+                frame = @frame();
+            }
+        }
+    };
+    S.doTheTest();
+}
BRANCH_TODO
@@ -2,8 +2,7 @@
  * compile error for error: expected anyframe->T, found 'i32'
  * await of a non async function
  * async call on a non async function
- * cancel
- * defer and errdefer
+ * a test where an async function destroys its own frame in a defer
  * implicit cast of normal function to async function should be allowed when it is inferred to be async
  * revive std.event.Loop
  * @typeInfo for @Frame(func)