Commit 229323e13a

Andrew Kelley <andrew@ziglang.org>
2019-09-07 23:37:17
fix suspensions inside for loops generating invalid LLVM IR
closes #3076
1 parent d3cf040
Changed files (5)
src/all_types.hpp
@@ -25,6 +25,7 @@ struct ZigFn;
 struct Scope;
 struct ScopeBlock;
 struct ScopeFnDef;
+struct ScopeExpr;
 struct ZigType;
 struct ZigVar;
 struct ErrorTableEntry;
@@ -2230,6 +2231,7 @@ struct ScopeLoop {
     ZigList<IrInstruction *> *incoming_values;
     ZigList<IrBasicBlock *> *incoming_blocks;
     ResultLocPeerParent *peer_parent;
+    ScopeExpr *spill_scope;
 };
 
 // This scope blocks certain things from working such as comptime continue
src/analyze.cpp
@@ -227,7 +227,7 @@ Scope *create_typeof_scope(CodeGen *g, AstNode *node, Scope *parent) {
     return &scope->base;
 }
 
-Scope *create_expr_scope(CodeGen *g, AstNode *node, Scope *parent) {
+ScopeExpr *create_expr_scope(CodeGen *g, AstNode *node, Scope *parent) {
     ScopeExpr *scope = allocate<ScopeExpr>(1);
     init_scope(g, &scope->base, ScopeIdExpr, node, parent);
     ScopeExpr *parent_expr = find_expr_scope(parent);
@@ -238,7 +238,7 @@ Scope *create_expr_scope(CodeGen *g, AstNode *node, Scope *parent) {
         parent_expr->children_ptr[parent_expr->children_len] = scope;
         parent_expr->children_len = new_len;
     }
-    return &scope->base;
+    return scope;
 }
 
 ZigType *get_scope_import(Scope *scope) {
@@ -5713,7 +5713,6 @@ static void mark_suspension_point(Scope *scope) {
             case ScopeIdCImport:
             case ScopeIdSuspend:
             case ScopeIdTypeOf:
-            case ScopeIdBlock:
                 return;
             case ScopeIdLoop:
             case ScopeIdRuntime:
@@ -5730,6 +5729,14 @@ static void mark_suspension_point(Scope *scope) {
                 child_expr_scope = parent_expr_scope;
                 continue;
             }
+            case ScopeIdBlock:
+                if (scope->parent->parent->id == ScopeIdLoop) {
+                    ScopeLoop *loop_scope = reinterpret_cast<ScopeLoop *>(scope->parent->parent);
+                    if (loop_scope->spill_scope != nullptr) {
+                        loop_scope->spill_scope->need_spill = MemoizedBoolTrue;
+                    }
+                }
+                return;
         }
     }
 }
@@ -5928,6 +5935,15 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
         await->result_loc = ir_create_alloca(g, await->base.scope, await->base.source_node, fn,
                 await->base.value.type, "");
     }
+    for (size_t block_i = 0; block_i < fn->analyzed_executable.basic_block_list.length; block_i += 1) {
+        IrBasicBlock *block = fn->analyzed_executable.basic_block_list.at(block_i);
+        for (size_t instr_i = 0; instr_i < block->instruction_list.length; instr_i += 1) {
+            IrInstruction *instruction = block->instruction_list.at(instr_i);
+            if (instruction->id == IrInstructionIdSuspendFinish) {
+                mark_suspension_point(instruction->scope);
+            }
+        }
+    }
     // Now that we've marked all the expr scopes that have to spill, we go over the instructions
     // and spill the relevant ones.
     for (size_t block_i = 0; block_i < fn->analyzed_executable.basic_block_list.length; block_i += 1) {
@@ -6395,9 +6411,7 @@ void eval_min_max_value(CodeGen *g, ZigType *type_entry, ConstExprValue *const_v
 }
 
 static void render_const_val_ptr(CodeGen *g, Buf *buf, ConstExprValue *const_val, ZigType *type_entry) {
-    assert(type_entry->id == ZigTypeIdPointer);
-
-    if (type_entry->data.pointer.child_type->id == ZigTypeIdOpaque) {
+    if (type_entry->id == ZigTypeIdPointer && type_entry->data.pointer.child_type->id == ZigTypeIdOpaque) {
         buf_append_buf(buf, &type_entry->name);
         return;
     }
src/analyze.hpp
@@ -114,7 +114,7 @@ ScopeFnDef *create_fndef_scope(CodeGen *g, AstNode *node, Scope *parent, ZigFn *
 Scope *create_comptime_scope(CodeGen *g, AstNode *node, Scope *parent);
 Scope *create_runtime_scope(CodeGen *g, AstNode *node, Scope *parent, IrInstruction *is_comptime);
 Scope *create_typeof_scope(CodeGen *g, AstNode *node, Scope *parent);
-Scope *create_expr_scope(CodeGen *g, AstNode *node, Scope *parent);
+ScopeExpr *create_expr_scope(CodeGen *g, AstNode *node, Scope *parent);
 
 void init_const_str_lit(CodeGen *g, ConstExprValue *const_val, Buf *str);
 ConstExprValue *create_const_str_lit(CodeGen *g, Buf *str);
src/ir.cpp
@@ -6474,6 +6474,8 @@ static IrInstruction *ir_gen_for_expr(IrBuilder *irb, Scope *parent_scope, AstNo
     IrInstruction *is_comptime = ir_build_const_bool(irb, parent_scope, node,
         ir_should_inline(irb->exec, parent_scope) || node->data.for_expr.is_inline);
 
+    ScopeExpr *spill_scope = create_expr_scope(irb->codegen, node, parent_scope);
+
     AstNode *index_var_source_node;
     ZigVar *index_var;
     const char *index_var_name;
@@ -6504,11 +6506,11 @@ static IrInstruction *ir_gen_for_expr(IrBuilder *irb, Scope *parent_scope, AstNo
 
     Buf *len_field_name = buf_create_from_str("len");
     IrInstruction *len_ref = ir_build_field_ptr(irb, parent_scope, node, array_val_ptr, len_field_name, false);
-    IrInstruction *len_val = ir_build_load_ptr(irb, parent_scope, node, len_ref);
+    IrInstruction *len_val = ir_build_load_ptr(irb, &spill_scope->base, node, len_ref);
     ir_build_br(irb, parent_scope, node, cond_block, is_comptime);
 
     ir_set_cursor_at_end_and_append_block(irb, cond_block);
-    IrInstruction *index_val = ir_build_load_ptr(irb, parent_scope, node, index_ptr);
+    IrInstruction *index_val = ir_build_load_ptr(irb, &spill_scope->base, node, index_ptr);
     IrInstruction *cond = ir_build_bin_op(irb, parent_scope, node, IrBinOpCmpLessThan, index_val, len_val, false);
     IrBasicBlock *after_cond_block = irb->current_basic_block;
     IrInstruction *void_else_value = else_node ? nullptr : ir_mark_gen(ir_build_const_void(irb, parent_scope, node));
@@ -6518,7 +6520,8 @@ static IrInstruction *ir_gen_for_expr(IrBuilder *irb, Scope *parent_scope, AstNo
     ResultLocPeerParent *peer_parent = ir_build_result_peers(irb, cond_br_inst, end_block, result_loc, is_comptime);
 
     ir_set_cursor_at_end_and_append_block(irb, body_block);
-    IrInstruction *elem_ptr = ir_build_elem_ptr(irb, parent_scope, node, array_val_ptr, index_val, false,
+    Scope *elem_ptr_scope = node->data.for_expr.elem_is_ptr ? parent_scope : &spill_scope->base;
+    IrInstruction *elem_ptr = ir_build_elem_ptr(irb, elem_ptr_scope, node, array_val_ptr, index_val, false,
             PtrLenSingle, nullptr);
     // TODO make it an error to write to element variable or i variable.
     Buf *elem_var_name = elem_node->data.symbol_expr.symbol;
@@ -6526,7 +6529,7 @@ static IrInstruction *ir_gen_for_expr(IrBuilder *irb, Scope *parent_scope, AstNo
     Scope *child_scope = elem_var->child_scope;
 
     IrInstruction *var_ptr = node->data.for_expr.elem_is_ptr ?
-        ir_build_ref(irb, parent_scope, elem_node, elem_ptr, true, false) : elem_ptr;
+        ir_build_ref(irb, &spill_scope->base, elem_node, elem_ptr, true, false) : elem_ptr;
     ir_build_var_decl_src(irb, parent_scope, elem_node, elem_var, nullptr, var_ptr);
 
     ZigList<IrInstruction *> incoming_values = {0};
@@ -6539,6 +6542,7 @@ static IrInstruction *ir_gen_for_expr(IrBuilder *irb, Scope *parent_scope, AstNo
     loop_scope->incoming_values = &incoming_values;
     loop_scope->lval = LValNone;
     loop_scope->peer_parent = peer_parent;
+    loop_scope->spill_scope = spill_scope;
 
     // Note the body block of the loop is not the place that lval and result_loc are used -
     // it's actually in break statements, handled similarly to return statements.
@@ -8166,7 +8170,7 @@ static IrInstruction *ir_gen_node_extra(IrBuilder *irb, AstNode *node, Scope *sc
     {
         child_scope = scope;
     } else {
-        child_scope = create_expr_scope(irb->codegen, node, scope);
+        child_scope = &create_expr_scope(irb->codegen, node, scope)->base;
     }
     IrInstruction *result = ir_gen_node_raw(irb, node, child_scope, lval, result_loc);
     if (result == irb->codegen->invalid_instruction) {
test/stage1/behavior/async_fn.zig
@@ -1151,3 +1151,30 @@ test "async fn call used in expression after a fn call" {
     };
     _ = async S.atest();
 }
+
+test "suspend in for loop" {
+    const S = struct {
+        var global_frame: ?anyframe = null;
+
+        fn doTheTest() void {
+            _ = async atest();
+            while (global_frame) |f| resume f;
+        }
+
+        fn atest() void {
+            expect(func([_]u8{ 1, 2, 3 }) == 6);
+        }
+        fn func(stuff: []const u8) u32 {
+            global_frame = @frame();
+            var sum: u32 = 0;
+            for (stuff) |x| {
+                suspend;
+                sum += x;
+            }
+            global_frame = null;
+            return sum;
+        }
+    };
+    S.doTheTest();
+}
+