Commit 7e9760de10

Andrew Kelley <andrew@ziglang.org>
2019-07-24 00:54:45
inferring async from async calls
1 parent 317d1ec
src/all_types.hpp
@@ -1336,6 +1336,11 @@ struct GlobalExport {
     GlobalLinkageId linkage;
 };
 
+struct FnCall {
+    AstNode *source_node;
+    ZigFn *callee;
+};
+
 struct ZigFn {
     CodeGen *codegen;
     LLVMValueRef llvm_value;
@@ -1379,8 +1384,10 @@ struct ZigFn {
     AstNode *set_alignstack_node;
 
     AstNode *set_cold_node;
+    const AstNode *inferred_async_node;
 
     ZigList<GlobalExport> export_list;
+    ZigList<FnCall> call_list;
 
     LLVMValueRef valgrind_client_request_array;
     LLVMBasicBlockRef preamble_llvm_block;
src/analyze.cpp
@@ -31,6 +31,11 @@ static void resolve_llvm_types(CodeGen *g, ZigType *type, ResolveStatus wanted_r
 static void preview_use_decl(CodeGen *g, TldUsingNamespace *using_namespace, ScopeDecls *dest_decls_scope);
 static void resolve_use_decl(CodeGen *g, TldUsingNamespace *tld_using_namespace, ScopeDecls *dest_decls_scope);
 
+// nullptr means not analyzed yet; this one means currently being analyzed
+static const AstNode *inferred_async_checking = reinterpret_cast<AstNode *>(0x1);
+// this one means analyzed and it's not async
+static const AstNode *inferred_async_none = reinterpret_cast<AstNode *>(0x2);
+
 static bool is_top_level_struct(ZigType *import) {
     return import->id == ZigTypeIdStruct && import->data.structure.root_struct != nullptr;
 }
@@ -1892,8 +1897,12 @@ static Error resolve_coro_frame(CodeGen *g, ZigType *frame_type) {
     field_names.append("resume_index");
     field_types.append(g->builtin_types.entry_usize);
 
-    for (size_t arg_i = 0; arg_i < fn->type_entry->data.fn.fn_type_id.param_count; arg_i += 1) {
-        FnTypeParamInfo *param_info = &fn->type_entry->data.fn.fn_type_id.param_info[arg_i];
+    FnTypeId *fn_type_id = &fn->type_entry->data.fn.fn_type_id;
+    field_names.append("result");
+    field_types.append(fn_type_id->return_type);
+
+    for (size_t arg_i = 0; arg_i < fn_type_id->param_count; arg_i += 1) {
+        FnTypeParamInfo *param_info = &fn_type_id->param_info[arg_i];
         AstNode *param_decl_node = get_param_decl_node(fn, arg_i);
         Buf *param_name;
         bool is_var_args = param_decl_node && param_decl_node->data.param_decl.is_var_args;
@@ -2796,6 +2805,16 @@ static void resolve_decl_fn(CodeGen *g, TldFn *tld_fn) {
                 g->fn_defs.append(fn_table_entry);
         }
 
+        switch (fn_table_entry->type_entry->data.fn.fn_type_id.cc) {
+            case CallingConventionAsync:
+                fn_table_entry->inferred_async_node = fn_table_entry->proto_node;
+                break;
+            case CallingConventionUnspecified:
+                break;
+            default:
+                fn_table_entry->inferred_async_node = inferred_async_none;
+        }
+
         if (scope_is_root_decls(tld_fn->base.parent_scope) &&
             (import == g->root_import || import->data.structure.root_struct->package == g->panic_package))
         {
@@ -3767,6 +3786,55 @@ bool resolve_inferred_error_set(CodeGen *g, ZigType *err_set_type, AstNode *sour
     return true;
 }
 
+static void resolve_async_fn_frame(CodeGen *g, ZigFn *fn) {
+    ZigType *frame_type = get_coro_frame_type(g, fn);
+    Error err;
+    if ((err = type_resolve(g, frame_type, ResolveStatusSizeKnown))) {
+        fn->anal_state = FnAnalStateInvalid;
+        return;
+    }
+}
+
+bool fn_is_async(ZigFn *fn) {
+    assert(fn->inferred_async_node != nullptr);
+    assert(fn->inferred_async_node != inferred_async_checking);
+    return fn->inferred_async_node != inferred_async_none;
+}
+
+// This function resolves functions being inferred async.
+static void analyze_fn_async(CodeGen *g, ZigFn *fn) {
+    if (fn->inferred_async_node == inferred_async_checking) {
+        // TODO call graph cycle detected, disallow the recursion
+        fn->inferred_async_node = inferred_async_none;
+        return;
+    }
+    if (fn->inferred_async_node == inferred_async_none) {
+        return;
+    }
+    if (fn->inferred_async_node != nullptr) {
+        resolve_async_fn_frame(g, fn);
+        return;
+    }
+    fn->inferred_async_node = inferred_async_checking;
+    for (size_t i = 0; i < fn->call_list.length; i += 1) {
+        FnCall *call = &fn->call_list.at(i);
+        if (call->callee->type_entry->data.fn.fn_type_id.cc != CallingConventionUnspecified)
+            continue;
+        assert(call->callee->anal_state == FnAnalStateComplete);
+        analyze_fn_async(g, call->callee);
+        if (call->callee->anal_state == FnAnalStateInvalid) {
+            fn->anal_state = FnAnalStateInvalid;
+            return;
+        }
+        if (fn_is_async(call->callee)) {
+            fn->inferred_async_node = call->source_node;
+            resolve_async_fn_frame(g, fn);
+            return;
+        }
+    }
+    fn->inferred_async_node = inferred_async_none;
+}
+
 static void analyze_fn_ir(CodeGen *g, ZigFn *fn_table_entry, AstNode *return_type_node) {
     ZigType *fn_type = fn_table_entry->type_entry;
     assert(!fn_type->data.fn.is_generic);
@@ -3824,17 +3892,7 @@ static void analyze_fn_ir(CodeGen *g, ZigFn *fn_table_entry, AstNode *return_typ
         ir_print(g, stderr, &fn_table_entry->analyzed_executable, 4);
         fprintf(stderr, "}\n");
     }
-
     fn_table_entry->anal_state = FnAnalStateComplete;
-
-    if (fn_table_entry->resume_blocks.length != 0) {
-        ZigType *frame_type = get_coro_frame_type(g, fn_table_entry);
-        Error err;
-        if ((err = type_resolve(g, frame_type, ResolveStatusSizeKnown))) {
-            fn_table_entry->anal_state = FnAnalStateInvalid;
-            return;
-        }
-    }
 }
 
 static void analyze_fn_body(CodeGen *g, ZigFn *fn_table_entry) {
@@ -4004,6 +4062,16 @@ void semantic_analyze(CodeGen *g) {
             analyze_fn_body(g, fn_entry);
         }
     }
+
+    if (g->errors.length != 0) {
+        return;
+    }
+
+    // second pass over functions for detecting async
+    for (g->fn_defs_index = 0; g->fn_defs_index < g->fn_defs.length; g->fn_defs_index += 1) {
+        ZigFn *fn_entry = g->fn_defs.at(g->fn_defs_index);
+        analyze_fn_async(g, fn_entry);
+    }
 }
 
 ZigType *get_int_type(CodeGen *g, bool is_signed, uint32_t size_in_bits) {
@@ -7173,11 +7241,7 @@ void resolve_llvm_types_fn(CodeGen *g, ZigFn *fn) {
     if (fn->raw_di_type != nullptr) return;
 
     ZigType *fn_type = fn->type_entry;
-    FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
-    bool cc_async = fn_type_id->cc == CallingConventionAsync;
-    bool inferred_async = fn->resume_blocks.length != 0;
-    bool is_async = cc_async || inferred_async;
-    if (!is_async) {
+    if (!fn_is_async(fn)) {
         resolve_llvm_types_fn_type(g, fn_type);
         fn->raw_type_ref = fn_type->data.fn.raw_type_ref;
         fn->raw_di_type = fn_type->data.fn.raw_di_type;
@@ -7223,8 +7287,6 @@ static void resolve_llvm_types_anyerror(CodeGen *g) {
 }
 
 static void resolve_llvm_types_coro_frame(CodeGen *g, ZigType *frame_type, ResolveStatus wanted_resolve_status) {
-    if (frame_type->llvm_di_type != nullptr) return;
-
     resolve_llvm_types_struct(g, frame_type->data.frame.locals_struct, wanted_resolve_status);
     frame_type->llvm_type = frame_type->data.frame.locals_struct->llvm_type;
     frame_type->llvm_di_type = frame_type->data.frame.locals_struct->llvm_di_type;
src/analyze.hpp
@@ -248,5 +248,6 @@ bool is_container(ZigType *type_entry);
 ConstExprValue *analyze_const_value(CodeGen *g, Scope *scope, AstNode *node, ZigType *type_entry, Buf *type_name);
 
 void resolve_llvm_types_fn(CodeGen *g, ZigFn *fn);
+bool fn_is_async(ZigFn *fn);
 
 #endif
src/codegen.cpp
@@ -371,7 +371,7 @@ static LLVMValueRef fn_llvm_value(CodeGen *g, ZigFn *fn_table_entry) {
         symbol_name = buf_sprintf("\x01_%s", buf_ptr(symbol_name));
     }
 
-    bool is_async = fn_table_entry->resume_blocks.length != 0 || cc == CallingConventionAsync;
+    bool is_async = fn_is_async(fn_table_entry);
 
 
     ZigType *fn_type = fn_table_entry->type_entry;
@@ -1847,7 +1847,7 @@ static bool iter_function_params_c_abi(CodeGen *g, ZigType *fn_type, FnWalk *fn_
                 }
                 case FnWalkIdInits: {
                     clear_debug_source_node(g);
-                    if (fn_walk->data.inits.fn->resume_blocks.length == 0) {
+                    if (!fn_is_async(fn_walk->data.inits.fn)) {
                         LLVMValueRef arg = LLVMGetParam(llvm_fn, fn_walk->data.inits.gen_i);
                         LLVMTypeRef ptr_to_int_type_ref = LLVMPointerType(LLVMIntType((unsigned)ty_size * 8), 0);
                         LLVMValueRef bitcasted = LLVMBuildBitCast(g->builder, var->value_ref, ptr_to_int_type_ref, "");
@@ -1945,7 +1945,7 @@ void walk_function_params(CodeGen *g, ZigType *fn_type, FnWalk *fn_walk) {
                 assert(variable);
                 assert(variable->value_ref);
 
-                if (!handle_is_ptr(variable->var_type) && fn_walk->data.inits.fn->resume_blocks.length == 0) {
+                if (!handle_is_ptr(variable->var_type) && !fn_is_async(fn_walk->data.inits.fn)) {
                     clear_debug_source_node(g);
                     ZigType *fn_type = fn_table_entry->type_entry;
                     unsigned gen_arg_index = fn_type->data.fn.gen_param_info[variable->src_arg_index].gen_index;
@@ -1986,7 +1986,7 @@ static LLVMValueRef ir_render_save_err_ret_addr(CodeGen *g, IrExecutable *execut
 }
 
 static LLVMValueRef ir_render_return(CodeGen *g, IrExecutable *executable, IrInstructionReturn *return_instruction) {
-    if (g->cur_fn->resume_blocks.length != 0) {
+    if (fn_is_async(g->cur_fn)) {
         if (ir_want_runtime_safety(g, &return_instruction->base)) {
             LLVMValueRef locals_ptr = g->cur_ret_ptr;
             LLVMValueRef resume_index_ptr = LLVMBuildStructGEP(g->builder, locals_ptr, coro_resume_index_index, "");
@@ -3387,8 +3387,10 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
     LLVMValueRef result;
 
     if (instruction->is_async) {
+        size_t ret_1_or_0 = type_has_bits(fn_type->data.fn.fn_type_id.return_type) ? 1 : 0;
         for (size_t arg_i = 0; arg_i < gen_param_values.length; arg_i += 1) {
-            LLVMValueRef arg_ptr = LLVMBuildStructGEP(g->builder, result_loc, coro_arg_start + arg_i, "");
+            LLVMValueRef arg_ptr = LLVMBuildStructGEP(g->builder, result_loc,
+                    coro_arg_start + ret_1_or_0 + arg_i, "");
             LLVMBuildStore(g->builder, gen_param_values.at(arg_i), arg_ptr);
         }
         ZigLLVMBuildCall(g->builder, fn_val, &result_loc, 1, llvm_cc, fn_inline, "");
@@ -5983,7 +5985,7 @@ static void build_all_basic_blocks(CodeGen *g, ZigFn *fn) {
     assert(executable->basic_block_list.length > 0);
     LLVMValueRef fn_val = fn_llvm_value(g, fn);
     LLVMBasicBlockRef first_bb = nullptr;
-    if (fn->resume_blocks.length != 0) {
+    if (fn_is_async(fn)) {
         first_bb = LLVMAppendBasicBlock(fn_val, "AsyncSwitch");
         fn->preamble_llvm_block = first_bb;
     }
@@ -6171,7 +6173,7 @@ static void do_code_gen(CodeGen *g) {
         build_all_basic_blocks(g, fn_table_entry);
         clear_debug_source_node(g);
 
-        bool is_async = cc == CallingConventionAsync || fn_table_entry->resume_blocks.length != 0;
+        bool is_async = fn_is_async(fn_table_entry);
 
         if (want_sret || is_async) {
             g->cur_ret_ptr = LLVMGetParam(fn, 0);
@@ -6261,7 +6263,9 @@ static void do_code_gen(CodeGen *g) {
                 fn_walk_var.data.vars.var = var;
                 iter_function_params_c_abi(g, fn_table_entry->type_entry, &fn_walk_var, var->src_arg_index);
             } else if (is_async) {
-                var->value_ref = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_arg_start + var_i, "");
+                size_t ret_1_or_0 = type_has_bits(fn_type_id->return_type) ? 1 : 0;
+                var->value_ref = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr,
+                        coro_arg_start + ret_1_or_0 + var_i, "");
                 if (var->decl_node) {
                     var->di_loc_var = ZigLLVMCreateAutoVariable(g->dbuilder, get_di_scope(g, var->parent_scope),
                         buf_ptr(&var->name), import->data.structure.root_struct->di_file,
src/ir.cpp
@@ -15383,6 +15383,13 @@ static IrInstruction *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCallSrc *c
             zig_panic("TODO async call");
         }
 
+        if (!call_instruction->is_async) {
+            if (impl_fn_type_id->cc == CallingConventionAsync && parent_fn_entry->inferred_async_node == nullptr) {
+                parent_fn_entry->inferred_async_node = fn_ref->source_node;
+            }
+            parent_fn_entry->call_list.append({call_instruction->base.source_node, impl_fn});
+        }
+
         IrInstruction *new_call_instruction = ir_build_call_gen(ira, &call_instruction->base,
                 impl_fn, nullptr, impl_param_count, casted_args, fn_inline,
                 call_instruction->is_async, casted_new_stack, result_loc,
@@ -15458,6 +15465,15 @@ static IrInstruction *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCallSrc *c
         return ira->codegen->invalid_instruction;
     }
 
+    if (!call_instruction->is_async) {
+        if (fn_type_id->cc == CallingConventionAsync && parent_fn_entry->inferred_async_node == nullptr) {
+            parent_fn_entry->inferred_async_node = fn_ref->source_node;
+        }
+        if (fn_entry != nullptr) {
+            parent_fn_entry->call_list.append({call_instruction->base.source_node, fn_entry});
+        }
+    }
+
     if (call_instruction->is_async) {
         IrInstruction *result = ir_analyze_async_call(ira, call_instruction, fn_entry, fn_type, fn_ref,
                 casted_args, call_param_count);
@@ -24142,6 +24158,9 @@ static IrInstruction *ir_analyze_instruction_suspend_br(IrAnalyze *ira, IrInstru
     new_bb->resume_index = fn_entry->resume_blocks.length + 2;
 
     fn_entry->resume_blocks.append(new_bb);
+    if (fn_entry->inferred_async_node == nullptr) {
+        fn_entry->inferred_async_node = instruction->base.source_node;
+    }
 
     ir_push_resume_block(ira, old_dest_block);
 
build.zig
@@ -375,7 +375,9 @@ fn addLibUserlandStep(b: *Builder) void {
     artifact.bundle_compiler_rt = true;
     artifact.setTarget(builtin.arch, builtin.os, builtin.abi);
     artifact.linkSystemLibrary("c");
-    artifact.linkSystemLibrary("ntdll");
+    if (builtin.os == .windows) {
+        artifact.linkSystemLibrary("ntdll");
+    }
     const libuserland_step = b.step("libuserland", "Build the userland compiler library for use in stage1");
     libuserland_step.dependOn(&artifact.step);