Commit 18af2f9a27

Andrew Kelley <superjoe30@gmail.com>
2018-03-24 23:21:51
fix async fns with inferred error sets
closes #856
1 parent b1c07c0
src/all_types.hpp
@@ -1251,7 +1251,10 @@ struct FnTableEntry {
     ScopeBlock *def_scope; // parent is child_scope
     Buf symbol_name;
     TypeTableEntry *type_entry; // function type
-    TypeTableEntry *implicit_return_type;
+    // in the case of normal functions this is the implicit return type
+    // in the case of async functions this is the implicit return type according to the
+    // zig source code, not according to zig ir
+    TypeTableEntry *src_implicit_return_type;
     bool is_test;
     FnInline fn_inline;
     FnAnalState anal_state;
@@ -2035,6 +2038,7 @@ enum IrInstructionId {
     IrInstructionIdPromiseResultType,
     IrInstructionIdAwaitBookkeeping,
     IrInstructionIdSaveErrRetAddr,
+    IrInstructionIdAddImplicitReturnType,
 };
 
 struct IrInstruction {
@@ -2993,6 +2997,12 @@ struct IrInstructionSaveErrRetAddr {
     IrInstruction base;
 };
 
+struct IrInstructionAddImplicitReturnType {
+    IrInstruction base;
+
+    IrInstruction *value;
+};
+
 static const size_t slice_ptr_index = 0;
 static const size_t slice_len_index = 1;
 
src/analyze.cpp
@@ -3865,7 +3865,7 @@ void analyze_fn_ir(CodeGen *g, FnTableEntry *fn_table_entry, AstNode *return_typ
 
     TypeTableEntry *block_return_type = ir_analyze(g, &fn_table_entry->ir_executable,
             &fn_table_entry->analyzed_executable, fn_type_id->return_type, return_type_node);
-    fn_table_entry->implicit_return_type = block_return_type;
+    fn_table_entry->src_implicit_return_type = block_return_type;
 
     if (type_is_invalid(block_return_type) || fn_table_entry->analyzed_executable.invalid) {
         assert(g->errors.length > 0);
@@ -3877,10 +3877,10 @@ void analyze_fn_ir(CodeGen *g, FnTableEntry *fn_table_entry, AstNode *return_typ
         TypeTableEntry *return_err_set_type = fn_type_id->return_type->data.error_union.err_set_type;
         if (return_err_set_type->data.error_set.infer_fn != nullptr) {
             TypeTableEntry *inferred_err_set_type;
-            if (fn_table_entry->implicit_return_type->id == TypeTableEntryIdErrorSet) {
-                inferred_err_set_type = fn_table_entry->implicit_return_type;
-            } else if (fn_table_entry->implicit_return_type->id == TypeTableEntryIdErrorUnion) {
-                inferred_err_set_type = fn_table_entry->implicit_return_type->data.error_union.err_set_type;
+            if (fn_table_entry->src_implicit_return_type->id == TypeTableEntryIdErrorSet) {
+                inferred_err_set_type = fn_table_entry->src_implicit_return_type;
+            } else if (fn_table_entry->src_implicit_return_type->id == TypeTableEntryIdErrorUnion) {
+                inferred_err_set_type = fn_table_entry->src_implicit_return_type->data.error_union.err_set_type;
             } else {
                 add_node_error(g, return_type_node,
                         buf_sprintf("function with inferred error set must return at least one possible error"));
src/ast_render.cpp
@@ -658,6 +658,15 @@ static void render_node_extra(AstRender *ar, AstNode *node, bool grouped) {
                 if (node->data.fn_call_expr.is_builtin) {
                     fprintf(ar->f, "@");
                 }
+                if (node->data.fn_call_expr.is_async) {
+                    fprintf(ar->f, "async");
+                    if (node->data.fn_call_expr.async_allocator != nullptr) {
+                        fprintf(ar->f, "<");
+                        render_node_extra(ar, node->data.fn_call_expr.async_allocator, true);
+                        fprintf(ar->f, ">");
+                    }
+                    fprintf(ar->f, " ");
+                }
                 AstNode *fn_ref_node = node->data.fn_call_expr.fn_ref_expr;
                 bool grouped = (fn_ref_node->type != NodeTypePrefixOpExpr && fn_ref_node->type != NodeTypeAddrOfExpr);
                 render_node_extra(ar, fn_ref_node, grouped);
@@ -1023,7 +1032,7 @@ static void render_node_extra(AstRender *ar, AstNode *node, bool grouped) {
         case NodeTypeUnwrapErrorExpr:
             {
                 render_node_ungrouped(ar, node->data.unwrap_err_expr.op1);
-                fprintf(ar->f, " %%%% ");
+                fprintf(ar->f, " catch ");
                 if (node->data.unwrap_err_expr.symbol) {
                     Buf *var_name = node->data.unwrap_err_expr.symbol->data.symbol_expr.symbol;
                     fprintf(ar->f, "|%s| ", buf_ptr(var_name));
src/codegen.cpp
@@ -4245,6 +4245,7 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
         case IrInstructionIdErrorUnion:
         case IrInstructionIdPromiseResultType:
         case IrInstructionIdAwaitBookkeeping:
+        case IrInstructionIdAddImplicitReturnType:
             zig_unreachable();
 
         case IrInstructionIdReturn:
src/ir.cpp
@@ -34,7 +34,7 @@ struct IrAnalyze {
     size_t old_bb_index;
     size_t instruction_index;
     TypeTableEntry *explicit_return_type;
-    ZigList<IrInstruction *> implicit_return_type_list;
+    ZigList<IrInstruction *> src_implicit_return_type_list;
     IrBasicBlock *const_predecessor_bb;
 };
 
@@ -717,6 +717,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionSaveErrRetAddr *
     return IrInstructionIdSaveErrRetAddr;
 }
 
+static constexpr IrInstructionId ir_instruction_id(IrInstructionAddImplicitReturnType *) {
+    return IrInstructionIdAddImplicitReturnType;
+}
+
 template<typename T>
 static T *ir_create_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node) {
     T *special_instruction = allocate<T>(1);
@@ -2687,6 +2691,17 @@ static IrInstruction *ir_build_save_err_ret_addr(IrBuilder *irb, Scope *scope, A
     return &instruction->base;
 }
 
+static IrInstruction *ir_build_add_implicit_return_type(IrBuilder *irb, Scope *scope, AstNode *source_node,
+        IrInstruction *value)
+{
+    IrInstructionAddImplicitReturnType *instruction = ir_build_instruction<IrInstructionAddImplicitReturnType>(irb, scope, source_node);
+    instruction->value = value;
+
+    ir_ref_instruction(value, irb->current_basic_block);
+
+    return &instruction->base;
+}
+
 static void ir_count_defers(IrBuilder *irb, Scope *inner_scope, Scope *outer_scope, size_t *results) {
     results[ReturnKindUnconditional] = 0;
     results[ReturnKindError] = 0;
@@ -2767,6 +2782,8 @@ static bool exec_is_async(IrExecutable *exec) {
 static IrInstruction *ir_gen_async_return(IrBuilder *irb, Scope *scope, AstNode *node, IrInstruction *return_value,
     bool is_generated_code)
 {
+    ir_mark_gen(ir_build_add_implicit_return_type(irb, scope, node, return_value));
+
     bool is_async = exec_is_async(irb->exec);
     if (!is_async) {
         IrInstruction *return_inst = ir_build_return(irb, scope, node, return_value);
@@ -6399,6 +6416,8 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec
         ir_build_cond_br(irb, scope, node, alloc_result_is_ok, alloc_ok_block, alloc_err_block, const_bool_false);
 
         ir_set_cursor_at_end_and_append_block(irb, alloc_err_block);
+        // we can return undefined here, because the caller passes a pointer to the error struct field
+        // in the error union result, and we populate it in case of allocation failure.
         IrInstruction *undef = ir_build_const_undefined(irb, scope, node);
         ir_build_return(irb, scope, node, undef);
 
@@ -10108,13 +10127,26 @@ static Buf *ir_resolve_str(IrAnalyze *ira, IrInstruction *value) {
     return result;
 }
 
+static TypeTableEntry *ir_analyze_instruction_add_implicit_return_type(IrAnalyze *ira,
+        IrInstructionAddImplicitReturnType *instruction)
+{
+    IrInstruction *value = instruction->value->other;
+    if (type_is_invalid(value->value.type))
+        return ir_unreach_error(ira);
+
+    ira->src_implicit_return_type_list.append(value);
+
+    ConstExprValue *out_val = ir_build_const_from(ira, &instruction->base);
+    out_val->type = ira->codegen->builtin_types.entry_void;
+    return out_val->type;
+}
+
 static TypeTableEntry *ir_analyze_instruction_return(IrAnalyze *ira,
     IrInstructionReturn *return_instruction)
 {
     IrInstruction *value = return_instruction->value->other;
     if (type_is_invalid(value->value.type))
         return ir_unreach_error(ira);
-    ira->implicit_return_type_list.append(value);
 
     IrInstruction *casted_value = ir_implicit_cast(ira, value, ira->explicit_return_type);
     if (casted_value == ira->codegen->invalid_instruction)
@@ -18049,6 +18081,8 @@ static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructi
             return ir_analyze_instruction_await_bookkeeping(ira, (IrInstructionAwaitBookkeeping *)instruction);
         case IrInstructionIdSaveErrRetAddr:
             return ir_analyze_instruction_save_err_ret_addr(ira, (IrInstructionSaveErrRetAddr *)instruction);
+        case IrInstructionIdAddImplicitReturnType:
+            return ir_analyze_instruction_add_implicit_return_type(ira, (IrInstructionAddImplicitReturnType *)instruction);
     }
     zig_unreachable();
 }
@@ -18122,11 +18156,11 @@ TypeTableEntry *ir_analyze(CodeGen *codegen, IrExecutable *old_exec, IrExecutabl
 
     if (new_exec->invalid) {
         return ira->codegen->builtin_types.entry_invalid;
-    } else if (ira->implicit_return_type_list.length == 0) {
+    } else if (ira->src_implicit_return_type_list.length == 0) {
         return codegen->builtin_types.entry_unreachable;
     } else {
-        return ir_resolve_peer_types(ira, expected_type_source_node, ira->implicit_return_type_list.items,
-                ira->implicit_return_type_list.length);
+        return ir_resolve_peer_types(ira, expected_type_source_node, ira->src_implicit_return_type_list.items,
+                ira->src_implicit_return_type_list.length);
     }
 }
 
@@ -18175,6 +18209,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdCoroAllocHelper:
         case IrInstructionIdAwaitBookkeeping:
         case IrInstructionIdSaveErrRetAddr:
+        case IrInstructionIdAddImplicitReturnType:
             return true;
 
         case IrInstructionIdPhi:
src/ir_print.cpp
@@ -201,9 +201,9 @@ static void ir_print_call(IrPrint *irp, IrInstructionCall *call_instruction) {
     if (call_instruction->is_async) {
         fprintf(irp->f, "async");
         if (call_instruction->async_allocator != nullptr) {
-            fprintf(irp->f, "(");
+            fprintf(irp->f, "<");
             ir_print_other_instruction(irp, call_instruction->async_allocator);
-            fprintf(irp->f, ")");
+            fprintf(irp->f, ">");
         }
         fprintf(irp->f, " ");
     }
@@ -1165,6 +1165,12 @@ static void ir_print_save_err_ret_addr(IrPrint *irp, IrInstructionSaveErrRetAddr
     fprintf(irp->f, "@saveErrRetAddr()");
 }
 
+static void ir_print_add_implicit_return_type(IrPrint *irp, IrInstructionAddImplicitReturnType *instruction) {
+    fprintf(irp->f, "@addImplicitReturnType(");
+    ir_print_other_instruction(irp, instruction->value);
+    fprintf(irp->f, ")");
+}
+
 static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
     ir_print_prefix(irp, instruction);
     switch (instruction->id) {
@@ -1539,6 +1545,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
         case IrInstructionIdSaveErrRetAddr:
             ir_print_save_err_ret_addr(irp, (IrInstructionSaveErrRetAddr *)instruction);
             break;
+        case IrInstructionIdAddImplicitReturnType:
+            ir_print_add_implicit_return_type(irp, (IrInstructionAddImplicitReturnType *)instruction);
+            break;
     }
     fprintf(irp->f, "\n");
 }
test/cases/coroutines.zig
@@ -176,3 +176,14 @@ async<&std.mem.Allocator> fn simpleAsyncFn2(y: &i32) void {
     *y += 1;
     suspend;
 }
+
+test "async fn with inferred error set" {
+    const p = (async<std.debug.global_allocator> failing()) catch unreachable;
+    resume p;
+    cancel p;
+}
+
+async fn failing() !void {
+    suspend;
+    return error.Fail;
+}