Commit af8c6ccb4b

Andrew Kelley <andrew@ziglang.org>
2019-08-11 20:26:34
fix canceling async functions which have error return tracing
1 parent 3f5c6d7
Changed files (2)
src/codegen.cpp
@@ -2071,6 +2071,7 @@ static LLVMValueRef get_merge_err_ret_traces_fn_val(CodeGen *g) {
     LLVMPositionBuilderAtEnd(g->builder, entry_block);
     ZigLLVMClearCurrentDebugLocation(g->builder);
 
+    // if (dest_stack_trace == null) return;
     // var frame_index: usize = undefined;
     // var frames_left: usize = undefined;
     // if (src_stack_trace.index < src_stack_trace.instruction_addresses.len) {
@@ -2088,6 +2089,7 @@ static LLVMValueRef get_merge_err_ret_traces_fn_val(CodeGen *g) {
     //     frame_index = (frame_index + 1) % src_stack_trace.instruction_addresses.len;
     // }
     LLVMBasicBlockRef return_block = LLVMAppendBasicBlock(fn_val, "Return");
+    LLVMBasicBlockRef dest_non_null_block = LLVMAppendBasicBlock(fn_val, "DestNonNull");
 
     LLVMValueRef frame_index_ptr = LLVMBuildAlloca(g->builder, g->builtin_types.entry_usize->llvm_type, "frame_index");
     LLVMValueRef frames_left_ptr = LLVMBuildAlloca(g->builder, g->builtin_types.entry_usize->llvm_type, "frames_left");
@@ -2095,6 +2097,11 @@ static LLVMValueRef get_merge_err_ret_traces_fn_val(CodeGen *g) {
     LLVMValueRef dest_stack_trace_ptr = LLVMGetParam(fn_val, 0);
     LLVMValueRef src_stack_trace_ptr = LLVMGetParam(fn_val, 1);
 
+    LLVMValueRef null_dest_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, dest_stack_trace_ptr,
+            LLVMConstNull(LLVMTypeOf(dest_stack_trace_ptr)), "");
+    LLVMBuildCondBr(g->builder, null_dest_bit, return_block, dest_non_null_block);
+
+    LLVMPositionBuilderAtEnd(g->builder, dest_non_null_block);
     size_t src_index_field_index = g->stack_trace_type->data.structure.fields[0].gen_index;
     size_t src_addresses_field_index = g->stack_trace_type->data.structure.fields[1].gen_index;
     LLVMValueRef src_index_field_ptr = LLVMBuildStructGEP(g->builder, src_stack_trace_ptr,
@@ -5480,10 +5487,20 @@ static LLVMValueRef ir_render_cancel(CodeGen *g, IrExecutable *executable, IrIns
     LLVMValueRef zero = LLVMConstNull(usize_type_ref);
     LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref);
     LLVMValueRef one = LLVMConstInt(usize_type_ref, 1, false);
+    src_assert(instruction->frame->value.type->id == ZigTypeIdAnyFrame, instruction->base.source_node);
+    ZigType *result_type = instruction->frame->value.type->data.any_frame.result_type;
 
     LLVMValueRef target_frame_ptr = ir_llvm_value(g, instruction->frame);
     LLVMBasicBlockRef resume_bb = gen_suspend_begin(g, "CancelResume");
 
+    // supply null for the error return trace pointer
+    if (codegen_fn_has_err_ret_tracing_arg(g, result_type)) {
+        LLVMValueRef err_ret_trace_ptr_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr,
+                frame_index_trace_arg(g, result_type), "");
+        LLVMBuildStore(g->builder, LLVMConstNull(LLVMGetElementType(LLVMTypeOf(err_ret_trace_ptr_ptr))),
+                err_ret_trace_ptr_ptr);
+    }
+
     LLVMValueRef awaiter_val = LLVMBuildPtrToInt(g->builder, g->cur_frame_ptr, usize_type_ref, "");
     LLVMValueRef awaiter_ored_val = LLVMBuildOr(g->builder, awaiter_val, one, "");
     LLVMValueRef awaiter_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr, coro_awaiter_index, "");
src/ir.cpp
@@ -24656,26 +24656,51 @@ static IrInstruction *ir_analyze_instruction_suspend_finish(IrAnalyze *ira,
     return ir_build_suspend_finish(&ira->new_irb, instruction->base.scope, instruction->base.source_node, begin);
 }
 
-static IrInstruction *ir_analyze_instruction_cancel(IrAnalyze *ira, IrInstructionCancel *instruction) {
-    IrInstruction *frame_ptr = instruction->frame->child;
+static IrInstruction *analyze_frame_ptr_to_anyframe_T(IrAnalyze *ira, IrInstruction *source_instr,
+        IrInstruction *frame_ptr)
+{
     if (type_is_invalid(frame_ptr->value.type))
         return ira->codegen->invalid_instruction;
 
+    ZigType *result_type;
     IrInstruction *frame;
     if (frame_ptr->value.type->id == ZigTypeIdPointer &&
         frame_ptr->value.type->data.pointer.ptr_len == PtrLenSingle &&
         frame_ptr->value.type->data.pointer.child_type->id == ZigTypeIdCoroFrame)
     {
+        result_type = frame_ptr->value.type->data.pointer.child_type->data.frame.fn->type_entry->data.fn.fn_type_id.return_type;
         frame = frame_ptr;
     } else {
-        frame = ir_get_deref(ira, &instruction->base, frame_ptr, nullptr);
+        frame = ir_get_deref(ira, source_instr, frame_ptr, nullptr);
+        if (frame->value.type->id == ZigTypeIdPointer &&
+            frame->value.type->data.pointer.ptr_len == PtrLenSingle &&
+            frame->value.type->data.pointer.child_type->id == ZigTypeIdCoroFrame)
+        {
+            result_type = frame->value.type->data.pointer.child_type->data.frame.fn->type_entry->data.fn.fn_type_id.return_type;
+        } else if (frame->value.type->id != ZigTypeIdAnyFrame ||
+            frame->value.type->data.any_frame.result_type == nullptr)
+        {
+            ir_add_error(ira, source_instr,
+                buf_sprintf("expected anyframe->T, found '%s'", buf_ptr(&frame->value.type->name)));
+            return ira->codegen->invalid_instruction;
+        } else {
+            result_type = frame->value.type->data.any_frame.result_type;
+        }
     }
 
-    ZigType *any_frame_type = get_any_frame_type(ira->codegen, nullptr);
+    ZigType *any_frame_type = get_any_frame_type(ira->codegen, result_type);
     IrInstruction *casted_frame = ir_implicit_cast(ira, frame, any_frame_type);
     if (type_is_invalid(casted_frame->value.type))
         return ira->codegen->invalid_instruction;
 
+    return casted_frame;
+}
+
+static IrInstruction *ir_analyze_instruction_cancel(IrAnalyze *ira, IrInstructionCancel *instruction) {
+    IrInstruction *frame = analyze_frame_ptr_to_anyframe_T(ira, &instruction->base, instruction->frame->child);
+    if (type_is_invalid(frame->value.type))
+        return ira->codegen->invalid_instruction;
+
     ZigFn *fn_entry = exec_fn_entry(ira->new_irb.exec);
     ir_assert(fn_entry != nullptr, &instruction->base);
 
@@ -24683,38 +24708,15 @@ static IrInstruction *ir_analyze_instruction_cancel(IrAnalyze *ira, IrInstructio
         fn_entry->inferred_async_node = instruction->base.source_node;
     }
 
-    return ir_build_cancel(&ira->new_irb, instruction->base.scope, instruction->base.source_node, casted_frame);
+    return ir_build_cancel(&ira->new_irb, instruction->base.scope, instruction->base.source_node, frame);
 }
 
 static IrInstruction *ir_analyze_instruction_await(IrAnalyze *ira, IrInstructionAwaitSrc *instruction) {
-    IrInstruction *frame_ptr = instruction->frame->child;
-    if (type_is_invalid(frame_ptr->value.type))
+    IrInstruction *frame = analyze_frame_ptr_to_anyframe_T(ira, &instruction->base, instruction->frame->child);
+    if (type_is_invalid(frame->value.type))
         return ira->codegen->invalid_instruction;
 
-    ZigType *result_type;
-    IrInstruction *frame;
-    if (frame_ptr->value.type->id == ZigTypeIdPointer &&
-        frame_ptr->value.type->data.pointer.ptr_len == PtrLenSingle &&
-        frame_ptr->value.type->data.pointer.child_type->id == ZigTypeIdCoroFrame)
-    {
-        result_type = frame_ptr->value.type->data.pointer.child_type->data.frame.fn->type_entry->data.fn.fn_type_id.return_type;
-        frame = frame_ptr;
-    } else {
-        frame = ir_get_deref(ira, &instruction->base, frame_ptr, nullptr);
-        if (frame->value.type->id != ZigTypeIdAnyFrame ||
-            frame->value.type->data.any_frame.result_type == nullptr)
-        {
-            ir_add_error(ira, &instruction->base,
-                buf_sprintf("expected anyframe->T, found '%s'", buf_ptr(&frame->value.type->name)));
-            return ira->codegen->invalid_instruction;
-        }
-        result_type = frame->value.type->data.any_frame.result_type;
-    }
-
-    ZigType *any_frame_type = get_any_frame_type(ira->codegen, result_type);
-    IrInstruction *casted_frame = ir_implicit_cast(ira, frame, any_frame_type);
-    if (type_is_invalid(casted_frame->value.type))
-        return ira->codegen->invalid_instruction;
+    ZigType *result_type = frame->value.type->data.any_frame.result_type;
 
     ZigFn *fn_entry = exec_fn_entry(ira->new_irb.exec);
     ir_assert(fn_entry != nullptr, &instruction->base);