Commit 1dd0c3d49f

Andrew Kelley <andrew@ziglang.org>
2019-08-01 22:41:30
fix calling an inferred async function
1 parent e7ae4e4
Changed files (5)
src/all_types.hpp
@@ -2605,7 +2605,6 @@ struct IrInstructionCallGen {
     IrInstruction **args;
     IrInstruction *result_loc;
     IrInstruction *frame_result_loc;
-    IrBasicBlock *resume_block;
 
     IrInstruction *new_stack;
     FnInline fn_inline;
src/analyze.cpp
@@ -5185,13 +5185,6 @@ static Error resolve_coro_frame(CodeGen *g, ZigType *frame_type) {
         if (!fn_is_async(callee))
             continue;
 
-        IrBasicBlock *new_resume_block = allocate<IrBasicBlock>(1);
-        new_resume_block->name_hint = "CallResume";
-        new_resume_block->split_llvm_fn = reinterpret_cast<LLVMValueRef>(0x1);
-        fn->resume_blocks.append(new_resume_block);
-        call->resume_block = new_resume_block;
-        fn->analyzed_executable.basic_block_list.append(new_resume_block);
-
         ZigType *callee_frame_type = get_coro_frame_type(g, callee);
 
         IrInstructionAllocaGen *alloca_gen = allocate<IrInstructionAllocaGen>(1);
src/codegen.cpp
@@ -3327,6 +3327,92 @@ static void set_call_instr_sret(CodeGen *g, LLVMValueRef call_instr) {
     LLVMAddCallSiteAttribute(call_instr, 1, sret_attr);
 }
 
+static void render_async_spills(CodeGen *g) {
+    ZigType *fn_type = g->cur_fn->type_entry;
+    ZigType *import = get_scope_import(&g->cur_fn->fndef_scope->base);
+    size_t async_var_index = coro_arg_start + (type_has_bits(fn_type->data.fn.fn_type_id.return_type) ? 2 : 0);
+    for (size_t var_i = 0; var_i < g->cur_fn->variable_list.length; var_i += 1) {
+        ZigVar *var = g->cur_fn->variable_list.at(var_i);
+
+        if (!type_has_bits(var->var_type)) {
+            continue;
+        }
+        if (ir_get_var_is_comptime(var))
+            continue;
+        switch (type_requires_comptime(g, var->var_type)) {
+            case ReqCompTimeInvalid:
+                zig_unreachable();
+            case ReqCompTimeYes:
+                continue;
+            case ReqCompTimeNo:
+                break;
+        }
+        if (var->src_arg_index == SIZE_MAX) {
+            continue;
+        }
+
+        var->value_ref = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, async_var_index,
+                buf_ptr(&var->name));
+        async_var_index += 1;
+        if (var->decl_node) {
+            var->di_loc_var = ZigLLVMCreateAutoVariable(g->dbuilder, get_di_scope(g, var->parent_scope),
+                buf_ptr(&var->name), import->data.structure.root_struct->di_file,
+                (unsigned)(var->decl_node->line + 1),
+                get_llvm_di_type(g, var->var_type), !g->strip_debug_symbols, 0);
+            gen_var_debug_decl(g, var);
+        }
+    }
+    for (size_t alloca_i = 0; alloca_i < g->cur_fn->alloca_gen_list.length; alloca_i += 1) {
+        IrInstructionAllocaGen *instruction = g->cur_fn->alloca_gen_list.at(alloca_i);
+        ZigType *ptr_type = instruction->base.value.type;
+        assert(ptr_type->id == ZigTypeIdPointer);
+        ZigType *child_type = ptr_type->data.pointer.child_type;
+        if (!type_has_bits(child_type))
+            continue;
+        if (instruction->base.ref_count == 0)
+            continue;
+        if (instruction->base.value.special != ConstValSpecialRuntime) {
+            if (const_ptr_pointee(nullptr, g, &instruction->base.value, nullptr)->special !=
+                    ConstValSpecialRuntime)
+            {
+                continue;
+            }
+        }
+        instruction->base.llvm_value = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, async_var_index,
+                instruction->name_hint);
+        async_var_index += 1;
+    }
+}
+
+static void render_async_var_decls(CodeGen *g, Scope *scope) {
+    render_async_spills(g);
+    for (;;) {
+        switch (scope->id) {
+            case ScopeIdCImport:
+                zig_unreachable();
+            case ScopeIdFnDef:
+                return;
+            case ScopeIdVarDecl: {
+                ZigVar *var = reinterpret_cast<ScopeVarDecl *>(scope)->var;
+                if (var->ptr_instruction != nullptr) {
+                    render_decl_var(g, var);
+                }
+                // fallthrough
+            }
+            case ScopeIdDecls:
+            case ScopeIdBlock:
+            case ScopeIdDefer:
+            case ScopeIdDeferExpr:
+            case ScopeIdLoop:
+            case ScopeIdSuspend:
+            case ScopeIdCompTime:
+            case ScopeIdRuntime:
+                scope = scope->parent;
+                continue;
+        }
+    }
+}
+
 static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstructionCallGen *instruction) {
     LLVMValueRef fn_val;
     ZigType *fn_type;
@@ -3431,15 +3517,19 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
         ZigLLVMBuildCall(g->builder, fn_val, &frame_result_loc, 1, llvm_cc, fn_inline, "");
         return nullptr;
     } else if (callee_is_async) {
+        LLVMValueRef split_llvm_fn = make_fn_llvm_value(g, g->cur_fn);
         LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_fn_ptr_index, "");
-        LLVMValueRef new_fn_ptr = instruction->resume_block->split_llvm_fn;
-        LLVMBuildStore(g->builder, new_fn_ptr, fn_ptr_ptr);
+        LLVMBuildStore(g->builder, split_llvm_fn, fn_ptr_ptr);
 
         LLVMValueRef call_inst = ZigLLVMBuildCall(g->builder, fn_val, &frame_result_loc, 1, llvm_cc, fn_inline, "");
         ZigLLVMSetTailCall(call_inst);
         LLVMBuildRetVoid(g->builder);
 
-        LLVMPositionBuilderAtEnd(g->builder, instruction->resume_block->llvm_block);
+        g->cur_fn_val = split_llvm_fn;
+        g->cur_ret_ptr = LLVMGetParam(split_llvm_fn, 0);
+        LLVMBasicBlockRef call_bb = LLVMAppendBasicBlock(split_llvm_fn, "CallResume");
+        LLVMPositionBuilderAtEnd(g->builder, call_bb);
+        render_async_var_decls(g, instruction->base.scope);
         return nullptr;
     }
 
@@ -5193,92 +5283,6 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
     zig_unreachable();
 }
 
-static void render_async_spills(CodeGen *g) {
-    ZigType *fn_type = g->cur_fn->type_entry;
-    ZigType *import = get_scope_import(&g->cur_fn->fndef_scope->base);
-    size_t async_var_index = coro_arg_start + (type_has_bits(fn_type->data.fn.fn_type_id.return_type) ? 2 : 0);
-    for (size_t var_i = 0; var_i < g->cur_fn->variable_list.length; var_i += 1) {
-        ZigVar *var = g->cur_fn->variable_list.at(var_i);
-
-        if (!type_has_bits(var->var_type)) {
-            continue;
-        }
-        if (ir_get_var_is_comptime(var))
-            continue;
-        switch (type_requires_comptime(g, var->var_type)) {
-            case ReqCompTimeInvalid:
-                zig_unreachable();
-            case ReqCompTimeYes:
-                continue;
-            case ReqCompTimeNo:
-                break;
-        }
-        if (var->src_arg_index == SIZE_MAX) {
-            continue;
-        }
-
-        var->value_ref = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, async_var_index,
-                buf_ptr(&var->name));
-        async_var_index += 1;
-        if (var->decl_node) {
-            var->di_loc_var = ZigLLVMCreateAutoVariable(g->dbuilder, get_di_scope(g, var->parent_scope),
-                buf_ptr(&var->name), import->data.structure.root_struct->di_file,
-                (unsigned)(var->decl_node->line + 1),
-                get_llvm_di_type(g, var->var_type), !g->strip_debug_symbols, 0);
-            gen_var_debug_decl(g, var);
-        }
-    }
-    for (size_t alloca_i = 0; alloca_i < g->cur_fn->alloca_gen_list.length; alloca_i += 1) {
-        IrInstructionAllocaGen *instruction = g->cur_fn->alloca_gen_list.at(alloca_i);
-        ZigType *ptr_type = instruction->base.value.type;
-        assert(ptr_type->id == ZigTypeIdPointer);
-        ZigType *child_type = ptr_type->data.pointer.child_type;
-        if (!type_has_bits(child_type))
-            continue;
-        if (instruction->base.ref_count == 0)
-            continue;
-        if (instruction->base.value.special != ConstValSpecialRuntime) {
-            if (const_ptr_pointee(nullptr, g, &instruction->base.value, nullptr)->special !=
-                    ConstValSpecialRuntime)
-            {
-                continue;
-            }
-        }
-        instruction->base.llvm_value = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, async_var_index,
-                instruction->name_hint);
-        async_var_index += 1;
-    }
-}
-
-static void render_async_var_decls(CodeGen *g, Scope *scope) {
-    render_async_spills(g);
-    for (;;) {
-        switch (scope->id) {
-            case ScopeIdCImport:
-                zig_unreachable();
-            case ScopeIdFnDef:
-                return;
-            case ScopeIdVarDecl: {
-                ZigVar *var = reinterpret_cast<ScopeVarDecl *>(scope)->var;
-                if (var->ptr_instruction != nullptr) {
-                    render_decl_var(g, var);
-                }
-                // fallthrough
-            }
-            case ScopeIdDecls:
-            case ScopeIdBlock:
-            case ScopeIdDefer:
-            case ScopeIdDeferExpr:
-            case ScopeIdLoop:
-            case ScopeIdSuspend:
-            case ScopeIdCompTime:
-            case ScopeIdRuntime:
-                scope = scope->parent;
-                continue;
-        }
-    }
-}
-
 static void ir_render(CodeGen *g, ZigFn *fn_entry) {
     assert(fn_entry);
 
test/stage1/behavior/coroutines.zig
@@ -82,55 +82,55 @@ test "local variable in async function" {
     S.doTheTest();
 }
 
-//test "calling an inferred async function" {
-//    const S = struct {
-//        var x: i32 = 1;
-//        var other_frame: *@Frame(other) = undefined;
-//
-//        fn doTheTest() void {
-//            const p = async first();
-//            expect(x == 1);
-//            resume other_frame.*;
-//            expect(x == 2);
-//        }
-//
-//        fn first() void {
-//            other();
-//        }
-//        fn other() void {
-//            other_frame = @frame();
-//            suspend;
-//            x += 1;
-//        }
-//    };
-//    S.doTheTest();
-//}
-//
-//test "@frameSize" {
-//    const S = struct {
-//        fn doTheTest() void {
-//            {
-//                var ptr = @ptrCast(async fn(i32) void, other);
-//                const size = @frameSize(ptr);
-//                expect(size == @sizeOf(@Frame(other)));
-//            }
-//            {
-//                var ptr = @ptrCast(async fn() void, first);
-//                const size = @frameSize(ptr);
-//                expect(size == @sizeOf(@Frame(first)));
-//            }
-//        }
-//
-//        fn first() void {
-//            other(1);
-//        }
-//        fn other(param: i32) void {
-//            var local: i32 = undefined;
-//            suspend;
-//        }
-//    };
-//    S.doTheTest();
-//}
+test "calling an inferred async function" {
+    const S = struct {
+        var x: i32 = 1;
+        var other_frame: *@Frame(other) = undefined;
+
+        fn doTheTest() void {
+            const p = async first();
+            expect(x == 1);
+            resume other_frame.*;
+            expect(x == 2);
+        }
+
+        fn first() void {
+            other();
+        }
+        fn other() void {
+            other_frame = @frame();
+            suspend;
+            x += 1;
+        }
+    };
+    S.doTheTest();
+}
+
+test "@frameSize" {
+    const S = struct {
+        fn doTheTest() void {
+            {
+                var ptr = @ptrCast(async fn(i32) void, other);
+                const size = @frameSize(ptr);
+                expect(size == @sizeOf(@Frame(other)));
+            }
+            {
+                var ptr = @ptrCast(async fn() void, first);
+                const size = @frameSize(ptr);
+                expect(size == @sizeOf(@Frame(first)));
+            }
+        }
+
+        fn first() void {
+            other(1);
+        }
+        fn other(param: i32) void {
+            var local: i32 = undefined;
+            suspend;
+        }
+    };
+    S.doTheTest();
+}
 
 //test "coroutine suspend, resume" {
 //    seq('a');
BRANCH_TODO
@@ -1,5 +1,3 @@
- * fix @frameSize
- * fix calling an inferred async function
  * await
  * await of a non async function
  * await in single-threaded mode