Commit 7e1fcb55b3

Andrew Kelley <andrew@ziglang.org>
2019-08-07 06:52:56
implement cancel
all behavior tests passing in this branch
1 parent 1afbb53
src/all_types.hpp
@@ -1556,6 +1556,7 @@ enum PanicMsgId {
     PanicMsgIdBadAwait,
     PanicMsgIdBadReturn,
     PanicMsgIdResumedAnAwaitingFn,
+    PanicMsgIdResumedACancelingFn,
     PanicMsgIdFrameTooSmall,
     PanicMsgIdResumedFnPendingAwait,
 
@@ -3432,7 +3433,7 @@ struct IrInstructionErrorUnion {
 struct IrInstructionCancel {
     IrInstruction base;
 
-    IrInstruction *target;
+    IrInstruction *frame;
 };
 
 struct IrInstructionAtomicRmw {
src/analyze.cpp
@@ -3811,6 +3811,9 @@ static void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn) {
     } else if (fn->inferred_async_node->type == NodeTypeAwaitExpr) {
         add_error_note(g, msg, fn->inferred_async_node,
             buf_sprintf("await is a suspend point"));
+    } else if (fn->inferred_async_node->type == NodeTypeCancel) {
+        add_error_note(g, msg, fn->inferred_async_node,
+            buf_sprintf("cancel is a suspend point"));
     } else {
         zig_unreachable();
     }
src/codegen.cpp
@@ -911,11 +911,13 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) {
         case PanicMsgIdBadResume:
             return buf_create_from_str("resumed an async function which already returned");
         case PanicMsgIdBadAwait:
-            return buf_create_from_str("async function awaited twice");
+            return buf_create_from_str("async function awaited/canceled twice");
         case PanicMsgIdBadReturn:
             return buf_create_from_str("async function returned twice");
         case PanicMsgIdResumedAnAwaitingFn:
             return buf_create_from_str("awaiting function resumed");
+        case PanicMsgIdResumedACancelingFn:
+            return buf_create_from_str("canceling function resumed");
         case PanicMsgIdFrameTooSmall:
             return buf_create_from_str("frame too small");
         case PanicMsgIdResumedFnPendingAwait:
@@ -2189,12 +2191,12 @@ static void gen_assert_resume_id(CodeGen *g, IrInstruction *source_instr, Resume
     if (end_bb == nullptr) end_bb = LLVMAppendBasicBlock(g->cur_fn_val, "OkResume");
     LLVMValueRef ok_bit;
     if (resume_id == ResumeIdAwaitEarlyReturn) {
-        LLVMValueRef last_value = LLVMBuildSub(g->builder, LLVMConstAllOnes(usize_type_ref),
-                LLVMConstInt(usize_type_ref, ResumeIdAwaitEarlyReturn, false), "");
+        LLVMValueRef last_value = LLVMConstSub(LLVMConstAllOnes(usize_type_ref),
+                LLVMConstInt(usize_type_ref, ResumeIdAwaitEarlyReturn, false));
         ok_bit = LLVMBuildICmp(g->builder, LLVMIntULT, LLVMGetParam(g->cur_fn_val, 1), last_value, "");
     } else {
-        LLVMValueRef expected_value = LLVMBuildSub(g->builder, LLVMConstAllOnes(usize_type_ref),
-                LLVMConstInt(usize_type_ref, resume_id, false), "");
+        LLVMValueRef expected_value = LLVMConstSub(LLVMConstAllOnes(usize_type_ref),
+                LLVMConstInt(usize_type_ref, resume_id, false));
         ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, LLVMGetParam(g->cur_fn_val, 1), expected_value, "");
     }
     LLVMBuildCondBr(g->builder, ok_bit, end_bb, bad_resume_block);
@@ -2210,11 +2212,13 @@ static LLVMValueRef gen_resume(CodeGen *g, LLVMValueRef fn_val, LLVMValueRef tar
 {
     LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
     if (fn_val == nullptr) {
-        if (g->anyframe_fn_type == nullptr) {
-            (void)get_llvm_type(g, get_any_frame_type(g, nullptr));
-        }
         LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr, coro_fn_ptr_index, "");
-        fn_val = LLVMBuildLoad(g->builder, fn_ptr_ptr, "");
+        LLVMValueRef fn_val_typed = LLVMBuildLoad(g->builder, fn_ptr_ptr, "");
+        LLVMValueRef as_int = LLVMBuildPtrToInt(g->builder, fn_val_typed, usize_type_ref, "");
+        LLVMValueRef one = LLVMConstInt(usize_type_ref, 1, false);
+        LLVMValueRef mask_val = LLVMConstNot(one);
+        LLVMValueRef as_int_masked = LLVMBuildAnd(g->builder, as_int, mask_val, "");
+        fn_val = LLVMBuildIntToPtr(g->builder, as_int_masked, LLVMTypeOf(fn_val_typed), "");
     }
     if (arg_val == nullptr) {
         arg_val = LLVMBuildSub(g->builder, LLVMConstAllOnes(usize_type_ref),
@@ -2226,6 +2230,17 @@ static LLVMValueRef gen_resume(CodeGen *g, LLVMValueRef fn_val, LLVMValueRef tar
     return ZigLLVMBuildCall(g->builder, fn_val, args, 2, LLVMFastCallConv, ZigLLVM_FnInlineAuto, "");
 }
 
+static LLVMBasicBlockRef gen_suspend_begin(CodeGen *g, const char *name_hint) {
+    LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
+    LLVMBasicBlockRef resume_bb = LLVMAppendBasicBlock(g->cur_fn_val, name_hint);
+    size_t new_block_index = g->cur_resume_block_count;
+    g->cur_resume_block_count += 1;
+    LLVMValueRef new_block_index_val = LLVMConstInt(usize_type_ref, new_block_index, false);
+    LLVMAddCase(g->cur_async_switch_instr, new_block_index_val, resume_bb);
+    LLVMBuildStore(g->builder, new_block_index_val, g->cur_async_resume_index_ptr);
+    return resume_bb;
+}
+
 static LLVMValueRef ir_render_return_begin(CodeGen *g, IrExecutable *executable,
         IrInstructionReturnBegin *instruction)
 {
@@ -2245,12 +2260,7 @@ static LLVMValueRef ir_render_return_begin(CodeGen *g, IrExecutable *executable,
     }
 
     // Prepare to be suspended. We might end up not having to suspend though.
-    LLVMBasicBlockRef resume_bb = LLVMAppendBasicBlock(g->cur_fn_val, "ReturnResume");
-    size_t new_block_index = g->cur_resume_block_count;
-    g->cur_resume_block_count += 1;
-    LLVMValueRef new_block_index_val = LLVMConstInt(usize_type_ref, new_block_index, false);
-    LLVMAddCase(g->cur_async_switch_instr, new_block_index_val, resume_bb);
-    LLVMBuildStore(g->builder, new_block_index_val, g->cur_async_resume_index_ptr);
+    LLVMBasicBlockRef resume_bb = gen_suspend_begin(g, "ReturnResume");
 
     LLVMValueRef zero = LLVMConstNull(usize_type_ref);
     LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref);
@@ -2335,7 +2345,10 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutable *executable, IrIns
 
         // We need to resume the caller by tail calling them.
         ZigType *any_frame_type = get_any_frame_type(g, ret_type);
-        LLVMValueRef their_frame_ptr = LLVMBuildIntToPtr(g->builder, g->cur_async_prev_val,
+        LLVMValueRef one = LLVMConstInt(usize_type_ref, 1, false);
+        LLVMValueRef mask_val = LLVMConstNot(one);
+        LLVMValueRef masked_prev_val = LLVMBuildAnd(g->builder, g->cur_async_prev_val, mask_val, "");
+        LLVMValueRef their_frame_ptr = LLVMBuildIntToPtr(g->builder, masked_prev_val,
                 get_llvm_type(g, any_frame_type), "");
         LLVMValueRef call_inst = gen_resume(g, nullptr, their_frame_ptr, ResumeIdReturn, nullptr);
         ZigLLVMSetTailCall(call_inst);
@@ -3945,13 +3958,7 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
     } else if (callee_is_async) {
         ZigType *ptr_result_type = get_pointer_to_type(g, src_return_type, true);
 
-        LLVMBasicBlockRef call_bb = LLVMAppendBasicBlock(g->cur_fn_val, "CallResume");
-        size_t new_block_index = g->cur_resume_block_count;
-        g->cur_resume_block_count += 1;
-        LLVMValueRef new_block_index_val = LLVMConstInt(usize_type_ref, new_block_index, false);
-        LLVMAddCase(g->cur_async_switch_instr, new_block_index_val, call_bb);
-
-        LLVMBuildStore(g->builder, new_block_index_val, g->cur_async_resume_index_ptr);
+        LLVMBasicBlockRef call_bb = gen_suspend_begin(g, "CallResume");
 
         LLVMValueRef call_inst = gen_resume(g, fn_val, frame_result_loc, ResumeIdCall, nullptr);
         ZigLLVMSetTailCall(call_inst);
@@ -4672,10 +4679,6 @@ static LLVMValueRef ir_render_error_return_trace(CodeGen *g, IrExecutable *execu
     return cur_err_ret_trace_val;
 }
 
-static LLVMValueRef ir_render_cancel(CodeGen *g, IrExecutable *executable, IrInstructionCancel *instruction) {
-    zig_panic("TODO cancel");
-}
-
 static LLVMAtomicOrdering to_LLVMAtomicOrdering(AtomicOrder atomic_order) {
     switch (atomic_order) {
         case AtomicOrderUnordered: return LLVMAtomicOrderingUnordered;
@@ -5416,13 +5419,7 @@ static LLVMValueRef ir_render_assert_non_null(CodeGen *g, IrExecutable *executab
 static LLVMValueRef ir_render_suspend_begin(CodeGen *g, IrExecutable *executable,
         IrInstructionSuspendBegin *instruction)
 {
-    LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
-    instruction->resume_bb = LLVMAppendBasicBlock(g->cur_fn_val, "SuspendResume");
-    size_t new_block_index = g->cur_resume_block_count;
-    g->cur_resume_block_count += 1;
-    LLVMValueRef new_block_index_val = LLVMConstInt(usize_type_ref, new_block_index, false);
-    LLVMAddCase(g->cur_async_switch_instr, new_block_index_val, instruction->resume_bb);
-    LLVMBuildStore(g->builder, new_block_index_val, g->cur_async_resume_index_ptr);
+    instruction->resume_bb = gen_suspend_begin(g, "SuspendResume");
     return nullptr;
 }
 
@@ -5436,6 +5433,43 @@ static LLVMValueRef ir_render_suspend_finish(CodeGen *g, IrExecutable *executabl
     return nullptr;
 }
 
+static LLVMValueRef ir_render_cancel(CodeGen *g, IrExecutable *executable, IrInstructionCancel *instruction) {
+    LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
+    LLVMValueRef zero = LLVMConstNull(usize_type_ref);
+    LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref);
+    LLVMValueRef one = LLVMConstInt(usize_type_ref, 1, false);
+
+    LLVMValueRef target_frame_ptr = ir_llvm_value(g, instruction->frame);
+    LLVMBasicBlockRef resume_bb = gen_suspend_begin(g, "CancelResume");
+
+    LLVMValueRef awaiter_val = LLVMBuildPtrToInt(g->builder, g->cur_frame_ptr, usize_type_ref, "");
+    LLVMValueRef awaiter_ored_val = LLVMBuildOr(g->builder, awaiter_val, one, "");
+    LLVMValueRef awaiter_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr, coro_awaiter_index, "");
+
+    LLVMValueRef prev_val = LLVMBuildAtomicRMW(g->builder, LLVMAtomicRMWBinOpXchg, awaiter_ptr, awaiter_ored_val,
+            LLVMAtomicOrderingRelease, g->is_single_threaded);
+
+    LLVMBasicBlockRef complete_suspend_block = LLVMAppendBasicBlock(g->cur_fn_val, "CancelSuspend");
+    LLVMBasicBlockRef early_return_block = LLVMAppendBasicBlock(g->cur_fn_val, "EarlyReturn");
+
+    LLVMValueRef switch_instr = LLVMBuildSwitch(g->builder, prev_val, resume_bb, 2);
+    LLVMAddCase(switch_instr, zero, complete_suspend_block);
+    LLVMAddCase(switch_instr, all_ones, early_return_block);
+
+    LLVMPositionBuilderAtEnd(g->builder, complete_suspend_block);
+    LLVMBuildRetVoid(g->builder);
+
+    LLVMPositionBuilderAtEnd(g->builder, early_return_block);
+    LLVMValueRef call_inst = gen_resume(g, nullptr, target_frame_ptr, ResumeIdAwaitEarlyReturn, awaiter_ored_val);
+    ZigLLVMSetTailCall(call_inst);
+    LLVMBuildRetVoid(g->builder);
+
+    LLVMPositionBuilderAtEnd(g->builder, resume_bb);
+    gen_assert_resume_id(g, &instruction->base, ResumeIdReturn, PanicMsgIdResumedACancelingFn, nullptr);
+
+    return nullptr;
+}
+
 static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInstructionAwaitGen *instruction) {
     LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
     LLVMValueRef zero = LLVMConstNull(usize_type_ref);
@@ -5444,12 +5478,7 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst
     ZigType *ptr_result_type = get_pointer_to_type(g, result_type, true);
 
     // Prepare to be suspended
-    LLVMBasicBlockRef resume_bb = LLVMAppendBasicBlock(g->cur_fn_val, "AwaitResume");
-    size_t new_block_index = g->cur_resume_block_count;
-    g->cur_resume_block_count += 1;
-    LLVMValueRef new_block_index_val = LLVMConstInt(usize_type_ref, new_block_index, false);
-    LLVMAddCase(g->cur_async_switch_instr, new_block_index_val, resume_bb);
-    LLVMBuildStore(g->builder, new_block_index_val, g->cur_async_resume_index_ptr);
+    LLVMBasicBlockRef resume_bb = gen_suspend_begin(g, "AwaitResume");
 
     // At this point resuming the function will do the correct thing.
     // This code is as if it is running inside the suspend block.
src/ir.cpp
@@ -3271,6 +3271,16 @@ static IrInstruction *ir_build_suspend_finish(IrBuilder *irb, Scope *scope, AstN
     return &instruction->base;
 }
 
+static IrInstruction *ir_build_cancel(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *frame) {
+    IrInstructionCancel *instruction = ir_build_instruction<IrInstructionCancel>(irb, scope, source_node);
+    instruction->base.value.type = irb->codegen->builtin_types.entry_void;
+    instruction->frame = frame;
+
+    ir_ref_instruction(frame, irb->current_basic_block);
+
+    return &instruction->base;
+}
+
 static IrInstruction *ir_build_await_src(IrBuilder *irb, Scope *scope, AstNode *source_node,
         IrInstruction *frame, ResultLoc *result_loc)
 {
@@ -7820,11 +7830,26 @@ static IrInstruction *ir_gen_fn_proto(IrBuilder *irb, Scope *parent_scope, AstNo
 static IrInstruction *ir_gen_cancel(IrBuilder *irb, Scope *scope, AstNode *node) {
     assert(node->type == NodeTypeCancel);
 
-    IrInstruction *target_inst = ir_gen_node(irb, node->data.cancel_expr.expr, scope);
-    if (target_inst == irb->codegen->invalid_instruction)
+    ZigFn *fn_entry = exec_fn_entry(irb->exec);
+    if (!fn_entry) {
+        add_node_error(irb->codegen, node, buf_sprintf("cancel outside function definition"));
+        return irb->codegen->invalid_instruction;
+    }
+    ScopeSuspend *existing_suspend_scope = get_scope_suspend(scope);
+    if (existing_suspend_scope) {
+        if (!existing_suspend_scope->reported_err) {
+            ErrorMsg *msg = add_node_error(irb->codegen, node, buf_sprintf("cannot cancel inside suspend block"));
+            add_error_note(irb->codegen, msg, existing_suspend_scope->base.source_node, buf_sprintf("suspend block here"));
+            existing_suspend_scope->reported_err = true;
+        }
+        return irb->codegen->invalid_instruction;
+    }
+
+    IrInstruction *operand = ir_gen_node(irb, node->data.cancel_expr.expr, scope);
+    if (operand == irb->codegen->invalid_instruction)
         return irb->codegen->invalid_instruction;
 
-    zig_panic("TODO ir_gen_cancel");
+    return ir_build_cancel(irb, scope, node, operand);
 }
 
 static IrInstruction *ir_gen_resume(IrBuilder *irb, Scope *scope, AstNode *node) {
@@ -23781,10 +23806,6 @@ static IrInstruction *ir_analyze_instruction_tag_type(IrAnalyze *ira, IrInstruct
     }
 }
 
-static IrInstruction *ir_analyze_instruction_cancel(IrAnalyze *ira, IrInstructionCancel *instruction) {
-    zig_panic("TODO analyze cancel");
-}
-
 static ZigType *ir_resolve_atomic_operand_type(IrAnalyze *ira, IrInstruction *op) {
     ZigType *operand_type = ir_resolve_type(ira, op);
     if (type_is_invalid(operand_type))
@@ -24474,6 +24495,26 @@ static IrInstruction *ir_analyze_instruction_suspend_finish(IrAnalyze *ira,
     return ir_build_suspend_finish(&ira->new_irb, instruction->base.scope, instruction->base.source_node, begin);
 }
 
+static IrInstruction *ir_analyze_instruction_cancel(IrAnalyze *ira, IrInstructionCancel *instruction) {
+    IrInstruction *frame = instruction->frame->child;
+    if (type_is_invalid(frame->value.type))
+        return ira->codegen->invalid_instruction;
+
+    ZigType *any_frame_type = get_any_frame_type(ira->codegen, nullptr);
+    IrInstruction *casted_frame = ir_implicit_cast(ira, frame, any_frame_type);
+    if (type_is_invalid(casted_frame->value.type))
+        return ira->codegen->invalid_instruction;
+
+    ZigFn *fn_entry = exec_fn_entry(ira->new_irb.exec);
+    ir_assert(fn_entry != nullptr, &instruction->base);
+
+    if (fn_entry->inferred_async_node == nullptr) {
+        fn_entry->inferred_async_node = instruction->base.source_node;
+    }
+
+    return ir_build_cancel(&ira->new_irb, instruction->base.scope, instruction->base.source_node, casted_frame);
+}
+
 static IrInstruction *ir_analyze_instruction_await(IrAnalyze *ira, IrInstructionAwaitSrc *instruction) {
     IrInstruction *frame_ptr = instruction->frame->child;
     if (type_is_invalid(frame_ptr->value.type))
src/ir_print.cpp
@@ -1396,7 +1396,7 @@ static void ir_print_error_union(IrPrint *irp, IrInstructionErrorUnion *instruct
 
 static void ir_print_cancel(IrPrint *irp, IrInstructionCancel *instruction) {
     fprintf(irp->f, "cancel ");
-    ir_print_other_instruction(irp, instruction->target);
+    ir_print_other_instruction(irp, instruction->frame);
 }
 
 static void ir_print_atomic_rmw(IrPrint *irp, IrInstructionAtomicRmw *instruction) {
test/stage1/behavior/cancel.zig
@@ -1,86 +1,94 @@
 const std = @import("std");
+const expect = std.testing.expect;
 
-//var defer_f1: bool = false;
-//var defer_f2: bool = false;
-//var defer_f3: bool = false;
-//
-//test "cancel forwards" {
-//    const p = async<std.heap.direct_allocator> f1() catch unreachable;
-//    cancel p;
-//    std.testing.expect(defer_f1);
-//    std.testing.expect(defer_f2);
-//    std.testing.expect(defer_f3);
-//}
-//
-//async fn f1() void {
-//    defer {
-//        defer_f1 = true;
-//    }
-//    await (async f2() catch unreachable);
-//}
-//
-//async fn f2() void {
-//    defer {
-//        defer_f2 = true;
-//    }
-//    await (async f3() catch unreachable);
-//}
-//
-//async fn f3() void {
-//    defer {
-//        defer_f3 = true;
-//    }
-//    suspend;
-//}
-//
-//var defer_b1: bool = false;
-//var defer_b2: bool = false;
-//var defer_b3: bool = false;
-//var defer_b4: bool = false;
-//
-//test "cancel backwards" {
-//    const p = async<std.heap.direct_allocator> b1() catch unreachable;
-//    cancel p;
-//    std.testing.expect(defer_b1);
-//    std.testing.expect(defer_b2);
-//    std.testing.expect(defer_b3);
-//    std.testing.expect(defer_b4);
-//}
-//
-//async fn b1() void {
-//    defer {
-//        defer_b1 = true;
-//    }
-//    await (async b2() catch unreachable);
-//}
-//
-//var b4_handle: promise = undefined;
-//
-//async fn b2() void {
-//    const b3_handle = async b3() catch unreachable;
-//    resume b4_handle;
-//    cancel b4_handle;
-//    defer {
-//        defer_b2 = true;
-//    }
-//    const value = await b3_handle;
-//    @panic("unreachable");
-//}
-//
-//async fn b3() i32 {
-//    defer {
-//        defer_b3 = true;
-//    }
-//    await (async b4() catch unreachable);
-//    return 1234;
-//}
-//
-//async fn b4() void {
-//    defer {
-//        defer_b4 = true;
-//    }
-//    suspend {
-//        b4_handle = @handle();
-//    }
-//    suspend;
-//}
+var defer_f1: bool = false;
+var defer_f2: bool = false;
+var defer_f3: bool = false;
+var f3_frame: anyframe = undefined;
+
+test "cancel forwards" {
+    _ = async atest1();
+    resume f3_frame;
+}
+
+fn atest1() void {
+    const p = async f1();
+    cancel &p;
+    expect(defer_f1);
+    expect(defer_f2);
+    expect(defer_f3);
+}
+
+async fn f1() void {
+    defer {
+        defer_f1 = true;
+    }
+    var f2_frame = async f2();
+    await f2_frame;
+}
+
+async fn f2() void {
+    defer {
+        defer_f2 = true;
+    }
+    f3();
+}
+
+async fn f3() void {
+    f3_frame = @frame();
+    defer {
+        defer_f3 = true;
+    }
+    suspend;
+}
+
+var defer_b1: bool = false;
+var defer_b2: bool = false;
+var defer_b3: bool = false;
+var defer_b4: bool = false;
+
+test "cancel backwards" {
+    _ = async b1();
+    resume b4_handle;
+    expect(defer_b1);
+    expect(defer_b2);
+    expect(defer_b3);
+    expect(defer_b4);
+}
+
+async fn b1() void {
+    defer {
+        defer_b1 = true;
+    }
+    b2();
+}
+
+var b4_handle: anyframe = undefined;
+
+async fn b2() void {
+    const b3_handle = async b3();
+    resume b4_handle;
+    defer {
+        defer_b2 = true;
+    }
+    const value = await b3_handle;
+    expect(value == 1234);
+}
+
+async fn b3() i32 {
+    defer {
+        defer_b3 = true;
+    }
+    b4();
+    return 1234;
+}
+
+async fn b4() void {
+    defer {
+        defer_b4 = true;
+    }
+    suspend {
+        b4_handle = @frame();
+    }
+    suspend;
+}
BRANCH_TODO
@@ -1,4 +1,4 @@
- * go over the commented out tests in cancel.zig
+ * clean up the bitcasting of awaiter fn ptr
  * compile error for error: expected anyframe->T, found 'anyframe'
  * compile error for error: expected anyframe->T, found 'i32'
  * await of a non async function