Commit d1a98ccff4
Changed files (7)
std
event
test
stage1
behavior
src/all_types.hpp
@@ -2124,6 +2124,7 @@ enum ScopeId {
ScopeIdCompTime,
ScopeIdRuntime,
ScopeIdTypeOf,
+ ScopeIdExpr,
};
struct Scope {
@@ -2271,6 +2272,24 @@ struct ScopeTypeOf {
Scope base;
};
+enum MemoizedBool {
+ MemoizedBoolUnknown,
+ MemoizedBoolFalse,
+ MemoizedBoolTrue,
+};
+
+// This scope is created for each expression.
+// It's used to identify when an instruction needs to be spilled,
+// so that it can be accessed after a suspend point.
+struct ScopeExpr {
+ Scope base;
+
+ ScopeExpr **children_ptr;
+ size_t children_len;
+
+ MemoizedBool need_spill;
+};
+
// synchronized with code in define_builtin_compile_vars
enum AtomicOrder {
AtomicOrderUnordered,
@@ -2510,6 +2529,10 @@ struct IrInstruction {
// with this child field.
IrInstruction *child;
IrBasicBlock *owner_bb;
+ // Nearly any instruction can have to be stored as a local variable before suspending
+ // and then loaded after resuming, in case there is an expression with a suspend point
+ // in it, such as: x + await y
+ IrInstruction *spill;
IrInstructionId id;
// true if this instruction was generated by zig and not from user code
bool is_gen;
src/analyze.cpp
@@ -96,6 +96,30 @@ static ScopeDecls **get_container_scope_ptr(ZigType *type_entry) {
zig_unreachable();
}
+static ScopeExpr *find_expr_scope(Scope *scope) {
+ for (;;) {
+ switch (scope->id) {
+ case ScopeIdExpr:
+ return reinterpret_cast<ScopeExpr *>(scope);
+ case ScopeIdDefer:
+ case ScopeIdDeferExpr:
+ case ScopeIdDecls:
+ case ScopeIdFnDef:
+ case ScopeIdCompTime:
+ case ScopeIdVarDecl:
+ case ScopeIdCImport:
+ case ScopeIdSuspend:
+ case ScopeIdTypeOf:
+ case ScopeIdBlock:
+ return nullptr;
+ case ScopeIdLoop:
+ case ScopeIdRuntime:
+ scope = scope->parent;
+ continue;
+ }
+ }
+}
+
ScopeDecls *get_container_scope(ZigType *type_entry) {
return *get_container_scope_ptr(type_entry);
}
@@ -203,6 +227,20 @@ Scope *create_typeof_scope(CodeGen *g, AstNode *node, Scope *parent) {
return &scope->base;
}
+Scope *create_expr_scope(CodeGen *g, AstNode *node, Scope *parent) {
+ ScopeExpr *scope = allocate<ScopeExpr>(1);
+ init_scope(g, &scope->base, ScopeIdExpr, node, parent);
+ ScopeExpr *parent_expr = find_expr_scope(parent);
+ if (parent_expr != nullptr) {
+ size_t new_len = parent_expr->children_len + 1;
+ parent_expr->children_ptr = reallocate_nonzero<ScopeExpr *>(
+ parent_expr->children_ptr, parent_expr->children_len, new_len);
+ parent_expr->children_ptr[parent_expr->children_len] = scope;
+ parent_expr->children_len = new_len;
+ }
+ return &scope->base;
+}
+
ZigType *get_scope_import(Scope *scope) {
while (scope) {
if (scope->id == ScopeIdDecls) {
@@ -5654,6 +5692,69 @@ static ZigType *get_async_fn_type(CodeGen *g, ZigType *orig_fn_type) {
return fn_type;
}
+// Traverse up to the very top ExprScope, which has children.
+// We have just arrived at the top from a child. That child,
+// and its next siblings, do not need to be marked. But the previous
+// siblings do.
+// x + (await y)
+// vs
+// (await y) + x
+static void mark_suspension_point(Scope *scope) {
+ ScopeExpr *child_expr_scope = (scope->id == ScopeIdExpr) ? reinterpret_cast<ScopeExpr *>(scope) : nullptr;
+ for (;;) {
+ scope = scope->parent;
+ switch (scope->id) {
+ case ScopeIdDefer:
+ case ScopeIdDeferExpr:
+ case ScopeIdDecls:
+ case ScopeIdFnDef:
+ case ScopeIdCompTime:
+ case ScopeIdVarDecl:
+ case ScopeIdCImport:
+ case ScopeIdSuspend:
+ case ScopeIdTypeOf:
+ case ScopeIdBlock:
+ return;
+ case ScopeIdLoop:
+ case ScopeIdRuntime:
+ continue;
+ case ScopeIdExpr: {
+ ScopeExpr *parent_expr_scope = reinterpret_cast<ScopeExpr *>(scope);
+ if (child_expr_scope != nullptr) {
+ for (size_t i = 0; parent_expr_scope->children_ptr[i] != child_expr_scope; i += 1) {
+ assert(i < parent_expr_scope->children_len);
+ parent_expr_scope->children_ptr[i]->need_spill = MemoizedBoolTrue;
+ }
+ }
+ parent_expr_scope->need_spill = MemoizedBoolTrue;
+ child_expr_scope = parent_expr_scope;
+ continue;
+ }
+ }
+ }
+}
+
+static bool scope_needs_spill(Scope *scope) {
+ ScopeExpr *scope_expr = find_expr_scope(scope);
+ if (scope_expr == nullptr) return false;
+
+ switch (scope_expr->need_spill) {
+ case MemoizedBoolUnknown:
+ if (scope_needs_spill(scope_expr->base.parent)) {
+ scope_expr->need_spill = MemoizedBoolTrue;
+ return true;
+ } else {
+ scope_expr->need_spill = MemoizedBoolFalse;
+ return false;
+ }
+ case MemoizedBoolFalse:
+ return false;
+ case MemoizedBoolTrue:
+ return true;
+ }
+ zig_unreachable();
+}
+
static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
Error err;
@@ -5786,21 +5887,17 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
callee_frame_type, "");
}
// Since this frame is async, an await might represent a suspend point, and
- // therefore need to spill.
+ // therefore need to spill. It also needs to mark expr scopes as having to spill.
+ // For example: foo() + await z
+ // The funtion call result of foo() must be spilled.
for (size_t i = 0; i < fn->await_list.length; i += 1) {
IrInstructionAwaitGen *await = fn->await_list.at(i);
- // TODO If this is a noasync await, it doesn't need to spill
+ // TODO If this is a noasync await, it doesn't suspend
// https://github.com/ziglang/zig/issues/3157
- if (await->result_loc != nullptr) {
- // If there's a result location, that is the spill
+ if (await->base.value.special != ConstValSpecialRuntime) {
+ // Known at comptime. No spill, no suspend.
continue;
}
- if (!type_has_bits(await->base.value.type))
- continue;
- if (await->base.value.special != ConstValSpecialRuntime)
- continue;
- if (await->base.ref_count == 0)
- continue;
if (await->target_fn != nullptr) {
// we might not need to suspend
analyze_fn_async(g, await->target_fn, false);
@@ -5809,13 +5906,53 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
return ErrorSemanticAnalyzeFail;
}
if (!fn_is_async(await->target_fn)) {
- // This await does not represent a suspend point. No spill needed.
+ // This await does not represent a suspend point. No spill needed,
+ // and no need to mark ExprScope.
continue;
}
}
+ // This await is a suspend point, but it might not need a spill.
+ // We do need to mark the ExprScope as having a suspend point in it.
+ mark_suspension_point(await->base.scope);
+
+ if (await->result_loc != nullptr) {
+ // If there's a result location, that is the spill
+ continue;
+ }
+ if (await->base.ref_count == 0)
+ continue;
+ if (!type_has_bits(await->base.value.type))
+ continue;
await->result_loc = ir_create_alloca(g, await->base.scope, await->base.source_node, fn,
await->base.value.type, "");
}
+ // Now that we've marked all the expr scopes that have to spill, we go over the instructions
+ // and spill the relevant ones.
+ for (size_t block_i = 0; block_i < fn->analyzed_executable.basic_block_list.length; block_i += 1) {
+ IrBasicBlock *block = fn->analyzed_executable.basic_block_list.at(block_i);
+ for (size_t instr_i = 0; instr_i < block->instruction_list.length; instr_i += 1) {
+ IrInstruction *instruction = block->instruction_list.at(instr_i);
+ if (instruction->id == IrInstructionIdAwaitGen ||
+ instruction->id == IrInstructionIdVarPtr ||
+ instruction->id == IrInstructionIdDeclRef ||
+ instruction->id == IrInstructionIdAllocaGen)
+ {
+ // This instruction does its own spilling specially, or otherwise doesn't need it.
+ continue;
+ }
+ if (instruction->value.special != ConstValSpecialRuntime)
+ continue;
+ if (instruction->ref_count == 0)
+ continue;
+ if (!type_has_bits(instruction->value.type))
+ continue;
+ if (scope_needs_spill(instruction->scope)) {
+ instruction->spill = ir_create_alloca(g, instruction->scope, instruction->source_node,
+ fn, instruction->value.type, "");
+ }
+ }
+ }
+
FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
ZigType *ptr_return_type = get_pointer_to_type(g, fn_type_id->return_type, false);
src/analyze.hpp
@@ -114,6 +114,7 @@ ScopeFnDef *create_fndef_scope(CodeGen *g, AstNode *node, Scope *parent, ZigFn *
Scope *create_comptime_scope(CodeGen *g, AstNode *node, Scope *parent);
Scope *create_runtime_scope(CodeGen *g, AstNode *node, Scope *parent, IrInstruction *is_comptime);
Scope *create_typeof_scope(CodeGen *g, AstNode *node, Scope *parent);
+Scope *create_expr_scope(CodeGen *g, AstNode *node, Scope *parent);
void init_const_str_lit(CodeGen *g, ConstExprValue *const_val, Buf *str);
ConstExprValue *create_const_str_lit(CodeGen *g, Buf *str);
@@ -261,5 +262,4 @@ void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn);
IrInstruction *ir_create_alloca(CodeGen *g, Scope *scope, AstNode *source_node, ZigFn *fn,
ZigType *var_type, const char *name_hint);
-
#endif
src/codegen.cpp
@@ -649,6 +649,7 @@ static ZigLLVMDIScope *get_di_scope(CodeGen *g, Scope *scope) {
case ScopeIdCompTime:
case ScopeIdRuntime:
case ScopeIdTypeOf:
+ case ScopeIdExpr:
return get_di_scope(g, scope->parent);
}
zig_unreachable();
@@ -1644,7 +1645,6 @@ static void gen_assign_raw(CodeGen *g, LLVMValueRef ptr, ZigType *ptr_type,
LLVMValueRef ored_value = LLVMBuildOr(g->builder, shifted_value, anded_containing_int, "");
gen_store(g, ored_value, ptr, ptr_type);
- return;
}
static void gen_var_debug_decl(CodeGen *g, ZigVar *var) {
@@ -1664,11 +1664,16 @@ static LLVMValueRef ir_llvm_value(CodeGen *g, IrInstruction *instruction) {
if (instruction->id == IrInstructionIdAwaitGen) {
IrInstructionAwaitGen *await = reinterpret_cast<IrInstructionAwaitGen*>(instruction);
if (await->result_loc != nullptr) {
- instruction->llvm_value = get_handle_value(g, ir_llvm_value(g, await->result_loc),
+ return get_handle_value(g, ir_llvm_value(g, await->result_loc),
await->result_loc->value.type->data.pointer.child_type, await->result_loc->value.type);
- return instruction->llvm_value;
}
}
+ if (instruction->spill != nullptr) {
+ ZigType *ptr_type = instruction->spill->value.type;
+ src_assert(ptr_type->id == ZigTypeIdPointer, instruction->source_node);
+ return get_handle_value(g, ir_llvm_value(g, instruction->spill),
+ ptr_type->data.pointer.child_type, instruction->spill->value.type);
+ }
src_assert(instruction->value.special != ConstValSpecialRuntime, instruction->source_node);
assert(instruction->value.type);
render_const_val(g, &instruction->value, "");
@@ -3786,6 +3791,7 @@ static void render_async_var_decls(CodeGen *g, Scope *scope) {
case ScopeIdCompTime:
case ScopeIdRuntime:
case ScopeIdTypeOf:
+ case ScopeIdExpr:
scope = scope->parent;
continue;
}
@@ -6049,6 +6055,11 @@ static void ir_render(CodeGen *g, ZigFn *fn_entry) {
set_debug_location(g, instruction);
}
instruction->llvm_value = ir_render_instruction(g, executable, instruction);
+ if (instruction->spill != nullptr) {
+ LLVMValueRef spill_ptr = ir_llvm_value(g, instruction->spill);
+ gen_assign_raw(g, spill_ptr, instruction->spill->value.type, instruction->llvm_value);
+ instruction->llvm_value = nullptr;
+ }
}
current_block->llvm_exit_block = LLVMGetInsertBlock(g->builder);
}
src/ir.cpp
@@ -3364,6 +3364,7 @@ static void ir_count_defers(IrBuilder *irb, Scope *inner_scope, Scope *outer_sco
case ScopeIdCompTime:
case ScopeIdRuntime:
case ScopeIdTypeOf:
+ case ScopeIdExpr:
scope = scope->parent;
continue;
case ScopeIdDeferExpr:
@@ -3420,6 +3421,7 @@ static bool ir_gen_defers_for_block(IrBuilder *irb, Scope *inner_scope, Scope *o
case ScopeIdCompTime:
case ScopeIdRuntime:
case ScopeIdTypeOf:
+ case ScopeIdExpr:
scope = scope->parent;
continue;
case ScopeIdDeferExpr:
@@ -8158,7 +8160,15 @@ static IrInstruction *ir_gen_node_extra(IrBuilder *irb, AstNode *node, Scope *sc
result_loc = no_result_loc();
ir_build_reset_result(irb, scope, node, result_loc);
}
- IrInstruction *result = ir_gen_node_raw(irb, node, scope, lval, result_loc);
+ Scope *child_scope;
+ if (irb->exec->is_inline ||
+ (irb->exec->fn_entry != nullptr && irb->exec->fn_entry->child_scope == scope))
+ {
+ child_scope = scope;
+ } else {
+ child_scope = create_expr_scope(irb->codegen, node, scope);
+ }
+ IrInstruction *result = ir_gen_node_raw(irb, node, child_scope, lval, result_loc);
if (result == irb->codegen->invalid_instruction) {
if (irb->exec->first_err_trace_msg == nullptr) {
irb->exec->first_err_trace_msg = irb->codegen->trace_err;
std/event/future.zig
@@ -104,11 +104,7 @@ fn testFuture(loop: *Loop) void {
var b = async waitOnFuture(&future);
resolveFuture(&future);
- // TODO https://github.com/ziglang/zig/issues/3077
- //const result = (await a) + (await b);
- const a_result = await a;
- const b_result = await b;
- const result = a_result + b_result;
+ const result = (await a) + (await b);
testing.expect(result == 12);
}
test/stage1/behavior/async_fn.zig
@@ -921,12 +921,10 @@ fn recursiveAsyncFunctionTest(comptime suspending_implementation: bool) type {
var sum: u32 = 0;
f1_awaited = true;
- const result_f1 = await f1; // TODO https://github.com/ziglang/zig/issues/3077
- sum += try result_f1;
+ sum += try await f1;
f2_awaited = true;
- const result_f2 = await f2; // TODO https://github.com/ziglang/zig/issues/3077
- sum += try result_f2;
+ sum += try await f2;
return sum;
}
@@ -943,8 +941,7 @@ fn recursiveAsyncFunctionTest(comptime suspending_implementation: bool) type {
fn amain(result: *u32) void {
var x = async fib(std.heap.direct_allocator, 10);
- const res = await x; // TODO https://github.com/ziglang/zig/issues/3077
- result.* = res catch unreachable;
+ result.* = (await x) catch unreachable;
}
};
}
@@ -1002,8 +999,7 @@ test "@asyncCall using the result location inside the frame" {
return 1234;
}
fn getAnswer(f: anyframe->i32, out: *i32) void {
- var res = await f; // TODO https://github.com/ziglang/zig/issues/3077
- out.* = res;
+ out.* = await f;
}
};
var data: i32 = 1;
@@ -1124,3 +1120,19 @@ test "await used in expression and awaiting fn with no suspend but async calling
};
_ = async S.atest();
}
+
+test "await used in expression after a fn call" {
+ const S = struct {
+ fn atest() void {
+ var f1 = async add(3, 4);
+ var sum: i32 = 0;
+ sum = foo() + await f1;
+ expect(sum == 8);
+ }
+ async fn add(a: i32, b: i32) i32 {
+ return a + b;
+ }
+ fn foo() i32 { return 1; }
+ };
+ _ = async S.atest();
+}