Commit 9ca8d9e21a
Changed files (4)
test
stage1
behavior
src/analyze.cpp
@@ -4232,31 +4232,40 @@ static Error analyze_callee_async(CodeGen *g, ZigFn *fn, ZigFn *callee, AstNode
{
if (modifier == CallModifierNoAsync)
return ErrorNone;
- if (callee->type_entry->data.fn.fn_type_id.cc != CallingConventionUnspecified)
- return ErrorNone;
- if (callee->anal_state == FnAnalStateReady) {
- analyze_fn_body(g, callee);
- if (callee->anal_state == FnAnalStateInvalid) {
- return ErrorSemanticAnalyzeFail;
- }
+ bool callee_is_async = false;
+ switch (callee->type_entry->data.fn.fn_type_id.cc) {
+ case CallingConventionUnspecified:
+ break;
+ case CallingConventionAsync:
+ callee_is_async = true;
+ break;
+ default:
+ return ErrorNone;
}
- bool callee_is_async;
- if (callee->anal_state == FnAnalStateComplete) {
- analyze_fn_async(g, callee, true);
- if (callee->anal_state == FnAnalStateInvalid) {
- return ErrorSemanticAnalyzeFail;
+ if (!callee_is_async) {
+ if (callee->anal_state == FnAnalStateReady) {
+ analyze_fn_body(g, callee);
+ if (callee->anal_state == FnAnalStateInvalid) {
+ return ErrorSemanticAnalyzeFail;
+ }
}
- callee_is_async = fn_is_async(callee);
- } else {
- // If it's already been determined, use that value. Otherwise
- // assume non-async, emit an error later if it turned out to be async.
- if (callee->inferred_async_node == nullptr ||
- callee->inferred_async_node == inferred_async_checking)
- {
- callee->assumed_non_async = call_node;
- callee_is_async = false;
+ if (callee->anal_state == FnAnalStateComplete) {
+ analyze_fn_async(g, callee, true);
+ if (callee->anal_state == FnAnalStateInvalid) {
+ return ErrorSemanticAnalyzeFail;
+ }
+ callee_is_async = fn_is_async(callee);
} else {
- callee_is_async = callee->inferred_async_node != inferred_async_none;
+ // If it's already been determined, use that value. Otherwise
+ // assume non-async, emit an error later if it turned out to be async.
+ if (callee->inferred_async_node == nullptr ||
+ callee->inferred_async_node == inferred_async_checking)
+ {
+ callee->assumed_non_async = call_node;
+ callee_is_async = false;
+ } else {
+ callee_is_async = callee->inferred_async_node != inferred_async_none;
+ }
}
}
if (callee_is_async) {
@@ -4333,6 +4342,8 @@ static void analyze_fn_async(CodeGen *g, ZigFn *fn, bool resolve_frame) {
}
for (size_t i = 0; i < fn->await_list.length; i += 1) {
IrInstructionAwaitGen *await = fn->await_list.at(i);
+ // TODO If this is a noasync await, it doesn't count
+ // https://github.com/ziglang/zig/issues/3157
switch (analyze_callee_async(g, fn, await->target_fn, await->base.source_node, must_not_be_async,
CallModifierNone))
{
@@ -5771,15 +5782,39 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
if (!fn_is_async(callee))
continue;
- IrInstructionAllocaGen *alloca_gen = allocate<IrInstructionAllocaGen>(1);
- alloca_gen->base.id = IrInstructionIdAllocaGen;
- alloca_gen->base.source_node = call->base.source_node;
- alloca_gen->base.scope = call->base.scope;
- alloca_gen->base.value.type = get_pointer_to_type(g, callee_frame_type, false);
- alloca_gen->base.ref_count = 1;
- alloca_gen->name_hint = "";
- fn->alloca_gen_list.append(alloca_gen);
- call->frame_result_loc = &alloca_gen->base;
+ call->frame_result_loc = ir_create_alloca(g, call->base.scope, call->base.source_node, fn,
+ callee_frame_type, "");
+ }
+ // Since this frame is async, an await might represent a suspend point, and
+ // therefore need to spill.
+ for (size_t i = 0; i < fn->await_list.length; i += 1) {
+ IrInstructionAwaitGen *await = fn->await_list.at(i);
+ // TODO If this is a noasync await, it doesn't need to spill
+ // https://github.com/ziglang/zig/issues/3157
+ if (await->result_loc != nullptr) {
+ // If there's a result location, that is the spill
+ continue;
+ }
+ if (!type_has_bits(await->base.value.type))
+ continue;
+ if (await->base.value.special != ConstValSpecialRuntime)
+ continue;
+ if (await->base.ref_count == 0)
+ continue;
+ if (await->target_fn != nullptr) {
+ // we might not need to suspend
+ analyze_fn_async(g, await->target_fn, false);
+ if (await->target_fn->anal_state == FnAnalStateInvalid) {
+ frame_type->data.frame.locals_struct = g->builtin_types.entry_invalid;
+ return ErrorSemanticAnalyzeFail;
+ }
+ if (!fn_is_async(await->target_fn)) {
+ // This await does not represent a suspend point. No spill needed.
+ continue;
+ }
+ }
+ await->result_loc = ir_create_alloca(g, await->base.scope, await->base.source_node, fn,
+ await->base.value.type, "");
}
FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
ZigType *ptr_return_type = get_pointer_to_type(g, fn_type_id->return_type, false);
@@ -8505,3 +8540,18 @@ void src_assert(bool ok, AstNode *source_node) {
const char *msg = "assertion failed. This is a bug in the Zig compiler.";
stage2_panic(msg, strlen(msg));
}
+
+IrInstruction *ir_create_alloca(CodeGen *g, Scope *scope, AstNode *source_node, ZigFn *fn,
+ ZigType *var_type, const char *name_hint)
+{
+ IrInstructionAllocaGen *alloca_gen = allocate<IrInstructionAllocaGen>(1);
+ alloca_gen->base.id = IrInstructionIdAllocaGen;
+ alloca_gen->base.source_node = source_node;
+ alloca_gen->base.scope = scope;
+ alloca_gen->base.value.type = get_pointer_to_type(g, var_type, false);
+ alloca_gen->base.ref_count = 1;
+ alloca_gen->name_hint = name_hint;
+ fn->alloca_gen_list.append(alloca_gen);
+ return &alloca_gen->base;
+}
+
src/analyze.hpp
@@ -258,4 +258,8 @@ ZigType *resolve_struct_field_type(CodeGen *g, TypeStructField *struct_field);
void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn);
+IrInstruction *ir_create_alloca(CodeGen *g, Scope *scope, AstNode *source_node, ZigFn *fn,
+ ZigType *var_type, const char *name_hint);
+
+
#endif
src/codegen.cpp
@@ -1661,6 +1661,14 @@ static LLVMValueRef ir_llvm_value(CodeGen *g, IrInstruction *instruction) {
if (!type_has_bits(instruction->value.type))
return nullptr;
if (!instruction->llvm_value) {
+ if (instruction->id == IrInstructionIdAwaitGen) {
+ IrInstructionAwaitGen *await = reinterpret_cast<IrInstructionAwaitGen*>(instruction);
+ if (await->result_loc != nullptr) {
+ instruction->llvm_value = get_handle_value(g, ir_llvm_value(g, await->result_loc),
+ await->result_loc->value.type->data.pointer.child_type, await->result_loc->value.type);
+ return instruction->llvm_value;
+ }
+ }
src_assert(instruction->value.special != ConstValSpecialRuntime, instruction->source_node);
assert(instruction->value.type);
render_const_val(g, &instruction->value, "");
@@ -5645,7 +5653,6 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst
// At this point resuming the function will continue from resume_bb.
// This code is as if it is running inside the suspend block.
-
// supply the awaiter return pointer
if (type_has_bits(result_type)) {
LLVMValueRef awaiter_ret_ptr_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr, frame_ret_start + 1, "");
@@ -5703,9 +5710,8 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst
LLVMBuildBr(g->builder, end_bb);
LLVMPositionBuilderAtEnd(g->builder, end_bb);
- if (type_has_bits(result_type) && result_loc != nullptr) {
- return get_handle_value(g, result_loc, result_type, ptr_result_type);
- }
+ // Rely on the spill for the llvm_value to be populated.
+ // See the implementation of ir_llvm_value.
return nullptr;
}
@@ -7153,15 +7159,8 @@ static void do_code_gen(CodeGen *g) {
if (call->frame_result_loc != nullptr)
continue;
ZigType *callee_frame_type = get_fn_frame_type(g, call->fn_entry);
- IrInstructionAllocaGen *alloca_gen = allocate<IrInstructionAllocaGen>(1);
- alloca_gen->base.id = IrInstructionIdAllocaGen;
- alloca_gen->base.source_node = call->base.source_node;
- alloca_gen->base.scope = call->base.scope;
- alloca_gen->base.value.type = get_pointer_to_type(g, callee_frame_type, false);
- alloca_gen->base.ref_count = 1;
- alloca_gen->name_hint = "";
- fn_table_entry->alloca_gen_list.append(alloca_gen);
- call->frame_result_loc = &alloca_gen->base;
+ call->frame_result_loc = ir_create_alloca(g, call->base.scope, call->base.source_node,
+ fn_table_entry, callee_frame_type, "");
}
// allocate temporary stack data
for (size_t alloca_i = 0; alloca_i < fn_table_entry->alloca_gen_list.length; alloca_i += 1) {
test/stage1/behavior/async_fn.zig
@@ -1108,3 +1108,19 @@ test "noasync function call" {
};
S.doTheTest();
}
+
+test "await used in expression and awaiting fn with no suspend but async calling convention" {
+ const S = struct {
+ fn atest() void {
+ var f1 = async add(1, 2);
+ var f2 = async add(3, 4);
+
+ const sum = (await f1) + (await f2);
+ expect(sum == 10);
+ }
+ async fn add(a: i32, b: i32) i32 {
+ return a + b;
+ }
+ };
+ _ = async S.atest();
+}