Commit 9ca8d9e21a

Andrew Kelley <andrew@ziglang.org>
2019-09-06 22:17:39
fix await used in an expression generating bad LLVM
1 parent 9423d38
Changed files (4)
src/analyze.cpp
@@ -4232,31 +4232,40 @@ static Error analyze_callee_async(CodeGen *g, ZigFn *fn, ZigFn *callee, AstNode
 {
     if (modifier == CallModifierNoAsync)
         return ErrorNone;
-    if (callee->type_entry->data.fn.fn_type_id.cc != CallingConventionUnspecified)
-        return ErrorNone;
-    if (callee->anal_state == FnAnalStateReady) {
-        analyze_fn_body(g, callee);
-        if (callee->anal_state == FnAnalStateInvalid) {
-            return ErrorSemanticAnalyzeFail;
-        }
+    bool callee_is_async = false;
+    switch (callee->type_entry->data.fn.fn_type_id.cc) {
+        case CallingConventionUnspecified:
+            break;
+        case CallingConventionAsync:
+            callee_is_async = true;
+            break;
+        default:
+            return ErrorNone;
     }
-    bool callee_is_async;
-    if (callee->anal_state == FnAnalStateComplete) {
-        analyze_fn_async(g, callee, true);
-        if (callee->anal_state == FnAnalStateInvalid) {
-            return ErrorSemanticAnalyzeFail;
+    if (!callee_is_async) {
+        if (callee->anal_state == FnAnalStateReady) {
+            analyze_fn_body(g, callee);
+            if (callee->anal_state == FnAnalStateInvalid) {
+                return ErrorSemanticAnalyzeFail;
+            }
         }
-        callee_is_async = fn_is_async(callee);
-    } else {
-        // If it's already been determined, use that value. Otherwise
-        // assume non-async, emit an error later if it turned out to be async.
-        if (callee->inferred_async_node == nullptr ||
-            callee->inferred_async_node == inferred_async_checking)
-        {
-            callee->assumed_non_async = call_node;
-            callee_is_async = false;
+        if (callee->anal_state == FnAnalStateComplete) {
+            analyze_fn_async(g, callee, true);
+            if (callee->anal_state == FnAnalStateInvalid) {
+                return ErrorSemanticAnalyzeFail;
+            }
+            callee_is_async = fn_is_async(callee);
         } else {
-            callee_is_async = callee->inferred_async_node != inferred_async_none;
+            // If it's already been determined, use that value. Otherwise
+            // assume non-async, emit an error later if it turned out to be async.
+            if (callee->inferred_async_node == nullptr ||
+                callee->inferred_async_node == inferred_async_checking)
+            {
+                callee->assumed_non_async = call_node;
+                callee_is_async = false;
+            } else {
+                callee_is_async = callee->inferred_async_node != inferred_async_none;
+            }
         }
     }
     if (callee_is_async) {
@@ -4333,6 +4342,8 @@ static void analyze_fn_async(CodeGen *g, ZigFn *fn, bool resolve_frame) {
     }
     for (size_t i = 0; i < fn->await_list.length; i += 1) {
         IrInstructionAwaitGen *await = fn->await_list.at(i);
+        // TODO If this is a noasync await, it doesn't count
+        // https://github.com/ziglang/zig/issues/3157
         switch (analyze_callee_async(g, fn, await->target_fn, await->base.source_node, must_not_be_async,
                     CallModifierNone))
         {
@@ -5771,15 +5782,39 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
         if (!fn_is_async(callee))
             continue;
 
-        IrInstructionAllocaGen *alloca_gen = allocate<IrInstructionAllocaGen>(1);
-        alloca_gen->base.id = IrInstructionIdAllocaGen;
-        alloca_gen->base.source_node = call->base.source_node;
-        alloca_gen->base.scope = call->base.scope;
-        alloca_gen->base.value.type = get_pointer_to_type(g, callee_frame_type, false);
-        alloca_gen->base.ref_count = 1;
-        alloca_gen->name_hint = "";
-        fn->alloca_gen_list.append(alloca_gen);
-        call->frame_result_loc = &alloca_gen->base;
+        call->frame_result_loc = ir_create_alloca(g, call->base.scope, call->base.source_node, fn,
+                callee_frame_type, "");
+    }
+    // Since this frame is async, an await might represent a suspend point, and
+    // therefore need to spill.
+    for (size_t i = 0; i < fn->await_list.length; i += 1) {
+        IrInstructionAwaitGen *await = fn->await_list.at(i);
+        // TODO If this is a noasync await, it doesn't need to spill
+        // https://github.com/ziglang/zig/issues/3157
+        if (await->result_loc != nullptr) {
+            // If there's a result location, that is the spill
+            continue;
+        }
+        if (!type_has_bits(await->base.value.type))
+            continue;
+        if (await->base.value.special != ConstValSpecialRuntime)
+            continue;
+        if (await->base.ref_count == 0)
+            continue;
+        if (await->target_fn != nullptr) {
+            // we might not need to suspend
+            analyze_fn_async(g, await->target_fn, false);
+            if (await->target_fn->anal_state == FnAnalStateInvalid) {
+                frame_type->data.frame.locals_struct = g->builtin_types.entry_invalid;
+                return ErrorSemanticAnalyzeFail;
+            }
+            if (!fn_is_async(await->target_fn)) {
+                // This await does not represent a suspend point. No spill needed.
+                continue;
+            }
+        }
+        await->result_loc = ir_create_alloca(g, await->base.scope, await->base.source_node, fn,
+                await->base.value.type, "");
     }
     FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
     ZigType *ptr_return_type = get_pointer_to_type(g, fn_type_id->return_type, false);
@@ -8505,3 +8540,18 @@ void src_assert(bool ok, AstNode *source_node) {
     const char *msg = "assertion failed. This is a bug in the Zig compiler.";
     stage2_panic(msg, strlen(msg));
 }
+
+IrInstruction *ir_create_alloca(CodeGen *g, Scope *scope, AstNode *source_node, ZigFn *fn,
+        ZigType *var_type, const char *name_hint)
+{
+    IrInstructionAllocaGen *alloca_gen = allocate<IrInstructionAllocaGen>(1);
+    alloca_gen->base.id = IrInstructionIdAllocaGen;
+    alloca_gen->base.source_node = source_node;
+    alloca_gen->base.scope = scope;
+    alloca_gen->base.value.type = get_pointer_to_type(g, var_type, false);
+    alloca_gen->base.ref_count = 1;
+    alloca_gen->name_hint = name_hint;
+    fn->alloca_gen_list.append(alloca_gen);
+    return &alloca_gen->base;
+}
+
src/analyze.hpp
@@ -258,4 +258,8 @@ ZigType *resolve_struct_field_type(CodeGen *g, TypeStructField *struct_field);
 
 void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn);
 
+IrInstruction *ir_create_alloca(CodeGen *g, Scope *scope, AstNode *source_node, ZigFn *fn,
+        ZigType *var_type, const char *name_hint);
+
+
 #endif
src/codegen.cpp
@@ -1661,6 +1661,14 @@ static LLVMValueRef ir_llvm_value(CodeGen *g, IrInstruction *instruction) {
     if (!type_has_bits(instruction->value.type))
         return nullptr;
     if (!instruction->llvm_value) {
+        if (instruction->id == IrInstructionIdAwaitGen) {
+            IrInstructionAwaitGen *await = reinterpret_cast<IrInstructionAwaitGen*>(instruction);
+            if (await->result_loc != nullptr) {
+                instruction->llvm_value = get_handle_value(g, ir_llvm_value(g, await->result_loc),
+                    await->result_loc->value.type->data.pointer.child_type, await->result_loc->value.type);
+                return instruction->llvm_value;
+            }
+        }
         src_assert(instruction->value.special != ConstValSpecialRuntime, instruction->source_node);
         assert(instruction->value.type);
         render_const_val(g, &instruction->value, "");
@@ -5645,7 +5653,6 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst
     // At this point resuming the function will continue from resume_bb.
     // This code is as if it is running inside the suspend block.
 
-
     // supply the awaiter return pointer
     if (type_has_bits(result_type)) {
         LLVMValueRef awaiter_ret_ptr_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr, frame_ret_start + 1, "");
@@ -5703,9 +5710,8 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst
     LLVMBuildBr(g->builder, end_bb);
 
     LLVMPositionBuilderAtEnd(g->builder, end_bb);
-    if (type_has_bits(result_type) && result_loc != nullptr) {
-        return get_handle_value(g, result_loc, result_type, ptr_result_type);
-    }
+    // Rely on the spill for the llvm_value to be populated.
+    // See the implementation of ir_llvm_value.
     return nullptr;
 }
 
@@ -7153,15 +7159,8 @@ static void do_code_gen(CodeGen *g) {
                 if (call->frame_result_loc != nullptr)
                     continue;
                 ZigType *callee_frame_type = get_fn_frame_type(g, call->fn_entry);
-                IrInstructionAllocaGen *alloca_gen = allocate<IrInstructionAllocaGen>(1);
-                alloca_gen->base.id = IrInstructionIdAllocaGen;
-                alloca_gen->base.source_node = call->base.source_node;
-                alloca_gen->base.scope = call->base.scope;
-                alloca_gen->base.value.type = get_pointer_to_type(g, callee_frame_type, false);
-                alloca_gen->base.ref_count = 1;
-                alloca_gen->name_hint = "";
-                fn_table_entry->alloca_gen_list.append(alloca_gen);
-                call->frame_result_loc = &alloca_gen->base;
+                call->frame_result_loc = ir_create_alloca(g, call->base.scope, call->base.source_node,
+                        fn_table_entry, callee_frame_type, "");
             }
             // allocate temporary stack data
             for (size_t alloca_i = 0; alloca_i < fn_table_entry->alloca_gen_list.length; alloca_i += 1) {
test/stage1/behavior/async_fn.zig
@@ -1108,3 +1108,19 @@ test "noasync function call" {
     };
     S.doTheTest();
 }
+
+test "await used in expression and awaiting fn with no suspend but async calling convention" {
+    const S = struct {
+        fn atest() void {
+            var f1 = async add(1, 2);
+            var f2 = async add(3, 4);
+
+            const sum = (await f1) + (await f2);
+            expect(sum == 10);
+        }
+        async fn add(a: i32, b: i32) i32 {
+            return a + b;
+        }
+    };
+    _ = async S.atest();
+}