Commit 128e70ff3a

LemonBoy <thatlemon@gmail.com>
2020-03-19 21:03:38
ir: Allow errdefer with payload
Closes #1265
1 parent 153c6cf
Changed files (8)
lib/std/zig/ast.zig
@@ -1032,6 +1032,7 @@ pub const Node = struct {
     pub const Defer = struct {
         base: Node = Node{ .id = .Defer },
         defer_token: TokenIndex,
+        payload: ?*Node,
         expr: *Node,
 
         pub fn iterate(self: *Defer, index: usize) ?*Node {
@@ -1833,8 +1834,7 @@ pub const Node = struct {
             var i = index;
 
             switch (self.kind) {
-                .Break,
-                .Continue => |maybe_label| {
+                .Break, .Continue => |maybe_label| {
                     if (maybe_label) |label| {
                         if (i < 1) return label;
                         i -= 1;
@@ -1861,8 +1861,7 @@ pub const Node = struct {
             }
 
             switch (self.kind) {
-                .Break,
-                .Continue => |maybe_label| {
+                .Break, .Continue => |maybe_label| {
                     if (maybe_label) |label| {
                         return label.lastToken();
                     }
lib/std/zig/parse.zig
@@ -465,7 +465,7 @@ fn parseContainerField(arena: *Allocator, it: *TokenIterator, tree: *Tree) !?*No
 ///      / KEYWORD_noasync BlockExprStatement
 ///      / KEYWORD_suspend (SEMICOLON / BlockExprStatement)
 ///      / KEYWORD_defer BlockExprStatement
-///      / KEYWORD_errdefer BlockExprStatement
+///      / KEYWORD_errdefer Payload? BlockExprStatement
 ///      / IfStatement
 ///      / LabeledStatement
 ///      / SwitchExpr
@@ -526,6 +526,10 @@ fn parseStatement(arena: *Allocator, it: *TokenIterator, tree: *Tree) Error!?*No
 
     const defer_token = eatToken(it, .Keyword_defer) orelse eatToken(it, .Keyword_errdefer);
     if (defer_token) |token| {
+        const payload = if (tree.tokens.at(token).id == .Keyword_errdefer)
+            try parsePayload(arena, it, tree)
+        else
+            null;
         const expr_node = try expectNode(arena, it, tree, parseBlockExprStatement, .{
             .ExpectedBlockOrExpression = .{ .token = it.index },
         });
@@ -533,6 +537,7 @@ fn parseStatement(arena: *Allocator, it: *TokenIterator, tree: *Tree) Error!?*No
         node.* = .{
             .defer_token = token,
             .expr = expr_node,
+            .payload = payload,
         };
         return &node.base;
     }
lib/std/zig/parser_test.zig
@@ -1,3 +1,16 @@
+test "zig fmt: noasync block" {
+    try testCanonical(
+        \\pub fn main() anyerror!void {
+        \\    errdefer |a| x += 1;
+        \\    errdefer |a| {}
+        \\    errdefer |a| {
+        \\        x += 1;
+        \\    }
+        \\}
+        \\
+    );
+}
+
 test "zig fmt: noasync block" {
     try testCanonical(
         \\pub fn main() anyerror!void {
lib/std/zig/render.zig
@@ -376,6 +376,9 @@ fn renderExpression(
             const defer_node = @fieldParentPtr(ast.Node.Defer, "base", base);
 
             try renderToken(tree, stream, defer_node.defer_token, indent, start_col, Space.Space);
+            if (defer_node.payload) |payload| {
+                try renderExpression(allocator, stream, tree, indent, start_col, payload, Space.Space);
+            }
             return renderExpression(allocator, stream, tree, indent, start_col, defer_node.expr, space);
         },
         .Comptime => {
src/all_types.hpp
@@ -744,6 +744,7 @@ struct AstNodeReturnExpr {
 
 struct AstNodeDefer {
     ReturnKind kind;
+    AstNode *err_payload;
     AstNode *expr;
 
     // temporary data used in IR generation
src/ir.cpp
@@ -272,6 +272,10 @@ static ResultLoc *no_result_loc(void);
 static IrInstGen *ir_analyze_test_non_null(IrAnalyze *ira, IrInst *source_inst, IrInstGen *value);
 static IrInstGen *ir_error_dependency_loop(IrAnalyze *ira, IrInst *source_instr);
 static IrInstGen *ir_const_undef(IrAnalyze *ira, IrInst *source_instruction, ZigType *ty);
+static ZigVar *ir_create_var(IrBuilderSrc *irb, AstNode *node, Scope *scope, Buf *name,
+        bool src_is_const, bool gen_is_const, bool is_shadowable, IrInstSrc *is_comptime);
+static void build_decl_var_and_init(IrBuilderSrc *irb, Scope *scope, AstNode *source_node, ZigVar *var,
+        IrInstSrc *init, const char *name_hint, IrInstSrc *is_comptime);
 
 static void destroy_instruction_src(IrInstSrc *inst) {
     switch (inst->id) {
@@ -5011,39 +5015,73 @@ static IrInstSrc *ir_mark_gen(IrInstSrc *instruction) {
     return instruction;
 }
 
-static bool ir_gen_defers_for_block(IrBuilderSrc *irb, Scope *inner_scope, Scope *outer_scope, bool gen_error_defers) {
+static bool ir_gen_defers_for_block(IrBuilderSrc *irb, Scope *inner_scope, Scope *outer_scope, bool *is_noreturn, IrInstSrc *err_value) {
     Scope *scope = inner_scope;
-    bool is_noreturn = false;
+    if (is_noreturn != nullptr) *is_noreturn = false;
     while (scope != outer_scope) {
         if (!scope)
-            return is_noreturn;
+            return true;
 
         switch (scope->id) {
             case ScopeIdDefer: {
                 AstNode *defer_node = scope->source_node;
                 assert(defer_node->type == NodeTypeDefer);
                 ReturnKind defer_kind = defer_node->data.defer.kind;
-                if (defer_kind == ReturnKindUnconditional ||
-                    (gen_error_defers && defer_kind == ReturnKindError))
-                {
-                    AstNode *defer_expr_node = defer_node->data.defer.expr;
-                    Scope *defer_expr_scope = defer_node->data.defer.expr_scope;
-                    IrInstSrc *defer_expr_value = ir_gen_node(irb, defer_expr_node, defer_expr_scope);
-                    if (defer_expr_value != irb->codegen->invalid_inst_src) {
-                        if (defer_expr_value->is_noreturn) {
-                            is_noreturn = true;
-                        } else {
-                            ir_mark_gen(ir_build_check_statement_is_void(irb, defer_expr_scope, defer_expr_node,
-                                        defer_expr_value));
-                        }
+                AstNode *defer_expr_node = defer_node->data.defer.expr;
+                AstNode *defer_var_node = defer_node->data.defer.err_payload;
+
+                if (defer_kind == ReturnKindError && err_value == nullptr) {
+                    // This is an `errdefer` but we're generating code for a
+                    // `return` that doesn't return an error, skip it
+                    scope = scope->parent;
+                    continue;
+                }
+
+                Scope *defer_expr_scope = defer_node->data.defer.expr_scope;
+                if (defer_var_node != nullptr) {
+                    assert(defer_kind == ReturnKindError);
+                    assert(defer_var_node->type == NodeTypeSymbol);
+                    Buf *var_name = defer_var_node->data.symbol_expr.symbol;
+
+                    if (defer_expr_node->type == NodeTypeUnreachable) {
+                        add_node_error(irb->codegen, defer_var_node,
+                            buf_sprintf("unused variable: '%s'", buf_ptr(var_name)));
+                        return false;
+                    }
+
+                    IrInstSrc *is_comptime;
+                    if (ir_should_inline(irb->exec, defer_expr_scope)) {
+                        is_comptime = ir_build_const_bool(irb, defer_expr_scope,
+                            defer_expr_node, true);
+                    } else {
+                        is_comptime = ir_build_test_comptime(irb, defer_expr_scope,
+                            defer_expr_node, err_value);
                     }
+
+                    ZigVar *err_var = ir_create_var(irb, defer_var_node, defer_expr_scope,
+                        var_name, true, true, false, is_comptime);
+                    build_decl_var_and_init(irb, defer_expr_scope, defer_var_node, err_var, err_value,
+                        buf_ptr(var_name), is_comptime);
+
+                    defer_expr_scope = err_var->child_scope;
+                }
+
+                IrInstSrc *defer_expr_value = ir_gen_node(irb, defer_expr_node, defer_expr_scope);
+                if (defer_expr_value == irb->codegen->invalid_inst_src)
+                    return irb->codegen->invalid_inst_src;
+
+                if (defer_expr_value->is_noreturn) {
+                    if (is_noreturn != nullptr) *is_noreturn = true;
+                } else {
+                    ir_mark_gen(ir_build_check_statement_is_void(irb, defer_expr_scope, defer_expr_node,
+                                defer_expr_value));
                 }
                 scope = scope->parent;
                 continue;
             }
             case ScopeIdDecls:
             case ScopeIdFnDef:
-                return is_noreturn;
+                return true;
             case ScopeIdBlock:
             case ScopeIdVarDecl:
             case ScopeIdLoop:
@@ -5060,7 +5098,7 @@ static bool ir_gen_defers_for_block(IrBuilderSrc *irb, Scope *inner_scope, Scope
                 zig_unreachable();
         }
     }
-    return is_noreturn;
+    return true;
 }
 
 static void ir_set_cursor_at_end_gen(IrBuilderGen *irb, IrBasicBlockGen *basic_block) {
@@ -5146,7 +5184,8 @@ static IrInstSrc *ir_gen_return(IrBuilderSrc *irb, Scope *scope, AstNode *node,
                 bool have_err_defers = defer_counts[ReturnKindError] > 0;
                 if (!have_err_defers && !irb->codegen->have_err_ret_tracing) {
                     // only generate unconditional defers
-                    ir_gen_defers_for_block(irb, scope, outer_scope, false);
+                    if (!ir_gen_defers_for_block(irb, scope, outer_scope, nullptr, nullptr))
+                        return irb->codegen->invalid_inst_src;
                     IrInstSrc *result = ir_build_return_src(irb, scope, node, nullptr);
                     result_loc_ret->base.source_instruction = result;
                     return result;
@@ -5169,14 +5208,16 @@ static IrInstSrc *ir_gen_return(IrBuilderSrc *irb, Scope *scope, AstNode *node,
                 IrBasicBlockSrc *ret_stmt_block = ir_create_basic_block(irb, scope, "RetStmt");
 
                 ir_set_cursor_at_end_and_append_block(irb, err_block);
-                ir_gen_defers_for_block(irb, scope, outer_scope, true);
+                if (!ir_gen_defers_for_block(irb, scope, outer_scope, nullptr, return_value))
+                    return irb->codegen->invalid_inst_src;
                 if (irb->codegen->have_err_ret_tracing && !should_inline) {
                     ir_build_save_err_ret_addr_src(irb, scope, node);
                 }
                 ir_build_br(irb, scope, node, ret_stmt_block, is_comptime);
 
                 ir_set_cursor_at_end_and_append_block(irb, ok_block);
-                ir_gen_defers_for_block(irb, scope, outer_scope, false);
+                if (!ir_gen_defers_for_block(irb, scope, outer_scope, nullptr, nullptr))
+                    return irb->codegen->invalid_inst_src;
                 ir_build_br(irb, scope, node, ret_stmt_block, is_comptime);
 
                 ir_set_cursor_at_end_and_append_block(irb, ret_stmt_block);
@@ -5213,7 +5254,12 @@ static IrInstSrc *ir_gen_return(IrBuilderSrc *irb, Scope *scope, AstNode *node,
                 result_loc_ret->base.id = ResultLocIdReturn;
                 ir_build_reset_result(irb, scope, node, &result_loc_ret->base);
                 ir_build_end_expr(irb, scope, node, err_val, &result_loc_ret->base);
-                if (!ir_gen_defers_for_block(irb, scope, outer_scope, true)) {
+
+                bool is_noreturn = false;
+                if (!ir_gen_defers_for_block(irb, scope, outer_scope, &is_noreturn, err_val)) {
+                    return irb->codegen->invalid_inst_src;
+                }
+                if (!is_noreturn) {
                     if (irb->codegen->have_err_ret_tracing && !should_inline) {
                         ir_build_save_err_ret_addr_src(irb, scope, node);
                     }
@@ -5415,7 +5461,8 @@ static IrInstSrc *ir_gen_block(IrBuilderSrc *irb, Scope *parent_scope, AstNode *
 
     bool is_return_from_fn = block_node == irb->main_block_node;
     if (!is_return_from_fn) {
-        ir_gen_defers_for_block(irb, child_scope, outer_block_scope, false);
+        if (!ir_gen_defers_for_block(irb, child_scope, outer_block_scope, nullptr, nullptr))
+            return irb->codegen->invalid_inst_src;
     }
 
     IrInstSrc *result;
@@ -5440,7 +5487,8 @@ static IrInstSrc *ir_gen_block(IrBuilderSrc *irb, Scope *parent_scope, AstNode *
     result_loc_ret->base.id = ResultLocIdReturn;
     ir_build_reset_result(irb, parent_scope, block_node, &result_loc_ret->base);
     ir_mark_gen(ir_build_end_expr(irb, parent_scope, block_node, result, &result_loc_ret->base));
-    ir_gen_defers_for_block(irb, child_scope, outer_block_scope, false);
+    if (!ir_gen_defers_for_block(irb, child_scope, outer_block_scope, nullptr, nullptr))
+        return irb->codegen->invalid_inst_src;
     return ir_mark_gen(ir_build_return_src(irb, child_scope, result->base.source_node, result));
 }
 
@@ -9240,7 +9288,8 @@ static IrInstSrc *ir_gen_return_from_block(IrBuilderSrc *irb, Scope *break_scope
     }
 
     IrBasicBlockSrc *dest_block = block_scope->end_block;
-    ir_gen_defers_for_block(irb, break_scope, dest_block->scope, false);
+    if (!ir_gen_defers_for_block(irb, break_scope, dest_block->scope, nullptr, nullptr))
+        return irb->codegen->invalid_inst_src;
 
     block_scope->incoming_blocks->append(irb->current_basic_block);
     block_scope->incoming_values->append(result_value);
@@ -9314,7 +9363,8 @@ static IrInstSrc *ir_gen_break(IrBuilderSrc *irb, Scope *break_scope, AstNode *n
     }
 
     IrBasicBlockSrc *dest_block = loop_scope->break_block;
-    ir_gen_defers_for_block(irb, break_scope, dest_block->scope, false);
+    if (!ir_gen_defers_for_block(irb, break_scope, dest_block->scope, nullptr, nullptr))
+        return irb->codegen->invalid_inst_src;
 
     loop_scope->incoming_blocks->append(irb->current_basic_block);
     loop_scope->incoming_values->append(result_value);
@@ -9373,7 +9423,8 @@ static IrInstSrc *ir_gen_continue(IrBuilderSrc *irb, Scope *continue_scope, AstN
     }
 
     IrBasicBlockSrc *dest_block = loop_scope->continue_block;
-    ir_gen_defers_for_block(irb, continue_scope, dest_block->scope, false);
+    if (!ir_gen_defers_for_block(irb, continue_scope, dest_block->scope, nullptr, nullptr))
+        return irb->codegen->invalid_inst_src;
     return ir_mark_gen(ir_build_br(irb, continue_scope, node, dest_block, is_comptime));
 }
 
src/parser.cpp
@@ -879,7 +879,7 @@ static AstNode *ast_parse_container_field(ParseContext *pc) {
 //      / KEYWORD_noasync BlockExprStatement
 //      / KEYWORD_suspend (SEMICOLON / BlockExprStatement)
 //      / KEYWORD_defer BlockExprStatement
-//      / KEYWORD_errdefer BlockExprStatement
+//      / KEYWORD_errdefer Payload? BlockExprStatement
 //      / IfStatement
 //      / LabeledStatement
 //      / SwitchExpr
@@ -923,12 +923,18 @@ static AstNode *ast_parse_statement(ParseContext *pc) {
     if (defer == nullptr)
         defer = eat_token_if(pc, TokenIdKeywordErrdefer);
     if (defer != nullptr) {
+        Token *payload = (defer->id == TokenIdKeywordErrdefer) ?
+            ast_parse_payload(pc) : nullptr;
         AstNode *statement = ast_expect(pc, ast_parse_block_expr_statement);
         AstNode *res = ast_create_node(pc, NodeTypeDefer, defer);
+
         res->data.defer.kind = ReturnKindUnconditional;
         res->data.defer.expr = statement;
-        if (defer->id == TokenIdKeywordErrdefer)
+        if (defer->id == TokenIdKeywordErrdefer) {
             res->data.defer.kind = ReturnKindError;
+            if (payload != nullptr)
+                res->data.defer.err_payload = token_symbol(pc, payload);
+        }
         return res;
     }
 
@@ -3032,6 +3038,7 @@ void ast_visit_node_children(AstNode *node, void (*visit)(AstNode **, void *cont
             break;
         case NodeTypeDefer:
             visit_field(&node->data.defer.expr, visit, context);
+            visit_field(&node->data.defer.err_payload, visit, context);
             break;
         case NodeTypeVariableDeclaration:
             visit_field(&node->data.variable_declaration.type, visit, context);
test/stage1/behavior/defer.zig
@@ -1,4 +1,6 @@
-const expect = @import("std").testing.expect;
+const std = @import("std");
+const expect = std.testing.expect;
+const expectEqual = std.testing.expectEqual;
 
 var result: [3]u8 = undefined;
 var index: usize = undefined;
@@ -93,3 +95,22 @@ test "return variable while defer expression in scope to modify it" {
     S.doTheTest();
     comptime S.doTheTest();
 }
+
+test "errdefer with payload" {
+    const S = struct {
+        fn foo() !i32 {
+            errdefer |a| {
+                expectEqual(error.One, a);
+            }
+            return error.One;
+        }
+        fn doTheTest() void {
+            _ = foo() catch |err| switch (err) {
+                error.One => {},
+                else => unreachable,
+            };
+        }
+    };
+    S.doTheTest();
+    comptime S.doTheTest();
+}