Commit 06909ceaab

Andrew Kelley <superjoe30@gmail.com>
2018-04-19 04:21:54
support break in suspend blocks
* you can label suspend blocks * labeled break supports suspend blocks See #803
1 parent ca4341f
doc/langref.html.in
@@ -5918,7 +5918,7 @@ Defer(body) = ("defer" | "deferror") body
 
 IfExpression(body) = "if" "(" Expression ")" body option("else" BlockExpression(body))
 
-SuspendExpression(body) = "suspend" option(("|" Symbol "|" body))
+SuspendExpression(body) = option(Symbol ":") "suspend" option(("|" Symbol "|" body))
 
 IfErrorExpression(body) = "if" "(" Expression ")" option("|" option("*") Symbol "|") body "else" "|" Symbol "|" BlockExpression(body)
 
src/all_types.hpp
@@ -867,6 +867,7 @@ struct AstNodeAwaitExpr {
 };
 
 struct AstNodeSuspend {
+    Buf *name;
     AstNode *block;
     AstNode *promise_symbol;
 };
@@ -1757,6 +1758,7 @@ enum ScopeId {
     ScopeIdVarDecl,
     ScopeIdCImport,
     ScopeIdLoop,
+    ScopeIdSuspend,
     ScopeIdFnDef,
     ScopeIdCompTime,
     ScopeIdCoroPrelude,
@@ -1852,6 +1854,17 @@ struct ScopeLoop {
     ZigList<IrBasicBlock *> *incoming_blocks;
 };
 
+// This scope is created for a suspend block in order to have labeled
+// suspend for breaking out of a suspend and for detecting if a suspend
+// block is inside a suspend block.
+struct ScopeSuspend {
+    Scope base;
+
+    Buf *name;
+    IrBasicBlock *resume_block;
+    bool reported_err;
+};
+
 // This scope is created for a comptime expression.
 // NodeTypeCompTime, NodeTypeSwitchExpr
 struct ScopeCompTime {
src/analyze.cpp
@@ -156,6 +156,14 @@ ScopeLoop *create_loop_scope(AstNode *node, Scope *parent) {
     return scope;
 }
 
+ScopeSuspend *create_suspend_scope(AstNode *node, Scope *parent) {
+    assert(node->type == NodeTypeSuspend);
+    ScopeSuspend *scope = allocate<ScopeSuspend>(1);
+    init_scope(&scope->base, ScopeIdSuspend, node, parent);
+    scope->name = node->data.suspend.name;
+    return scope;
+}
+
 ScopeFnDef *create_fndef_scope(AstNode *node, Scope *parent, FnTableEntry *fn_entry) {
     ScopeFnDef *scope = allocate<ScopeFnDef>(1);
     init_scope(&scope->base, ScopeIdFnDef, node, parent);
@@ -3616,6 +3624,7 @@ FnTableEntry *scope_get_fn_if_root(Scope *scope) {
             case ScopeIdVarDecl:
             case ScopeIdCImport:
             case ScopeIdLoop:
+            case ScopeIdSuspend:
             case ScopeIdCompTime:
             case ScopeIdCoroPrelude:
                 scope = scope->parent;
src/analyze.hpp
@@ -104,6 +104,7 @@ ScopeDeferExpr *create_defer_expr_scope(AstNode *node, Scope *parent);
 Scope *create_var_scope(AstNode *node, Scope *parent, VariableTableEntry *var);
 ScopeCImport *create_cimport_scope(AstNode *node, Scope *parent);
 ScopeLoop *create_loop_scope(AstNode *node, Scope *parent);
+ScopeSuspend *create_suspend_scope(AstNode *node, Scope *parent);
 ScopeFnDef *create_fndef_scope(AstNode *node, Scope *parent, FnTableEntry *fn_entry);
 ScopeDecls *create_decls_scope(AstNode *node, Scope *parent, TypeTableEntry *container_type, ImportTableEntry *import);
 Scope *create_comptime_scope(AstNode *node, Scope *parent);
src/codegen.cpp
@@ -654,6 +654,7 @@ static ZigLLVMDIScope *get_di_scope(CodeGen *g, Scope *scope) {
         }
         case ScopeIdDeferExpr:
         case ScopeIdLoop:
+        case ScopeIdSuspend:
         case ScopeIdCompTime:
         case ScopeIdCoroPrelude:
             return get_di_scope(g, scope->parent);
src/ir.cpp
@@ -2829,6 +2829,18 @@ static void ir_set_cursor_at_end_and_append_block(IrBuilder *irb, IrBasicBlock *
     ir_set_cursor_at_end(irb, basic_block);
 }
 
+static ScopeSuspend *get_scope_suspend(Scope *scope) {
+    while (scope) {
+        if (scope->id == ScopeIdSuspend)
+            return (ScopeSuspend *)scope;
+        if (scope->id == ScopeIdFnDef)
+            return nullptr;
+
+        scope = scope->parent;
+    }
+    return nullptr;
+}
+
 static ScopeDeferExpr *get_scope_defer_expr(Scope *scope) {
     while (scope) {
         if (scope->id == ScopeIdDeferExpr)
@@ -5665,6 +5677,15 @@ static IrInstruction *ir_gen_return_from_block(IrBuilder *irb, Scope *break_scop
     return ir_build_br(irb, break_scope, node, dest_block, is_comptime);
 }
 
+static IrInstruction *ir_gen_break_from_suspend(IrBuilder *irb, Scope *break_scope, AstNode *node, ScopeSuspend *suspend_scope) {
+    IrInstruction *is_comptime = ir_build_const_bool(irb, break_scope, node, false);
+
+    IrBasicBlock *dest_block = suspend_scope->resume_block;
+    ir_gen_defers_for_block(irb, break_scope, dest_block->scope, false);
+
+    return ir_build_br(irb, break_scope, node, dest_block, is_comptime);
+}
+
 static IrInstruction *ir_gen_break(IrBuilder *irb, Scope *break_scope, AstNode *node) {
     assert(node->type == NodeTypeBreak);
 
@@ -5704,6 +5725,13 @@ static IrInstruction *ir_gen_break(IrBuilder *irb, Scope *break_scope, AstNode *
                 assert(this_block_scope->end_block != nullptr);
                 return ir_gen_return_from_block(irb, break_scope, node, this_block_scope);
             }
+        } else if (search_scope->id == ScopeIdSuspend) {
+            ScopeSuspend *this_suspend_scope = (ScopeSuspend *)search_scope;
+            if (node->data.break_expr.name != nullptr &&
+                (this_suspend_scope->name != nullptr && buf_eql_buf(node->data.break_expr.name, this_suspend_scope->name)))
+            {
+                return ir_gen_break_from_suspend(irb, break_scope, node, this_suspend_scope);
+            }
         }
         search_scope = search_scope->parent;
     }
@@ -6290,14 +6318,26 @@ static IrInstruction *ir_gen_suspend(IrBuilder *irb, Scope *parent_scope, AstNod
     ScopeDeferExpr *scope_defer_expr = get_scope_defer_expr(parent_scope);
     if (scope_defer_expr) {
         if (!scope_defer_expr->reported_err) {
-            add_node_error(irb->codegen, node, buf_sprintf("cannot suspend inside defer expression"));
+            ErrorMsg *msg = add_node_error(irb->codegen, node, buf_sprintf("cannot suspend inside defer expression"));
+            add_error_note(irb->codegen, msg, scope_defer_expr->base.source_node, buf_sprintf("defer here"));
             scope_defer_expr->reported_err = true;
         }
         return irb->codegen->invalid_instruction;
     }
+    ScopeSuspend *existing_suspend_scope = get_scope_suspend(parent_scope);
+    if (existing_suspend_scope) {
+        if (!existing_suspend_scope->reported_err) {
+            ErrorMsg *msg = add_node_error(irb->codegen, node, buf_sprintf("cannot suspend inside suspend block"));
+            add_error_note(irb->codegen, msg, existing_suspend_scope->base.source_node, buf_sprintf("other suspend block here"));
+            existing_suspend_scope->reported_err = true;
+        }
+        return irb->codegen->invalid_instruction;
+    }
 
     Scope *outer_scope = irb->exec->begin_scope;
 
+    IrBasicBlock *cleanup_block = ir_create_basic_block(irb, parent_scope, "SuspendCleanup");
+    IrBasicBlock *resume_block = ir_create_basic_block(irb, parent_scope, "SuspendResume");
 
     IrInstruction *suspend_code;
     IrInstruction *const_bool_false = ir_build_const_bool(irb, parent_scope, node, false);
@@ -6316,28 +6356,28 @@ static IrInstruction *ir_gen_suspend(IrBuilder *irb, Scope *parent_scope, AstNod
         } else {
             child_scope = parent_scope;
         }
+        ScopeSuspend *suspend_scope = create_suspend_scope(node, child_scope);
+        suspend_scope->resume_block = resume_block;
+        child_scope = &suspend_scope->base;
         IrInstruction *save_token = ir_build_coro_save(irb, child_scope, node, irb->exec->coro_handle);
         ir_gen_node(irb, node->data.suspend.block, child_scope);
-        suspend_code = ir_build_coro_suspend(irb, parent_scope, node, save_token, const_bool_false);
+        suspend_code = ir_mark_gen(ir_build_coro_suspend(irb, parent_scope, node, save_token, const_bool_false));
     }
 
-    IrBasicBlock *cleanup_block = ir_create_basic_block(irb, parent_scope, "SuspendCleanup");
-    IrBasicBlock *resume_block = ir_create_basic_block(irb, parent_scope, "SuspendResume");
-
     IrInstructionSwitchBrCase *cases = allocate<IrInstructionSwitchBrCase>(2);
-    cases[0].value = ir_build_const_u8(irb, parent_scope, node, 0);
+    cases[0].value = ir_mark_gen(ir_build_const_u8(irb, parent_scope, node, 0));
     cases[0].block = resume_block;
-    cases[1].value = ir_build_const_u8(irb, parent_scope, node, 1);
+    cases[1].value = ir_mark_gen(ir_build_const_u8(irb, parent_scope, node, 1));
     cases[1].block = cleanup_block;
-    ir_build_switch_br(irb, parent_scope, node, suspend_code, irb->exec->coro_suspend_block,
-            2, cases, const_bool_false);
+    ir_mark_gen(ir_build_switch_br(irb, parent_scope, node, suspend_code, irb->exec->coro_suspend_block,
+            2, cases, const_bool_false));
 
     ir_set_cursor_at_end_and_append_block(irb, cleanup_block);
     ir_gen_defers_for_block(irb, parent_scope, outer_scope, true);
     ir_mark_gen(ir_build_br(irb, parent_scope, node, irb->exec->coro_final_cleanup_block, const_bool_false));
 
     ir_set_cursor_at_end_and_append_block(irb, resume_block);
-    return ir_build_const_void(irb, parent_scope, node);
+    return ir_mark_gen(ir_build_const_void(irb, parent_scope, node));
 }
 
 static IrInstruction *ir_gen_node_raw(IrBuilder *irb, AstNode *node, Scope *scope,
src/parser.cpp
@@ -648,12 +648,30 @@ static AstNode *ast_parse_asm_expr(ParseContext *pc, size_t *token_index, bool m
 }
 
 /*
-SuspendExpression(body) = "suspend" "|" Symbol "|" body
+SuspendExpression(body) = option(Symbol ":") "suspend" option(("|" Symbol "|" body))
 */
 static AstNode *ast_parse_suspend_block(ParseContext *pc, size_t *token_index, bool mandatory) {
     size_t orig_token_index = *token_index;
 
-    Token *suspend_token = &pc->tokens->at(*token_index);
+    Token *name_token = nullptr;
+    Token *token = &pc->tokens->at(*token_index);
+    if (token->id == TokenIdSymbol) {
+        *token_index += 1;
+        Token *colon_token = &pc->tokens->at(*token_index);
+        if (colon_token->id == TokenIdColon) {
+            *token_index += 1;
+            name_token = token;
+            token = &pc->tokens->at(*token_index);
+        } else if (mandatory) {
+            ast_expect_token(pc, colon_token, TokenIdColon);
+            zig_unreachable();
+        } else {
+            *token_index = orig_token_index;
+            return nullptr;
+        }
+    }
+
+    Token *suspend_token = token;
     if (suspend_token->id == TokenIdKeywordSuspend) {
         *token_index += 1;
     } else if (mandatory) {
@@ -675,6 +693,9 @@ static AstNode *ast_parse_suspend_block(ParseContext *pc, size_t *token_index, b
     }
 
     AstNode *node = ast_create_node(pc, NodeTypeSuspend, suspend_token);
+    if (name_token != nullptr) {
+        node->data.suspend.name = token_buf(name_token);
+    }
     node->data.suspend.promise_symbol = ast_parse_symbol(pc, token_index);
     ast_eat_token(pc, token_index, TokenIdBinOr);
     node->data.suspend.block = ast_parse_block(pc, token_index, true);
test/cases/coroutines.zig
@@ -224,3 +224,21 @@ async fn printTrace(p: promise->error!void) void {
         }
     };
 }
+
+test "break from suspend" {
+    var buf: [500]u8 = undefined;
+    var a = &std.heap.FixedBufferAllocator.init(buf[0..]).allocator;
+    var my_result: i32 = 1;
+    const p = try async<a> testBreakFromSuspend(&my_result);
+    cancel p;
+    std.debug.assert(my_result == 2);
+}
+
+async fn testBreakFromSuspend(my_result: &i32) void {
+    s: suspend |p| {
+        break :s;
+    }
+    *my_result += 1;
+    suspend;
+    *my_result += 1;
+}
test/compile_errors.zig
@@ -1,6 +1,26 @@
 const tests = @import("tests.zig");
 
 pub fn addCases(cases: &tests.CompileErrorContext) void {
+    cases.add("suspend inside suspend block",
+        \\const std = @import("std");
+        \\
+        \\export fn entry() void {
+        \\    var buf: [500]u8 = undefined;
+        \\    var a = &std.heap.FixedBufferAllocator.init(buf[0..]).allocator;
+        \\    const p = (async<a> foo()) catch unreachable;
+        \\    cancel p;
+        \\}
+        \\
+        \\async fn foo() void {
+        \\    suspend |p| {
+        \\        suspend |p1| {
+        \\        }
+        \\    }
+        \\}
+    ,
+        ".tmp_source.zig:12:9: error: cannot suspend inside suspend block",
+        ".tmp_source.zig:11:5: note: other suspend block here");
+
     cases.add("assign inline fn to non-comptime var",
         \\export fn entry() void {
         \\    var a = b;