Commit 64c293f8a4

Andrew Kelley <andrew@ziglang.org>
2019-08-14 18:52:20
codegen for async call of blocking function
1 parent f3f838c
Changed files (2)
src/analyze.cpp
@@ -3831,7 +3831,7 @@ static void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn) {
 }
 
 // This function resolves functions being inferred async.
-static void analyze_fn_async(CodeGen *g, ZigFn *fn) {
+static void analyze_fn_async(CodeGen *g, ZigFn *fn, bool resolve_frame) {
     if (fn->inferred_async_node == inferred_async_checking) {
         // TODO call graph cycle detected, disallow the recursion
         fn->inferred_async_node = inferred_async_none;
@@ -3841,7 +3841,9 @@ static void analyze_fn_async(CodeGen *g, ZigFn *fn) {
         return;
     }
     if (fn->inferred_async_node != nullptr) {
-        resolve_async_fn_frame(g, fn);
+        if (resolve_frame) {
+            resolve_async_fn_frame(g, fn);
+        }
         return;
     }
     fn->inferred_async_node = inferred_async_checking;
@@ -3870,7 +3872,7 @@ static void analyze_fn_async(CodeGen *g, ZigFn *fn) {
             }
         }
         assert(callee->anal_state == FnAnalStateComplete);
-        analyze_fn_async(g, callee);
+        analyze_fn_async(g, callee, true);
         if (callee->anal_state == FnAnalStateInvalid) {
             fn->anal_state = FnAnalStateInvalid;
             return;
@@ -3886,7 +3888,9 @@ static void analyze_fn_async(CodeGen *g, ZigFn *fn) {
                 fn->anal_state = FnAnalStateInvalid;
                 return;
             }
-            resolve_async_fn_frame(g, fn);
+            if (resolve_frame) {
+                resolve_async_fn_frame(g, fn);
+            }
             return;
         }
     }
@@ -4141,7 +4145,7 @@ void semantic_analyze(CodeGen *g) {
     // second pass over functions for detecting async
     for (g->fn_defs_index = 0; g->fn_defs_index < g->fn_defs.length; g->fn_defs_index += 1) {
         ZigFn *fn_entry = g->fn_defs.at(g->fn_defs_index);
-        analyze_fn_async(g, fn_entry);
+        analyze_fn_async(g, fn_entry, true);
     }
 }
 
@@ -5212,6 +5216,36 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
             return ErrorSemanticAnalyzeFail;
         }
     }
+    analyze_fn_async(g, fn, false);
+    if (fn->anal_state == FnAnalStateInvalid)
+        return ErrorSemanticAnalyzeFail;
+
+    if (!fn_is_async(fn)) {
+        ZigType *fn_type = fn->type_entry;
+        FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
+        ZigType *ptr_return_type = get_pointer_to_type(g, fn_type_id->return_type, false);
+
+        // label (grep this): [fn_frame_struct_layout]
+        ZigList<SrcField> fields = {};
+
+        fields.append({"@fn_ptr", g->builtin_types.entry_usize, 0});
+        fields.append({"@resume_index", g->builtin_types.entry_usize, 0});
+        fields.append({"@awaiter", g->builtin_types.entry_usize, 0});
+        fields.append({"@prev_val", g->builtin_types.entry_usize, 0});
+
+        fields.append({"@result_ptr_callee", ptr_return_type, 0});
+        fields.append({"@result_ptr_awaiter", ptr_return_type, 0});
+        fields.append({"@result", fn_type_id->return_type, 0});
+
+        frame_type->data.frame.locals_struct = get_struct_type(g, buf_ptr(&frame_type->name),
+                fields.items, fields.length, target_fn_align(g->zig_target));
+        frame_type->abi_size = frame_type->data.frame.locals_struct->abi_size;
+        frame_type->abi_align = frame_type->data.frame.locals_struct->abi_align;
+        frame_type->size_in_bits = frame_type->data.frame.locals_struct->size_in_bits;
+
+        return ErrorNone;
+    }
+
     ZigType *fn_type = get_async_fn_type(g, fn->type_entry);
 
     if (fn->analyzed_executable.need_err_code_spill) {
@@ -5252,7 +5286,7 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
             frame_type->data.frame.locals_struct = g->builtin_types.entry_invalid;
             return ErrorSemanticAnalyzeFail;
         }
-        analyze_fn_async(g, callee);
+        analyze_fn_async(g, callee, true);
         if (!fn_is_async(callee))
             continue;
 
@@ -5268,6 +5302,8 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
         fn->alloca_gen_list.append(alloca_gen);
         call->frame_result_loc = &alloca_gen->base;
     }
+    FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
+    ZigType *ptr_return_type = get_pointer_to_type(g, fn_type_id->return_type, false);
 
     // label (grep this): [fn_frame_struct_layout]
     ZigList<SrcField> fields = {};
@@ -5277,9 +5313,6 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
     fields.append({"@awaiter", g->builtin_types.entry_usize, 0});
     fields.append({"@prev_val", g->builtin_types.entry_usize, 0});
 
-    FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
-    ZigType *ptr_return_type = get_pointer_to_type(g, fn_type_id->return_type, false);
-
     fields.append({"@result_ptr_callee", ptr_return_type, 0});
     fields.append({"@result_ptr_awaiter", ptr_return_type, 0});
     fields.append({"@result", fn_type_id->return_type, 0});
@@ -7651,7 +7684,8 @@ static void resolve_llvm_types_anyerror(CodeGen *g) {
 }
 
 static void resolve_llvm_types_async_frame(CodeGen *g, ZigType *frame_type, ResolveStatus wanted_resolve_status) {
-    resolve_llvm_types_struct(g, frame_type->data.frame.locals_struct, wanted_resolve_status, frame_type);
+    ZigType *passed_frame_type = fn_is_async(frame_type->data.frame.fn) ? frame_type : nullptr;
+    resolve_llvm_types_struct(g, frame_type->data.frame.locals_struct, wanted_resolve_status, passed_frame_type);
     frame_type->llvm_type = frame_type->data.frame.locals_struct->llvm_type;
     frame_type->llvm_di_type = frame_type->data.frame.locals_struct->llvm_di_type;
 }
src/codegen.cpp
@@ -3850,73 +3850,74 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
     LLVMValueRef frame_result_loc;
     LLVMValueRef awaiter_init_val;
     LLVMValueRef ret_ptr;
-    if (instruction->is_async) {
-        awaiter_init_val = zero;
-
-        if (instruction->new_stack == nullptr) {
-            frame_result_loc = result_loc;
-
-            if (ret_has_bits) {
-                // Use the result location which is inside the frame if this is an async call.
-                ret_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_ret_start + 2, "");
-            }
-        } else {
-            LLVMValueRef frame_slice_ptr = ir_llvm_value(g, instruction->new_stack);
-            if (ir_want_runtime_safety(g, &instruction->base)) {
-                LLVMValueRef given_len_ptr = LLVMBuildStructGEP(g->builder, frame_slice_ptr, slice_len_index, "");
-                LLVMValueRef given_frame_len = LLVMBuildLoad(g->builder, given_len_ptr, "");
-                LLVMValueRef actual_frame_len = gen_frame_size(g, fn_val);
-
-                LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "FrameSizeCheckFail");
-                LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "FrameSizeCheckOk");
-
-                LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntUGE, given_frame_len, actual_frame_len, "");
-                LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
+    if (callee_is_async) {
+        if (instruction->is_async) {
+            if (instruction->new_stack == nullptr) {
+                awaiter_init_val = zero;
+                frame_result_loc = result_loc;
+
+                if (ret_has_bits) {
+                    // Use the result location which is inside the frame if this is an async call.
+                    ret_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_ret_start + 2, "");
+                }
+            } else if (cc == CallingConventionAsync) {
+                awaiter_init_val = zero;
+                LLVMValueRef frame_slice_ptr = ir_llvm_value(g, instruction->new_stack);
+                if (ir_want_runtime_safety(g, &instruction->base)) {
+                    LLVMValueRef given_len_ptr = LLVMBuildStructGEP(g->builder, frame_slice_ptr, slice_len_index, "");
+                    LLVMValueRef given_frame_len = LLVMBuildLoad(g->builder, given_len_ptr, "");
+                    LLVMValueRef actual_frame_len = gen_frame_size(g, fn_val);
+
+                    LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "FrameSizeCheckFail");
+                    LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "FrameSizeCheckOk");
+
+                    LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntUGE, given_frame_len, actual_frame_len, "");
+                    LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
 
-                LLVMPositionBuilderAtEnd(g->builder, fail_block);
-                gen_safety_crash(g, PanicMsgIdFrameTooSmall);
+                    LLVMPositionBuilderAtEnd(g->builder, fail_block);
+                    gen_safety_crash(g, PanicMsgIdFrameTooSmall);
 
-                LLVMPositionBuilderAtEnd(g->builder, ok_block);
+                    LLVMPositionBuilderAtEnd(g->builder, ok_block);
+                }
+                LLVMValueRef frame_ptr_ptr = LLVMBuildStructGEP(g->builder, frame_slice_ptr, slice_ptr_index, "");
+                LLVMValueRef frame_ptr = LLVMBuildLoad(g->builder, frame_ptr_ptr, "");
+                frame_result_loc = LLVMBuildBitCast(g->builder, frame_ptr,
+                        get_llvm_type(g, instruction->base.value.type), "");
+
+                if (ret_has_bits) {
+                    // Use the result location provided to the @asyncCall builtin
+                    ret_ptr = result_loc;
+                }
             }
-            LLVMValueRef frame_ptr_ptr = LLVMBuildStructGEP(g->builder, frame_slice_ptr, slice_ptr_index, "");
-            LLVMValueRef frame_ptr = LLVMBuildLoad(g->builder, frame_ptr_ptr, "");
-            frame_result_loc = LLVMBuildBitCast(g->builder, frame_ptr,
-                    get_llvm_type(g, instruction->base.value.type), "");
 
+            // even if prefix_arg_err_ret_stack is true, let the async function do its own
+            // initialization.
+        } else {
+            frame_result_loc = ir_llvm_value(g, instruction->frame_result_loc);
+            awaiter_init_val = LLVMBuildPtrToInt(g->builder, g->cur_frame_ptr, usize_type_ref, ""); // caller's own frame pointer
             if (ret_has_bits) {
-                // Use the result location provided to the @asyncCall builtin
-                ret_ptr = result_loc;
-            }
-        }
+                if (result_loc == nullptr) {
+                    // return type is a scalar, but we still need a pointer to it. Use the async fn frame.
+                    ret_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_ret_start + 2, "");
+                } else {
+                    // Use the call instruction's result location.
+                    ret_ptr = result_loc;
+                }
 
-        // even if prefix_arg_err_ret_stack is true, let the async function do its own
-        // initialization.
-    } else if (callee_is_async) {
-        frame_result_loc = ir_llvm_value(g, instruction->frame_result_loc);
-        awaiter_init_val = LLVMBuildPtrToInt(g->builder, g->cur_frame_ptr, usize_type_ref, ""); // caller's own frame pointer
-        if (ret_has_bits) {
-            if (result_loc == nullptr) {
-                // return type is a scalar, but we still need a pointer to it. Use the async fn frame.
-                ret_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_ret_start + 2, "");
-            } else {
-                // Use the call instruction's result location.
-                ret_ptr = result_loc;
+                // Store a zero in the awaiter's result ptr to indicate we do not need a copy made.
+                LLVMValueRef awaiter_ret_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_ret_start + 1, "");
+                LLVMValueRef zero_ptr = LLVMConstNull(LLVMGetElementType(LLVMTypeOf(awaiter_ret_ptr)));
+                LLVMBuildStore(g->builder, zero_ptr, awaiter_ret_ptr);
             }
 
-            // Store a zero in the awaiter's result ptr to indicate we do not need a copy made.
-            LLVMValueRef awaiter_ret_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_ret_start + 1, "");
-            LLVMValueRef zero_ptr = LLVMConstNull(LLVMGetElementType(LLVMTypeOf(awaiter_ret_ptr)));
-            LLVMBuildStore(g->builder, zero_ptr, awaiter_ret_ptr);
+            if (prefix_arg_err_ret_stack) {
+                LLVMValueRef err_ret_trace_ptr_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc,
+                        frame_index_trace_arg(g, src_return_type), "");
+                LLVMValueRef my_err_ret_trace_val = get_cur_err_ret_trace_val(g, instruction->base.scope);
+                LLVMBuildStore(g->builder, my_err_ret_trace_val, err_ret_trace_ptr_ptr);
+            }
         }
 
-        if (prefix_arg_err_ret_stack) {
-            LLVMValueRef err_ret_trace_ptr_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc,
-                    frame_index_trace_arg(g, src_return_type), "");
-            LLVMValueRef my_err_ret_trace_val = get_cur_err_ret_trace_val(g, instruction->base.scope);
-            LLVMBuildStore(g->builder, my_err_ret_trace_val, err_ret_trace_ptr_ptr);
-        }
-    }
-    if (instruction->is_async || callee_is_async) {
         assert(frame_result_loc != nullptr);
 
         LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_fn_ptr_index, "");
@@ -3934,6 +3935,29 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
             LLVMValueRef ret_ptr_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_ret_start, "");
             LLVMBuildStore(g->builder, ret_ptr, ret_ptr_ptr);
         }
+    } else if (instruction->is_async) {
+        // Async call of blocking function
+        if (instruction->new_stack != nullptr) {
+            zig_panic("TODO @asyncCall of non-async function");
+        }
+        frame_result_loc = result_loc;
+        awaiter_init_val = LLVMConstAllOnes(usize_type_ref);
+
+        LLVMValueRef awaiter_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_awaiter_index, "");
+        LLVMBuildStore(g->builder, awaiter_init_val, awaiter_ptr);
+
+        if (ret_has_bits) {
+            LLVMValueRef ret_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_ret_start + 2, "");
+            LLVMValueRef ret_ptr_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_ret_start, "");
+            LLVMBuildStore(g->builder, ret_ptr, ret_ptr_ptr);
+
+            if (first_arg_ret) {
+                gen_param_values.append(ret_ptr);
+            }
+        }
+        if (prefix_arg_err_ret_stack) {
+            gen_param_values.append(get_cur_err_ret_trace_val(g, instruction->base.scope));
+        }
     } else {
         if (first_arg_ret) {
             gen_param_values.append(result_loc);
@@ -3966,7 +3990,7 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
     LLVMCallConv llvm_cc = get_llvm_cc(g, cc);
     LLVMValueRef result;
 
-    if (instruction->is_async || callee_is_async) {
+    if (callee_is_async) {
         uint32_t arg_start_i = frame_index_arg(g, fn_type->data.fn.fn_type_id.return_type);
 
         LLVMValueRef casted_frame;
@@ -3992,39 +4016,42 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
             gen_assign_raw(g, arg_ptr, get_pointer_to_type(g, gen_param_types.at(arg_i), true),
                     gen_param_values.at(arg_i));
         }
-    }
-    if (instruction->is_async) {
-        gen_resume(g, fn_val, frame_result_loc, ResumeIdCall, nullptr);
-        if (instruction->new_stack != nullptr) {
-            return frame_result_loc;
-        }
-        return nullptr;
-    } else if (callee_is_async) {
-        ZigType *ptr_result_type = get_pointer_to_type(g, src_return_type, true);
 
-        LLVMBasicBlockRef call_bb = gen_suspend_begin(g, "CallResume");
+        if (instruction->is_async) {
+            gen_resume(g, fn_val, frame_result_loc, ResumeIdCall, nullptr);
+            if (instruction->new_stack != nullptr) {
+                return frame_result_loc;
+            }
+            return nullptr;
+        } else {
+            ZigType *ptr_result_type = get_pointer_to_type(g, src_return_type, true);
 
-        LLVMValueRef call_inst = gen_resume(g, fn_val, frame_result_loc, ResumeIdCall, nullptr);
-        set_tail_call_if_appropriate(g, call_inst);
-        LLVMBuildRetVoid(g->builder);
+            LLVMBasicBlockRef call_bb = gen_suspend_begin(g, "CallResume");
+
+            LLVMValueRef call_inst = gen_resume(g, fn_val, frame_result_loc, ResumeIdCall, nullptr);
+            set_tail_call_if_appropriate(g, call_inst);
+            LLVMBuildRetVoid(g->builder);
 
-        LLVMPositionBuilderAtEnd(g->builder, call_bb);
-        gen_assert_resume_id(g, &instruction->base, ResumeIdReturn, PanicMsgIdResumedAnAwaitingFn, nullptr);
-        render_async_var_decls(g, instruction->base.scope);
+            LLVMPositionBuilderAtEnd(g->builder, call_bb);
+            gen_assert_resume_id(g, &instruction->base, ResumeIdReturn, PanicMsgIdResumedAnAwaitingFn, nullptr);
+            render_async_var_decls(g, instruction->base.scope);
 
-        if (!type_has_bits(src_return_type))
-            return nullptr;
+            if (!type_has_bits(src_return_type))
+                return nullptr;
 
-        if (result_loc != nullptr) 
-            return get_handle_value(g, result_loc, src_return_type, ptr_result_type);
+            if (result_loc != nullptr) 
+                return get_handle_value(g, result_loc, src_return_type, ptr_result_type);
 
-        LLVMValueRef result_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_ret_start + 2, "");
-        return LLVMBuildLoad(g->builder, result_ptr, "");
+            LLVMValueRef result_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, frame_ret_start + 2, "");
+            return LLVMBuildLoad(g->builder, result_ptr, "");
+        }
     }
 
     if (instruction->new_stack == nullptr) {
         result = ZigLLVMBuildCall(g->builder, fn_val,
                 gen_param_values.items, (unsigned)gen_param_values.length, llvm_cc, fn_inline, "");
+    } else if (instruction->is_async) {
+        zig_panic("TODO @asyncCall of non-async function");
     } else {
         LLVMValueRef stacksave_fn_val = get_stacksave_fn_val(g);
         LLVMValueRef stackrestore_fn_val = get_stackrestore_fn_val(g);