Commit 03910925f0

Andrew Kelley <andrew@ziglang.org>
2019-08-30 03:51:31
await does not force async if callee is blocking
closes #3067
1 parent 8e93991
src/all_types.hpp
@@ -36,6 +36,7 @@ struct IrInstruction;
 struct IrInstructionCast;
 struct IrInstructionAllocaGen;
 struct IrInstructionCallGen;
+struct IrInstructionAwaitGen;
 struct IrBasicBlock;
 struct ScopeDecls;
 struct ZigWindowsSDK;
@@ -1486,6 +1487,7 @@ struct ZigFn {
     AstNode **param_source_nodes;
     Buf **param_names;
     IrInstruction *err_code_spill;
+    AstNode *assumed_non_async;
 
     AstNode *fn_no_inline_set_node;
     AstNode *fn_static_eval_set_node;
@@ -1503,6 +1505,7 @@ struct ZigFn {
 
     ZigList<GlobalExport> export_list;
     ZigList<IrInstructionCallGen *> call_list;
+    ZigList<IrInstructionAwaitGen *> await_list;
 
     LLVMValueRef valgrind_client_request_array;
 
@@ -3717,6 +3720,7 @@ struct IrInstructionAwaitGen {
 
     IrInstruction *frame;
     IrInstruction *result_loc;
+    ZigFn *target_fn;
 };
 
 struct IrInstructionResume {
src/analyze.cpp
@@ -31,6 +31,7 @@ static void analyze_fn_body(CodeGen *g, ZigFn *fn_table_entry);
 static void resolve_llvm_types(CodeGen *g, ZigType *type, ResolveStatus wanted_resolve_status);
 static void preview_use_decl(CodeGen *g, TldUsingNamespace *using_namespace, ScopeDecls *dest_decls_scope);
 static void resolve_use_decl(CodeGen *g, TldUsingNamespace *tld_using_namespace, ScopeDecls *dest_decls_scope);
+static void analyze_fn_async(CodeGen *g, ZigFn *fn, bool resolve_frame);
 
 // nullptr means not analyzed yet; this one means currently being analyzed
 static const AstNode *inferred_async_checking = reinterpret_cast<AstNode *>(0x1);
@@ -4196,6 +4197,54 @@ static void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn) {
     }
 }
 
+// ErrorNone - not async
+// ErrorIsAsync - yes async
+// ErrorSemanticAnalyzeFail - compile error emitted result is invalid
+static Error analyze_callee_async(CodeGen *g, ZigFn *fn, ZigFn *callee, AstNode *call_node,
+        bool must_not_be_async)
+{
+    if (callee->type_entry->data.fn.fn_type_id.cc != CallingConventionUnspecified)
+        return ErrorNone;
+    if (callee->anal_state == FnAnalStateReady) {
+        analyze_fn_body(g, callee);
+        if (callee->anal_state == FnAnalStateInvalid) {
+            return ErrorSemanticAnalyzeFail;
+        }
+    }
+    bool callee_is_async;
+    if (callee->anal_state == FnAnalStateComplete) {
+        analyze_fn_async(g, callee, true);
+        if (callee->anal_state == FnAnalStateInvalid) {
+            return ErrorSemanticAnalyzeFail;
+        }
+        callee_is_async = fn_is_async(callee);
+    } else {
+        // If it's already been determined, use that value. Otherwise
+        // assume non-async, emit an error later if it turned out to be async.
+        if (callee->inferred_async_node == nullptr ||
+            callee->inferred_async_node == inferred_async_checking)
+        {
+            callee->assumed_non_async = call_node;
+            callee_is_async = false;
+        } else {
+            callee_is_async = callee->inferred_async_node != inferred_async_none;
+        }
+    }
+    if (callee_is_async) {
+        fn->inferred_async_node = call_node;
+        fn->inferred_async_fn = callee;
+        if (must_not_be_async) {
+            ErrorMsg *msg = add_node_error(g, fn->proto_node,
+                buf_sprintf("function with calling convention '%s' cannot be async",
+                    calling_convention_name(fn->type_entry->data.fn.fn_type_id.cc)));
+            add_async_error_notes(g, msg, fn);
+            return ErrorSemanticAnalyzeFail;
+        }
+        return ErrorIsAsync;
+    }
+    return ErrorNone;
+}
+
 // This function resolves functions being inferred async.
 static void analyze_fn_async(CodeGen *g, ZigFn *fn, bool resolve_frame) {
     if (fn->inferred_async_node == inferred_async_checking) {
@@ -4222,47 +4271,40 @@ static void analyze_fn_async(CodeGen *g, ZigFn *fn, bool resolve_frame) {
 
     for (size_t i = 0; i < fn->call_list.length; i += 1) {
         IrInstructionCallGen *call = fn->call_list.at(i);
-        ZigFn *callee = call->fn_entry;
-        if (callee == nullptr) {
+        if (call->fn_entry == nullptr) {
             // TODO function pointer call here, could be anything
             continue;
         }
-
-        if (callee->type_entry->data.fn.fn_type_id.cc != CallingConventionUnspecified)
-            continue;
-        if (callee->anal_state == FnAnalStateReady) {
-            analyze_fn_body(g, callee);
-            if (callee->anal_state == FnAnalStateInvalid) {
+        switch (analyze_callee_async(g, fn, call->fn_entry, call->base.source_node, must_not_be_async)) {
+            case ErrorSemanticAnalyzeFail:
                 fn->anal_state = FnAnalStateInvalid;
                 return;
-            }
-        }
-        if (callee->anal_state != FnAnalStateComplete) {
-            add_node_error(g, call->base.source_node,
-                buf_sprintf("call to function '%s' depends on itself", buf_ptr(&callee->symbol_name)));
-            fn->anal_state = FnAnalStateInvalid;
-            return;
-        }
-        analyze_fn_async(g, callee, true);
-        if (callee->anal_state == FnAnalStateInvalid) {
-            fn->anal_state = FnAnalStateInvalid;
-            return;
+            case ErrorNone:
+                continue;
+            case ErrorIsAsync:
+                if (resolve_frame) {
+                    resolve_async_fn_frame(g, fn);
+                }
+                return;
+            default:
+                zig_unreachable();
         }
-        if (fn_is_async(callee)) {
-            fn->inferred_async_node = call->base.source_node;
-            fn->inferred_async_fn = callee;
-            if (must_not_be_async) {
-                ErrorMsg *msg = add_node_error(g, fn->proto_node,
-                    buf_sprintf("function with calling convention '%s' cannot be async",
-                        calling_convention_name(fn->type_entry->data.fn.fn_type_id.cc)));
-                add_async_error_notes(g, msg, fn);
+    }
+    for (size_t i = 0; i < fn->await_list.length; i += 1) {
+        IrInstructionAwaitGen *await = fn->await_list.at(i);
+        switch (analyze_callee_async(g, fn, await->target_fn, await->base.source_node, must_not_be_async)) {
+            case ErrorSemanticAnalyzeFail:
                 fn->anal_state = FnAnalStateInvalid;
                 return;
-            }
-            if (resolve_frame) {
-                resolve_async_fn_frame(g, fn);
-            }
-            return;
+            case ErrorNone:
+                continue;
+            case ErrorIsAsync:
+                if (resolve_frame) {
+                    resolve_async_fn_frame(g, fn);
+                }
+                return;
+            default:
+                zig_unreachable();
         }
     }
     fn->inferred_async_node = inferred_async_none;
src/codegen.cpp
@@ -3924,7 +3924,7 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
         LLVMBuildStore(g->builder, awaiter_init_val, awaiter_ptr);
 
         if (ret_has_bits) {
-            LLVMValueRef ret_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_ret_start + 2, "");
+            ret_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_ret_start + 2, "");
             LLVMValueRef ret_ptr_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_ret_start, "");
             LLVMBuildStore(g->builder, ret_ptr, ret_ptr_ptr);
 
@@ -4067,6 +4067,9 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
         LLVMValueRef store_instr = LLVMBuildStore(g->builder, result, result_loc);
         LLVMSetAlignment(store_instr, get_ptr_align(g, instruction->result_loc->value.type));
         return result_loc;
+    } else if (!callee_is_async && instruction->is_async) {
+        LLVMBuildStore(g->builder, result, ret_ptr);
+        return result_loc;
     } else {
         return result;
     }
@@ -5498,6 +5501,44 @@ static LLVMValueRef ir_render_suspend_finish(CodeGen *g, IrExecutable *executabl
     return nullptr;
 }
 
+static LLVMValueRef gen_await_early_return(CodeGen *g, IrInstruction *source_instr,
+        LLVMValueRef target_frame_ptr, ZigType *result_type, ZigType *ptr_result_type,
+        LLVMValueRef result_loc, bool non_async)
+{
+    LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
+    LLVMValueRef their_result_ptr = nullptr;
+    if (type_has_bits(result_type) && (non_async || result_loc != nullptr)) {
+        LLVMValueRef their_result_ptr_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr, frame_ret_start, "");
+        their_result_ptr = LLVMBuildLoad(g->builder, their_result_ptr_ptr, "");
+        if (result_loc != nullptr) {
+            LLVMTypeRef ptr_u8 = LLVMPointerType(LLVMInt8Type(), 0);
+            LLVMValueRef dest_ptr_casted = LLVMBuildBitCast(g->builder, result_loc, ptr_u8, "");
+            LLVMValueRef src_ptr_casted = LLVMBuildBitCast(g->builder, their_result_ptr, ptr_u8, "");
+            bool is_volatile = false;
+            uint32_t abi_align = get_abi_alignment(g, result_type);
+            LLVMValueRef byte_count_val = LLVMConstInt(usize_type_ref, type_size(g, result_type), false);
+            ZigLLVMBuildMemCpy(g->builder,
+                    dest_ptr_casted, abi_align,
+                    src_ptr_casted, abi_align, byte_count_val, is_volatile);
+        }
+    }
+    if (codegen_fn_has_err_ret_tracing_arg(g, result_type)) {
+        LLVMValueRef their_trace_ptr_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr,
+                frame_index_trace_arg(g, result_type), "");
+        LLVMValueRef src_trace_ptr = LLVMBuildLoad(g->builder, their_trace_ptr_ptr, "");
+        LLVMValueRef dest_trace_ptr = get_cur_err_ret_trace_val(g, source_instr->scope);
+        LLVMValueRef args[] = { dest_trace_ptr, src_trace_ptr };
+        ZigLLVMBuildCall(g->builder, get_merge_err_ret_traces_fn_val(g), args, 2,
+                get_llvm_cc(g, CallingConventionUnspecified), ZigLLVM_FnInlineAuto, "");
+    }
+    if (non_async && type_has_bits(result_type)) {
+        LLVMValueRef result_ptr = (result_loc == nullptr) ? their_result_ptr : result_loc;
+        return get_handle_value(g, result_ptr, result_type, ptr_result_type);
+    } else {
+        return nullptr;
+    }
+}
+
 static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInstructionAwaitGen *instruction) {
     LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
     LLVMValueRef zero = LLVMConstNull(usize_type_ref);
@@ -5505,6 +5546,14 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst
     ZigType *result_type = instruction->base.value.type;
     ZigType *ptr_result_type = get_pointer_to_type(g, result_type, true);
 
+    LLVMValueRef result_loc = (instruction->result_loc == nullptr) ?
+        nullptr : ir_llvm_value(g, instruction->result_loc);
+
+    if (instruction->target_fn != nullptr && !fn_is_async(instruction->target_fn)) {
+        return gen_await_early_return(g, &instruction->base, target_frame_ptr, result_type,
+                ptr_result_type, result_loc, true);
+    }
+
     // Prepare to be suspended
     LLVMBasicBlockRef resume_bb = gen_suspend_begin(g, "AwaitResume");
     LLVMBasicBlockRef end_bb = LLVMAppendBasicBlock(g->cur_fn_val, "AwaitEnd");
@@ -5512,9 +5561,8 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst
     // At this point resuming the function will continue from resume_bb.
     // This code is as if it is running inside the suspend block.
 
+
     // supply the awaiter return pointer
-    LLVMValueRef result_loc = (instruction->result_loc == nullptr) ?
-        nullptr : ir_llvm_value(g, instruction->result_loc);
     if (type_has_bits(result_type)) {
         LLVMValueRef awaiter_ret_ptr_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr, frame_ret_start + 1, "");
         if (result_loc == nullptr) {
@@ -5562,28 +5610,8 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst
     // Early return: The async function has already completed. We must copy the result and
     // the error return trace if applicable.
     LLVMPositionBuilderAtEnd(g->builder, early_return_block);
-    if (type_has_bits(result_type) && result_loc != nullptr) {
-        LLVMValueRef their_result_ptr_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr, frame_ret_start, "");
-        LLVMValueRef their_result_ptr = LLVMBuildLoad(g->builder, their_result_ptr_ptr, "");
-        LLVMTypeRef ptr_u8 = LLVMPointerType(LLVMInt8Type(), 0);
-        LLVMValueRef dest_ptr_casted = LLVMBuildBitCast(g->builder, result_loc, ptr_u8, "");
-        LLVMValueRef src_ptr_casted = LLVMBuildBitCast(g->builder, their_result_ptr, ptr_u8, "");
-        bool is_volatile = false;
-        uint32_t abi_align = get_abi_alignment(g, result_type);
-        LLVMValueRef byte_count_val = LLVMConstInt(usize_type_ref, type_size(g, result_type), false);
-        ZigLLVMBuildMemCpy(g->builder,
-                dest_ptr_casted, abi_align,
-                src_ptr_casted, abi_align, byte_count_val, is_volatile);
-    }
-    if (codegen_fn_has_err_ret_tracing_arg(g, result_type)) {
-        LLVMValueRef their_trace_ptr_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr,
-                frame_index_trace_arg(g, result_type), "");
-        LLVMValueRef src_trace_ptr = LLVMBuildLoad(g->builder, their_trace_ptr_ptr, "");
-        LLVMValueRef dest_trace_ptr = get_cur_err_ret_trace_val(g, instruction->base.scope);
-        LLVMValueRef args[] = { dest_trace_ptr, src_trace_ptr };
-        ZigLLVMBuildCall(g->builder, get_merge_err_ret_traces_fn_val(g), args, 2,
-                get_llvm_cc(g, CallingConventionUnspecified), ZigLLVM_FnInlineAuto, "");
-    }
+    gen_await_early_return(g, &instruction->base, target_frame_ptr, result_type, ptr_result_type,
+            result_loc, false);
     LLVMBuildBr(g->builder, end_bb);
 
     LLVMPositionBuilderAtEnd(g->builder, resume_bb);
src/error.cpp
@@ -56,6 +56,7 @@ const char *err_str(Error err) {
         case ErrorNoSpaceLeft: return "no space left";
         case ErrorNoCCompilerInstalled: return "no C compiler installed";
         case ErrorNotLazy: return "not lazy";
+        case ErrorIsAsync: return "is async";
     }
     return "(invalid error)";
 }
src/ir.cpp
@@ -3268,7 +3268,7 @@ static IrInstruction *ir_build_await_src(IrBuilder *irb, Scope *scope, AstNode *
     return &instruction->base;
 }
 
-static IrInstruction *ir_build_await_gen(IrAnalyze *ira, IrInstruction *source_instruction,
+static IrInstructionAwaitGen *ir_build_await_gen(IrAnalyze *ira, IrInstruction *source_instruction,
         IrInstruction *frame, ZigType *result_type, IrInstruction *result_loc)
 {
     IrInstructionAwaitGen *instruction = ir_build_instruction<IrInstructionAwaitGen>(&ira->new_irb,
@@ -3280,7 +3280,7 @@ static IrInstruction *ir_build_await_gen(IrAnalyze *ira, IrInstruction *source_i
     ir_ref_instruction(frame, ira->new_irb.current_basic_block);
     if (result_loc != nullptr) ir_ref_instruction(result_loc, ira->new_irb.current_basic_block);
 
-    return &instruction->base;
+    return instruction;
 }
 
 static IrInstruction *ir_build_resume(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *frame) {
@@ -24763,18 +24763,22 @@ static IrInstruction *ir_analyze_instruction_suspend_finish(IrAnalyze *ira,
 }
 
 static IrInstruction *analyze_frame_ptr_to_anyframe_T(IrAnalyze *ira, IrInstruction *source_instr,
-        IrInstruction *frame_ptr)
+        IrInstruction *frame_ptr, ZigFn **target_fn)
 {
     if (type_is_invalid(frame_ptr->value.type))
         return ira->codegen->invalid_instruction;
 
+    *target_fn = nullptr;
+
     ZigType *result_type;
     IrInstruction *frame;
     if (frame_ptr->value.type->id == ZigTypeIdPointer &&
         frame_ptr->value.type->data.pointer.ptr_len == PtrLenSingle &&
         frame_ptr->value.type->data.pointer.child_type->id == ZigTypeIdFnFrame)
     {
-        result_type = frame_ptr->value.type->data.pointer.child_type->data.frame.fn->type_entry->data.fn.fn_type_id.return_type;
+        ZigFn *func = frame_ptr->value.type->data.pointer.child_type->data.frame.fn;
+        result_type = func->type_entry->data.fn.fn_type_id.return_type;
+        *target_fn = func;
         frame = frame_ptr;
     } else {
         frame = ir_get_deref(ira, source_instr, frame_ptr, nullptr);
@@ -24782,7 +24786,9 @@ static IrInstruction *analyze_frame_ptr_to_anyframe_T(IrAnalyze *ira, IrInstruct
             frame->value.type->data.pointer.ptr_len == PtrLenSingle &&
             frame->value.type->data.pointer.child_type->id == ZigTypeIdFnFrame)
         {
-            result_type = frame->value.type->data.pointer.child_type->data.frame.fn->type_entry->data.fn.fn_type_id.return_type;
+            ZigFn *func = frame->value.type->data.pointer.child_type->data.frame.fn;
+            result_type = func->type_entry->data.fn.fn_type_id.return_type;
+            *target_fn = func;
         } else if (frame->value.type->id != ZigTypeIdAnyFrame ||
             frame->value.type->data.any_frame.result_type == nullptr)
         {
@@ -24803,7 +24809,11 @@ static IrInstruction *analyze_frame_ptr_to_anyframe_T(IrAnalyze *ira, IrInstruct
 }
 
 static IrInstruction *ir_analyze_instruction_await(IrAnalyze *ira, IrInstructionAwaitSrc *instruction) {
-    IrInstruction *frame = analyze_frame_ptr_to_anyframe_T(ira, &instruction->base, instruction->frame->child);
+    IrInstruction *operand = instruction->frame->child;
+    if (type_is_invalid(operand->value.type))
+        return ira->codegen->invalid_instruction;
+    ZigFn *target_fn;
+    IrInstruction *frame = analyze_frame_ptr_to_anyframe_T(ira, &instruction->base, operand, &target_fn);
     if (type_is_invalid(frame->value.type))
         return ira->codegen->invalid_instruction;
 
@@ -24812,8 +24822,11 @@ static IrInstruction *ir_analyze_instruction_await(IrAnalyze *ira, IrInstruction
     ZigFn *fn_entry = exec_fn_entry(ira->new_irb.exec);
     ir_assert(fn_entry != nullptr, &instruction->base);
 
-    if (fn_entry->inferred_async_node == nullptr) {
-        fn_entry->inferred_async_node = instruction->base.source_node;
+    // If it's not @Frame(func) then it's definitely a suspend point
+    if (target_fn == nullptr) {
+        if (fn_entry->inferred_async_node == nullptr) {
+            fn_entry->inferred_async_node = instruction->base.source_node;
+        }
     }
 
     if (type_can_fail(result_type)) {
@@ -24830,8 +24843,10 @@ static IrInstruction *ir_analyze_instruction_await(IrAnalyze *ira, IrInstruction
         result_loc = nullptr;
     }
 
-    IrInstruction *result = ir_build_await_gen(ira, &instruction->base, frame, result_type, result_loc);
-    return ir_finish_anal(ira, result);
+    IrInstructionAwaitGen *result = ir_build_await_gen(ira, &instruction->base, frame, result_type, result_loc);
+    result->target_fn = target_fn;
+    fn_entry->await_list.append(result);
+    return ir_finish_anal(ira, &result->base);
 }
 
 static IrInstruction *ir_analyze_instruction_resume(IrAnalyze *ira, IrInstructionResume *instruction) {
src/userland.h
@@ -76,6 +76,7 @@ enum Error {
     ErrorBrokenPipe,
     ErrorNoSpaceLeft,
     ErrorNotLazy,
+    ErrorIsAsync,
 };
 
 // ABI warning
test/stage1/behavior/async_fn.zig
@@ -844,3 +844,13 @@ test "cast fn to async fn when it is inferred to be async" {
     resume S.frame;
     expect(S.ok);
 }
+
+test "await does not force async if callee is blocking" {
+    const S = struct {
+        fn simple() i32 {
+            return 1234;
+        }
+    };
+    var x = async S.simple();
+    expect(await x == 1234);
+}