Commit b9d1d45dfd

Andrew Kelley <andrew@ziglang.org>
2019-08-10 03:49:40
fix combining try with errdefer cancel
1 parent 2e7f53f
Changed files (5)
src/all_types.hpp
@@ -2366,6 +2366,7 @@ enum IrInstructionId {
     IrInstructionIdAwaitGen,
     IrInstructionIdCoroResume,
     IrInstructionIdTestCancelRequested,
+    IrInstructionIdSpill,
 };
 
 struct IrInstruction {
@@ -3643,6 +3644,18 @@ struct IrInstructionTestCancelRequested {
     IrInstruction base;
 };
 
+enum SpillId {
+    SpillIdInvalid,
+    SpillIdRetErrCode,
+};
+
+struct IrInstructionSpill {
+    IrInstruction base;
+
+    SpillId spill_id;
+    IrInstruction *operand;
+};
+
 enum ResultLocId {
     ResultLocIdInvalid,
     ResultLocIdNone,
src/codegen.cpp
@@ -5113,17 +5113,9 @@ static LLVMValueRef ir_render_test_err(CodeGen *g, IrExecutable *executable, IrI
     return LLVMBuildICmp(g->builder, LLVMIntNE, err_val, zero, "");
 }
 
-static LLVMValueRef ir_render_unwrap_err_code(CodeGen *g, IrExecutable *executable,
-        IrInstructionUnwrapErrCode *instruction)
-{
-    if (instruction->base.value.special != ConstValSpecialRuntime)
-        return nullptr;
-
-    ZigType *ptr_type = instruction->err_union_ptr->value.type;
-    assert(ptr_type->id == ZigTypeIdPointer);
+static LLVMValueRef gen_unwrap_err_code(CodeGen *g, LLVMValueRef err_union_ptr, ZigType *ptr_type) {
     ZigType *err_union_type = ptr_type->data.pointer.child_type;
     ZigType *payload_type = err_union_type->data.error_union.payload_type;
-    LLVMValueRef err_union_ptr = ir_llvm_value(g, instruction->err_union_ptr);
     if (!type_has_bits(payload_type)) {
         return err_union_ptr;
     } else {
@@ -5133,6 +5125,18 @@ static LLVMValueRef ir_render_unwrap_err_code(CodeGen *g, IrExecutable *executab
     }
 }
 
+static LLVMValueRef ir_render_unwrap_err_code(CodeGen *g, IrExecutable *executable,
+        IrInstructionUnwrapErrCode *instruction)
+{
+    if (instruction->base.value.special != ConstValSpecialRuntime)
+        return nullptr;
+
+    ZigType *ptr_type = instruction->err_union_ptr->value.type;
+    assert(ptr_type->id == ZigTypeIdPointer);
+    LLVMValueRef err_union_ptr = ir_llvm_value(g, instruction->err_union_ptr);
+    return gen_unwrap_err_code(g, err_union_ptr, ptr_type);
+}
+
 static LLVMValueRef ir_render_unwrap_err_payload(CodeGen *g, IrExecutable *executable,
         IrInstructionUnwrapErrPayload *instruction)
 {
@@ -5611,6 +5615,27 @@ static LLVMValueRef ir_render_test_cancel_requested(CodeGen *g, IrExecutable *ex
     }
 }
 
+static LLVMValueRef ir_render_spill(CodeGen *g, IrExecutable *executable, IrInstructionSpill *instruction) {
+    if (!fn_is_async(g->cur_fn))
+        return ir_llvm_value(g, instruction->operand);
+
+    switch (instruction->spill_id) {
+        case SpillIdInvalid:
+            zig_unreachable();
+        case SpillIdRetErrCode: {
+            LLVMValueRef ret_ptr = LLVMBuildLoad(g->builder, g->cur_ret_ptr, "");
+            ZigType *ret_type = g->cur_fn->type_entry->data.fn.fn_type_id.return_type;
+            if (ret_type->id == ZigTypeIdErrorUnion) {
+                return gen_unwrap_err_code(g, ret_ptr, get_pointer_to_type(g, ret_type, true));
+            } else {
+                zig_unreachable();
+            }
+        }
+
+    }
+    zig_unreachable();
+}
+
 static void set_debug_location(CodeGen *g, IrInstruction *instruction) {
     AstNode *source_node = instruction->source_node;
     Scope *scope = instruction->scope;
@@ -5866,6 +5891,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
             return ir_render_await(g, executable, (IrInstructionAwaitGen *)instruction);
         case IrInstructionIdTestCancelRequested:
             return ir_render_test_cancel_requested(g, executable, (IrInstructionTestCancelRequested *)instruction);
+        case IrInstructionIdSpill:
+            return ir_render_spill(g, executable, (IrInstructionSpill *)instruction);
     }
     zig_unreachable();
 }
src/ir.cpp
@@ -1066,6 +1066,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionTestCancelReques
     return IrInstructionIdTestCancelRequested;
 }
 
+static constexpr IrInstructionId ir_instruction_id(IrInstructionSpill *) {
+    return IrInstructionIdSpill;
+}
+
 template<typename T>
 static T *ir_create_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node) {
     T *special_instruction = allocate<T>(1);
@@ -3332,6 +3336,18 @@ static IrInstruction *ir_build_test_cancel_requested(IrBuilder *irb, Scope *scop
     return &instruction->base;
 }
 
+static IrInstruction *ir_build_spill(IrBuilder *irb, Scope *scope, AstNode *source_node,
+        IrInstruction *operand, SpillId spill_id)
+{
+    IrInstructionSpill *instruction = ir_build_instruction<IrInstructionSpill>(irb, scope, source_node);
+    instruction->operand = operand;
+    instruction->spill_id = spill_id;
+
+    ir_ref_instruction(operand, irb->current_basic_block);
+
+    return &instruction->base;
+}
+
 static void ir_count_defers(IrBuilder *irb, Scope *inner_scope, Scope *outer_scope, size_t *results) {
     results[ReturnKindUnconditional] = 0;
     results[ReturnKindError] = 0;
@@ -3591,6 +3607,7 @@ static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node,
                     ResultLocReturn *result_loc_ret = allocate<ResultLocReturn>(1);
                     result_loc_ret->base.id = ResultLocIdReturn;
                     ir_build_reset_result(irb, scope, node, &result_loc_ret->base);
+                    err_val = ir_build_spill(irb, scope, node, err_val, SpillIdRetErrCode);
                     ir_build_end_expr(irb, scope, node, err_val, &result_loc_ret->base);
 
                     if (irb->codegen->have_err_ret_tracing && !should_inline) {
@@ -24725,6 +24742,19 @@ static IrInstruction *ir_analyze_instruction_test_cancel_requested(IrAnalyze *ir
     return ir_build_test_cancel_requested(&ira->new_irb, instruction->base.scope, instruction->base.source_node);
 }
 
+static IrInstruction *ir_analyze_instruction_spill(IrAnalyze *ira, IrInstructionSpill *instruction) {
+    IrInstruction *operand = instruction->operand->child;
+    if (type_is_invalid(operand->value.type))
+        return ira->codegen->invalid_instruction;
+    if (ir_should_inline(ira->new_irb.exec, instruction->base.scope)) {
+        return operand;
+    }
+    IrInstruction *result = ir_build_spill(&ira->new_irb, instruction->base.scope, instruction->base.source_node,
+            operand, instruction->spill_id);
+    result->value.type = operand->value.type;
+    return result;
+}
+
 static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction *instruction) {
     switch (instruction->id) {
         case IrInstructionIdInvalid:
@@ -25024,6 +25054,8 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction
             return ir_analyze_instruction_await(ira, (IrInstructionAwaitSrc *)instruction);
         case IrInstructionIdTestCancelRequested:
             return ir_analyze_instruction_test_cancel_requested(ira, (IrInstructionTestCancelRequested *)instruction);
+        case IrInstructionIdSpill:
+            return ir_analyze_instruction_spill(ira, (IrInstructionSpill *)instruction);
     }
     zig_unreachable();
 }
@@ -25259,6 +25291,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdAllocaSrc:
         case IrInstructionIdAllocaGen:
         case IrInstructionIdTestCancelRequested:
+        case IrInstructionIdSpill:
             return false;
 
         case IrInstructionIdAsm:
src/ir_print.cpp
@@ -1554,6 +1554,12 @@ static void ir_print_test_cancel_requested(IrPrint *irp, IrInstructionTestCancel
     fprintf(irp->f, "@testCancelRequested()");
 }
 
+static void ir_print_spill(IrPrint *irp, IrInstructionSpill *instruction) {
+    fprintf(irp->f, "@spill(");
+    ir_print_other_instruction(irp, instruction->operand);
+    fprintf(irp->f, ")");
+}
+
 static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
     ir_print_prefix(irp, instruction);
     switch (instruction->id) {
@@ -2039,6 +2045,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
         case IrInstructionIdTestCancelRequested:
             ir_print_test_cancel_requested(irp, (IrInstructionTestCancelRequested *)instruction);
             break;
+        case IrInstructionIdSpill:
+            ir_print_spill(irp, (IrInstructionSpill *)instruction);
+            break;
     }
     fprintf(irp->f, "\n");
 }
test/stage1/behavior/coroutines.zig
@@ -613,3 +613,32 @@ test "cancel inside an errdefer" {
     };
     S.doTheTest();
 }
+
+test "combining try with errdefer cancel" {
+    const S = struct {
+        var frame: anyframe = undefined;
+        var ok = false;
+
+        fn doTheTest() void {
+            _ = async amain();
+            resume frame;
+            expect(ok);
+        }
+
+        fn amain() !void {
+            var f = async func("https://example.com/");
+            errdefer cancel f;
+
+            _ = try await f;
+        }
+
+        fn func(url: []const u8) ![]u8 {
+            errdefer ok = true;
+            frame = @frame();
+            suspend;
+            return error.Bad;
+        }
+
+    };
+    S.doTheTest();
+}