Commit d1a98ccff4

Andrew Kelley <andrew@ziglang.org>
2019-09-07 06:12:15
implement spills when expressions used across suspend points
closes #3077
1 parent 9ca8d9e
Changed files (7)
src/all_types.hpp
@@ -2124,6 +2124,7 @@ enum ScopeId {
     ScopeIdCompTime,
     ScopeIdRuntime,
     ScopeIdTypeOf,
+    ScopeIdExpr,
 };
 
 struct Scope {
@@ -2271,6 +2272,24 @@ struct ScopeTypeOf {
     Scope base;
 };
 
+enum MemoizedBool {
+    MemoizedBoolUnknown,
+    MemoizedBoolFalse,
+    MemoizedBoolTrue,
+};
+
+// This scope is created for each expression.
+// It's used to identify when an instruction needs to be spilled,
+// so that it can be accessed after a suspend point.
+struct ScopeExpr {
+    Scope base;
+
+    ScopeExpr **children_ptr;
+    size_t children_len;
+
+    MemoizedBool need_spill;
+};
+
 // synchronized with code in define_builtin_compile_vars
 enum AtomicOrder {
     AtomicOrderUnordered,
@@ -2510,6 +2529,10 @@ struct IrInstruction {
     // with this child field.
     IrInstruction *child;
     IrBasicBlock *owner_bb;
+    // Nearly any instruction can have to be stored as a local variable before suspending
+    // and then loaded after resuming, in case there is an expression with a suspend point
+    // in it, such as: x + await y
+    IrInstruction *spill;
     IrInstructionId id;
     // true if this instruction was generated by zig and not from user code
     bool is_gen;
src/analyze.cpp
@@ -96,6 +96,30 @@ static ScopeDecls **get_container_scope_ptr(ZigType *type_entry) {
     zig_unreachable();
 }
 
+static ScopeExpr *find_expr_scope(Scope *scope) {
+    for (;;) {
+        switch (scope->id) {
+            case ScopeIdExpr:
+                return reinterpret_cast<ScopeExpr *>(scope);
+            case ScopeIdDefer:
+            case ScopeIdDeferExpr:
+            case ScopeIdDecls:
+            case ScopeIdFnDef:
+            case ScopeIdCompTime:
+            case ScopeIdVarDecl:
+            case ScopeIdCImport:
+            case ScopeIdSuspend:
+            case ScopeIdTypeOf:
+            case ScopeIdBlock:
+                return nullptr;
+            case ScopeIdLoop:
+            case ScopeIdRuntime:
+                scope = scope->parent;
+                continue;
+        }
+    }
+}
+
 ScopeDecls *get_container_scope(ZigType *type_entry) {
     return *get_container_scope_ptr(type_entry);
 }
@@ -203,6 +227,20 @@ Scope *create_typeof_scope(CodeGen *g, AstNode *node, Scope *parent) {
     return &scope->base;
 }
 
+Scope *create_expr_scope(CodeGen *g, AstNode *node, Scope *parent) {
+    ScopeExpr *scope = allocate<ScopeExpr>(1);
+    init_scope(g, &scope->base, ScopeIdExpr, node, parent);
+    ScopeExpr *parent_expr = find_expr_scope(parent);
+    if (parent_expr != nullptr) {
+        size_t new_len = parent_expr->children_len + 1;
+        parent_expr->children_ptr = reallocate_nonzero<ScopeExpr *>(
+                parent_expr->children_ptr, parent_expr->children_len, new_len);
+        parent_expr->children_ptr[parent_expr->children_len] = scope;
+        parent_expr->children_len = new_len;
+    }
+    return &scope->base;
+}
+
 ZigType *get_scope_import(Scope *scope) {
     while (scope) {
         if (scope->id == ScopeIdDecls) {
@@ -5654,6 +5692,69 @@ static ZigType *get_async_fn_type(CodeGen *g, ZigType *orig_fn_type) {
     return fn_type;
 }
 
+// Traverse up to the very top ExprScope, which has children.
+// We have just arrived at the top from a child. That child,
+// and its next siblings, do not need to be marked. But the previous
+// siblings do.
+//      x + (await y)
+// vs
+//      (await y) + x
+static void mark_suspension_point(Scope *scope) {
+    ScopeExpr *child_expr_scope = (scope->id == ScopeIdExpr) ? reinterpret_cast<ScopeExpr *>(scope) : nullptr;
+    for (;;) {
+        scope = scope->parent;
+        switch (scope->id) {
+            case ScopeIdDefer:
+            case ScopeIdDeferExpr:
+            case ScopeIdDecls:
+            case ScopeIdFnDef:
+            case ScopeIdCompTime:
+            case ScopeIdVarDecl:
+            case ScopeIdCImport:
+            case ScopeIdSuspend:
+            case ScopeIdTypeOf:
+            case ScopeIdBlock:
+                return;
+            case ScopeIdLoop:
+            case ScopeIdRuntime:
+                continue;
+            case ScopeIdExpr: {
+                ScopeExpr *parent_expr_scope = reinterpret_cast<ScopeExpr *>(scope);
+                if (child_expr_scope != nullptr) {
+                    for (size_t i = 0; parent_expr_scope->children_ptr[i] != child_expr_scope; i += 1) {
+                        assert(i < parent_expr_scope->children_len);
+                        parent_expr_scope->children_ptr[i]->need_spill = MemoizedBoolTrue;
+                    }
+                }
+                parent_expr_scope->need_spill = MemoizedBoolTrue;
+                child_expr_scope = parent_expr_scope;
+                continue;
+            }
+        }
+    }
+}
+
+static bool scope_needs_spill(Scope *scope) {
+    ScopeExpr *scope_expr = find_expr_scope(scope);
+    if (scope_expr == nullptr) return false;
+
+    switch (scope_expr->need_spill) {
+        case MemoizedBoolUnknown:
+            if (scope_needs_spill(scope_expr->base.parent)) {
+                scope_expr->need_spill = MemoizedBoolTrue;
+                return true;
+            } else {
+                scope_expr->need_spill = MemoizedBoolFalse;
+                return false;
+            }
+        case MemoizedBoolFalse:
+            return false;
+        case MemoizedBoolTrue:
+            return true;
+    }
+    zig_unreachable();
+}
+
 static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
     Error err;
 
@@ -5786,21 +5887,17 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
                 callee_frame_type, "");
     }
     // Since this frame is async, an await might represent a suspend point, and
-    // therefore need to spill.
+    // therefore need to spill. It also needs to mark expr scopes as having to spill.
+    // For example: foo() + await z
+    // The funtion call result of foo() must be spilled.
     for (size_t i = 0; i < fn->await_list.length; i += 1) {
         IrInstructionAwaitGen *await = fn->await_list.at(i);
-        // TODO If this is a noasync await, it doesn't need to spill
+        // TODO If this is a noasync await, it doesn't suspend
         // https://github.com/ziglang/zig/issues/3157
-        if (await->result_loc != nullptr) {
-            // If there's a result location, that is the spill
+        if (await->base.value.special != ConstValSpecialRuntime) {
+            // Known at comptime. No spill, no suspend.
             continue;
         }
-        if (!type_has_bits(await->base.value.type))
-            continue;
-        if (await->base.value.special != ConstValSpecialRuntime)
-            continue;
-        if (await->base.ref_count == 0)
-            continue;
         if (await->target_fn != nullptr) {
             // we might not need to suspend
             analyze_fn_async(g, await->target_fn, false);
@@ -5809,13 +5906,53 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
                 return ErrorSemanticAnalyzeFail;
             }
             if (!fn_is_async(await->target_fn)) {
-                // This await does not represent a suspend point. No spill needed.
+                // This await does not represent a suspend point. No spill needed,
+                // and no need to mark ExprScope.
                 continue;
             }
         }
+        // This await is a suspend point, but it might not need a spill.
+        // We do need to mark the ExprScope as having a suspend point in it.
+        mark_suspension_point(await->base.scope);
+
+        if (await->result_loc != nullptr) {
+            // If there's a result location, that is the spill
+            continue;
+        }
+        if (await->base.ref_count == 0)
+            continue;
+        if (!type_has_bits(await->base.value.type))
+            continue;
         await->result_loc = ir_create_alloca(g, await->base.scope, await->base.source_node, fn,
                 await->base.value.type, "");
     }
+    // Now that we've marked all the expr scopes that have to spill, we go over the instructions
+    // and spill the relevant ones.
+    for (size_t block_i = 0; block_i < fn->analyzed_executable.basic_block_list.length; block_i += 1) {
+        IrBasicBlock *block = fn->analyzed_executable.basic_block_list.at(block_i);
+        for (size_t instr_i = 0; instr_i < block->instruction_list.length; instr_i += 1) {
+            IrInstruction *instruction = block->instruction_list.at(instr_i);
+            if (instruction->id == IrInstructionIdAwaitGen ||
+                instruction->id == IrInstructionIdVarPtr ||
+                instruction->id == IrInstructionIdDeclRef ||
+                instruction->id == IrInstructionIdAllocaGen)
+            {
+                // This instruction does its own spilling specially, or otherwise doesn't need it.
+                continue;
+            }
+            if (instruction->value.special != ConstValSpecialRuntime)
+                continue;
+            if (instruction->ref_count == 0)
+                continue;
+            if (!type_has_bits(instruction->value.type))
+                continue;
+            if (scope_needs_spill(instruction->scope)) {
+                instruction->spill = ir_create_alloca(g, instruction->scope, instruction->source_node,
+                        fn, instruction->value.type, "");
+            }
+        }
+    }
+
     FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
     ZigType *ptr_return_type = get_pointer_to_type(g, fn_type_id->return_type, false);
 
src/analyze.hpp
@@ -114,6 +114,7 @@ ScopeFnDef *create_fndef_scope(CodeGen *g, AstNode *node, Scope *parent, ZigFn *
 Scope *create_comptime_scope(CodeGen *g, AstNode *node, Scope *parent);
 Scope *create_runtime_scope(CodeGen *g, AstNode *node, Scope *parent, IrInstruction *is_comptime);
 Scope *create_typeof_scope(CodeGen *g, AstNode *node, Scope *parent);
+Scope *create_expr_scope(CodeGen *g, AstNode *node, Scope *parent);
 
 void init_const_str_lit(CodeGen *g, ConstExprValue *const_val, Buf *str);
 ConstExprValue *create_const_str_lit(CodeGen *g, Buf *str);
@@ -261,5 +262,4 @@ void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn);
 IrInstruction *ir_create_alloca(CodeGen *g, Scope *scope, AstNode *source_node, ZigFn *fn,
         ZigType *var_type, const char *name_hint);
 
-
 #endif
src/codegen.cpp
@@ -649,6 +649,7 @@ static ZigLLVMDIScope *get_di_scope(CodeGen *g, Scope *scope) {
         case ScopeIdCompTime:
         case ScopeIdRuntime:
         case ScopeIdTypeOf:
+        case ScopeIdExpr:
             return get_di_scope(g, scope->parent);
     }
     zig_unreachable();
@@ -1644,7 +1645,6 @@ static void gen_assign_raw(CodeGen *g, LLVMValueRef ptr, ZigType *ptr_type,
     LLVMValueRef ored_value = LLVMBuildOr(g->builder, shifted_value, anded_containing_int, "");
 
     gen_store(g, ored_value, ptr, ptr_type);
-    return;
 }
 
 static void gen_var_debug_decl(CodeGen *g, ZigVar *var) {
@@ -1664,11 +1664,16 @@ static LLVMValueRef ir_llvm_value(CodeGen *g, IrInstruction *instruction) {
         if (instruction->id == IrInstructionIdAwaitGen) {
             IrInstructionAwaitGen *await = reinterpret_cast<IrInstructionAwaitGen*>(instruction);
             if (await->result_loc != nullptr) {
-                instruction->llvm_value = get_handle_value(g, ir_llvm_value(g, await->result_loc),
+                return get_handle_value(g, ir_llvm_value(g, await->result_loc),
                     await->result_loc->value.type->data.pointer.child_type, await->result_loc->value.type);
-                return instruction->llvm_value;
             }
         }
+        if (instruction->spill != nullptr) {
+            ZigType *ptr_type = instruction->spill->value.type;
+            src_assert(ptr_type->id == ZigTypeIdPointer, instruction->source_node);
+            return get_handle_value(g, ir_llvm_value(g, instruction->spill),
+                ptr_type->data.pointer.child_type, instruction->spill->value.type);
+        }
         src_assert(instruction->value.special != ConstValSpecialRuntime, instruction->source_node);
         assert(instruction->value.type);
         render_const_val(g, &instruction->value, "");
@@ -3786,6 +3791,7 @@ static void render_async_var_decls(CodeGen *g, Scope *scope) {
             case ScopeIdCompTime:
             case ScopeIdRuntime:
             case ScopeIdTypeOf:
+            case ScopeIdExpr:
                 scope = scope->parent;
                 continue;
         }
@@ -6049,6 +6055,11 @@ static void ir_render(CodeGen *g, ZigFn *fn_entry) {
                 set_debug_location(g, instruction);
             }
             instruction->llvm_value = ir_render_instruction(g, executable, instruction);
+            if (instruction->spill != nullptr) {
+                LLVMValueRef spill_ptr = ir_llvm_value(g, instruction->spill);
+                gen_assign_raw(g, spill_ptr, instruction->spill->value.type, instruction->llvm_value);
+                instruction->llvm_value = nullptr;
+            }
         }
         current_block->llvm_exit_block = LLVMGetInsertBlock(g->builder);
     }
src/ir.cpp
@@ -3364,6 +3364,7 @@ static void ir_count_defers(IrBuilder *irb, Scope *inner_scope, Scope *outer_sco
             case ScopeIdCompTime:
             case ScopeIdRuntime:
             case ScopeIdTypeOf:
+            case ScopeIdExpr:
                 scope = scope->parent;
                 continue;
             case ScopeIdDeferExpr:
@@ -3420,6 +3421,7 @@ static bool ir_gen_defers_for_block(IrBuilder *irb, Scope *inner_scope, Scope *o
             case ScopeIdCompTime:
             case ScopeIdRuntime:
             case ScopeIdTypeOf:
+            case ScopeIdExpr:
                 scope = scope->parent;
                 continue;
             case ScopeIdDeferExpr:
@@ -8158,7 +8160,15 @@ static IrInstruction *ir_gen_node_extra(IrBuilder *irb, AstNode *node, Scope *sc
         result_loc = no_result_loc();
         ir_build_reset_result(irb, scope, node, result_loc);
     }
-    IrInstruction *result = ir_gen_node_raw(irb, node, scope, lval, result_loc);
+    Scope *child_scope;
+    if (irb->exec->is_inline ||
+        (irb->exec->fn_entry != nullptr && irb->exec->fn_entry->child_scope == scope))
+    {
+        child_scope = scope;
+    } else {
+        child_scope = create_expr_scope(irb->codegen, node, scope);
+    }
+    IrInstruction *result = ir_gen_node_raw(irb, node, child_scope, lval, result_loc);
     if (result == irb->codegen->invalid_instruction) {
         if (irb->exec->first_err_trace_msg == nullptr) {
             irb->exec->first_err_trace_msg = irb->codegen->trace_err;
std/event/future.zig
@@ -104,11 +104,7 @@ fn testFuture(loop: *Loop) void {
     var b = async waitOnFuture(&future);
     resolveFuture(&future);
 
-    // TODO https://github.com/ziglang/zig/issues/3077
-    //const result = (await a) + (await b);
-    const a_result = await a;
-    const b_result = await b;
-    const result = a_result + b_result;
+    const result = (await a) + (await b);
 
     testing.expect(result == 12);
 }
test/stage1/behavior/async_fn.zig
@@ -921,12 +921,10 @@ fn recursiveAsyncFunctionTest(comptime suspending_implementation: bool) type {
             var sum: u32 = 0;
 
             f1_awaited = true;
-            const result_f1 = await f1; // TODO https://github.com/ziglang/zig/issues/3077
-            sum += try result_f1;
+            sum += try await f1;
 
             f2_awaited = true;
-            const result_f2 = await f2; // TODO https://github.com/ziglang/zig/issues/3077
-            sum += try result_f2;
+            sum += try await f2;
 
             return sum;
         }
@@ -943,8 +941,7 @@ fn recursiveAsyncFunctionTest(comptime suspending_implementation: bool) type {
 
         fn amain(result: *u32) void {
             var x = async fib(std.heap.direct_allocator, 10);
-            const res = await x; // TODO https://github.com/ziglang/zig/issues/3077
-            result.* = res catch unreachable;
+            result.* = (await x) catch unreachable;
         }
     };
 }
@@ -1002,8 +999,7 @@ test "@asyncCall using the result location inside the frame" {
             return 1234;
         }
         fn getAnswer(f: anyframe->i32, out: *i32) void {
-            var res = await f; // TODO https://github.com/ziglang/zig/issues/3077
-            out.* = res;
+            out.* = await f;
         }
     };
     var data: i32 = 1;
@@ -1124,3 +1120,19 @@ test "await used in expression and awaiting fn with no suspend but async calling
     };
     _ = async S.atest();
 }
+
+test "await used in expression after a fn call" {
+    const S = struct {
+        fn atest() void {
+            var f1 = async add(3, 4);
+            var sum: i32 = 0;
+            sum = foo() + await f1;
+            expect(sum == 8);
+        }
+        async fn add(a: i32, b: i32) i32 {
+            return a + b;
+        }
+        fn foo() i32 { return 1; }
+    };
+    _ = async S.atest();
+}