Commit ec33e5a638

Andrew Kelley <superjoe30@gmail.com>
2016-02-06 08:56:01
simple unconditional defer support
See #110
1 parent 6a2ede5
doc/langref.md
@@ -43,7 +43,7 @@ ParamDecl = option("noalias") option("Symbol" ":") TypeExpr | "..."
 
 Block = "{" list(option(Statement), ";") "}"
 
-Statement = Label | VariableDeclaration ";" | NonBlockExpression ";" | BlockExpression
+Statement = Label | VariableDeclaration ";" | Defer ";" | NonBlockExpression ";" | BlockExpression
 
 Label = "Symbol" ":"
 
@@ -51,7 +51,7 @@ Expression = BlockExpression | NonBlockExpression
 
 TypeExpr = PrefixOpExpression
 
-NonBlockExpression = ReturnExpression | AssignmentExpression | DeferExpression
+NonBlockExpression = ReturnExpression | AssignmentExpression
 
 AsmExpression = "asm" option("volatile") "(" "String" option(AsmOutput) ")"
 
@@ -91,7 +91,7 @@ BoolOrExpression = BoolAndExpression "||" BoolOrExpression | BoolAndExpression
 
 ReturnExpression = option("%" | "?") "return" option(Expression)
 
-DeferExpression = option("%" | "?") "defer" option(Expression)
+Defer = option("%" | "?") "defer" option(Expression)
 
 IfExpression = IfVarExpression | IfBoolExpression
 
example/cat/main.zig
@@ -4,7 +4,6 @@ import "std.zig";
 
 // Things to do to make this work:
 // * var args printing
-// * defer
 // * cast err type to string
 // * string equality
 
src/all_types.hpp
@@ -117,7 +117,7 @@ enum NodeType {
     NodeTypeBlock,
     NodeTypeDirective,
     NodeTypeReturnExpr,
-    NodeTypeDeferExpr,
+    NodeTypeDefer,
     NodeTypeVariableDeclaration,
     NodeTypeTypeDecl,
     NodeTypeErrorValueDecl,
@@ -216,7 +216,12 @@ struct AstNodeBlock {
     ZigList<AstNode *> statements;
 
     // populated by semantic analyzer
-    BlockContext *block_context;
+    // this one is the scope that the block itself introduces
+    BlockContext *child_block;
+    // this is the innermost scope created by defers and var decls.
+    // you can follow its parents up to child_block. it will equal
+    // child_block if there are no defers or var decls in the block.
+    BlockContext *nested_block;
     Expr resolved_expr;
 };
 
@@ -235,7 +240,7 @@ struct AstNodeReturnExpr {
     Expr resolved_expr;
 };
 
-struct AstNodeDeferExpr {
+struct AstNodeDefer {
     ReturnKind kind;
     AstNode *expr;
 
@@ -243,6 +248,7 @@ struct AstNodeDeferExpr {
     Expr resolved_expr;
     int index_in_block;
     LLVMBasicBlockRef basic_block;
+    BlockContext *child_block;
 };
 
 struct AstNodeVariableDeclaration {
@@ -739,7 +745,7 @@ struct AstNode {
         AstNodeParamDecl param_decl;
         AstNodeBlock block;
         AstNodeReturnExpr return_expr;
-        AstNodeDeferExpr defer_expr;
+        AstNodeDefer defer;
         AstNodeVariableDeclaration variable_declaration;
         AstNodeTypeDecl type_decl;
         AstNodeErrorValueDecl error_value_decl;
@@ -1157,10 +1163,12 @@ enum BlockExitPath {
     BlockExitPathFallthrough,
     BlockExitPathReturn,
     BlockExitPathGoto,
+
+    BlockExitPathCount,
 };
 
 struct BlockContext {
-    // One of: NodeTypeFnDef, NodeTypeBlock, NodeTypeRoot, NodeTypeDeferExpr, NodeTypeVariableDeclaration
+    // One of: NodeTypeFnDef, NodeTypeBlock, NodeTypeRoot, NodeTypeDefer, NodeTypeVariableDeclaration
     AstNode *node;
 
     // any variables that are introduced by this scope
@@ -1178,7 +1186,7 @@ struct BlockContext {
 
     LLVMZigDIScope *di_scope;
     Buf *c_import_buf;
-    bool block_exit_paths[3]; // one for each BlockExitPath
+    bool block_exit_paths[BlockExitPathCount];
 };
 
 enum CIntType {
src/analyze.cpp
@@ -57,7 +57,7 @@ static AstNode *first_executing_node(AstNode *node) {
         case NodeTypeBlock:
         case NodeTypeDirective:
         case NodeTypeReturnExpr:
-        case NodeTypeDeferExpr:
+        case NodeTypeDefer:
         case NodeTypeVariableDeclaration:
         case NodeTypeTypeDecl:
         case NodeTypeErrorValueDecl:
@@ -1456,7 +1456,7 @@ static void resolve_top_level_decl(CodeGen *g, ImportTableEntry *import, AstNode
         case NodeTypeParamDecl:
         case NodeTypeFnDecl:
         case NodeTypeReturnExpr:
-        case NodeTypeDeferExpr:
+        case NodeTypeDefer:
         case NodeTypeRoot:
         case NodeTypeBlock:
         case NodeTypeBinOpExpr:
@@ -4590,56 +4590,54 @@ static void validate_voided_expr(CodeGen *g, AstNode *source_node, TypeTableEntr
     }
 }
 
-static TypeTableEntry *analyze_defer_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
+static TypeTableEntry *analyze_defer(CodeGen *g, ImportTableEntry *import, BlockContext *parent_context,
         TypeTableEntry *expected_type, AstNode *node)
 {
-    if (!context->fn_entry) {
+    if (!parent_context->fn_entry) {
         add_node_error(g, node, buf_sprintf("defer expression outside function definition"));
         return g->builtin_types.entry_invalid;
     }
 
-    if (!node->data.defer_expr.expr) {
+    if (!node->data.defer.expr) {
         add_node_error(g, node, buf_sprintf("defer expects an expression"));
         return g->builtin_types.entry_void;
     }
 
+    node->data.defer.child_block = new_block_context(node, parent_context);
 
-    switch (node->data.defer_expr.kind) {
+    switch (node->data.defer.kind) {
         case ReturnKindUnconditional:
             {
-                TypeTableEntry *resolved_type = analyze_expression(g, import, context, nullptr,
-                        node->data.defer_expr.expr);
-                validate_voided_expr(g, node->data.defer_expr.expr, resolved_type);
-                zig_panic("TODO");
+                TypeTableEntry *resolved_type = analyze_expression(g, import, parent_context, nullptr,
+                        node->data.defer.expr);
+                validate_voided_expr(g, node->data.defer.expr, resolved_type);
 
-                //node->data.defer_expr.index_in_block = context->defer_list.length;
-                //context->defer_list.append(node);
                 return g->builtin_types.entry_void;
             }
         case ReturnKindError:
             {
-                TypeTableEntry *resolved_type = analyze_expression(g, import, context, nullptr,
-                        node->data.defer_expr.expr);
+                TypeTableEntry *resolved_type = analyze_expression(g, import, parent_context, nullptr,
+                        node->data.defer.expr);
                 if (resolved_type->id == TypeTableEntryIdInvalid) {
                     // OK
                 } else if (resolved_type->id == TypeTableEntryIdErrorUnion) {
                     // OK
                 } else {
-                    add_node_error(g, node->data.defer_expr.expr,
+                    add_node_error(g, node->data.defer.expr,
                             buf_sprintf("expected error type, got '%s'", buf_ptr(&resolved_type->name)));
                 }
                 return g->builtin_types.entry_void;
             }
         case ReturnKindMaybe:
             {
-                TypeTableEntry *resolved_type = analyze_expression(g, import, context, nullptr,
-                        node->data.defer_expr.expr);
+                TypeTableEntry *resolved_type = analyze_expression(g, import, parent_context, nullptr,
+                        node->data.defer.expr);
                 if (resolved_type->id == TypeTableEntryIdInvalid) {
                     // OK
                 } else if (resolved_type->id == TypeTableEntryIdMaybe) {
                     // OK
                 } else {
-                    add_node_error(g, node->data.defer_expr.expr,
+                    add_node_error(g, node->data.defer.expr,
                             buf_sprintf("expected maybe type, got '%s'", buf_ptr(&resolved_type->name)));
                 }
                 return g->builtin_types.entry_void;
@@ -4657,11 +4655,11 @@ static TypeTableEntry *analyze_string_literal_expr(CodeGen *g, ImportTableEntry
     }
 }
 
-static TypeTableEntry *analyze_block_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
+static TypeTableEntry *analyze_block_expr(CodeGen *g, ImportTableEntry *import, BlockContext *parent_context,
         TypeTableEntry *expected_type, AstNode *node)
 {
-    BlockContext *child_context = new_block_context(node, context);
-    node->data.block.block_context = child_context;
+    BlockContext *child_context = new_block_context(node, parent_context);
+    node->data.block.child_block = child_context;
     TypeTableEntry *return_type = g->builtin_types.entry_void;
 
     for (int i = 0; i < node->data.block.statements.length; i += 1) {
@@ -4676,7 +4674,7 @@ static TypeTableEntry *analyze_block_expr(CodeGen *g, ImportTableEntry *import,
             if (is_node_void_expr(child)) {
                 // {unreachable;void;void} is allowed.
                 // ignore void statements once we enter unreachable land.
-                analyze_expression(g, import, context, g->builtin_types.entry_void, child);
+                analyze_expression(g, import, child_context, g->builtin_types.entry_void, child);
                 continue;
             }
             add_node_error(g, first_executing_node(child), buf_sprintf("unreachable code"));
@@ -4685,10 +4683,16 @@ static TypeTableEntry *analyze_block_expr(CodeGen *g, ImportTableEntry *import,
         bool is_last = (i == node->data.block.statements.length - 1);
         TypeTableEntry *passed_expected_type = is_last ? expected_type : nullptr;
         return_type = analyze_expression(g, import, child_context, passed_expected_type, child);
+        if (child->type == NodeTypeDefer && return_type->id != TypeTableEntryIdInvalid) {
+            // defer starts a new block context
+            child_context = child->data.defer.child_block;
+            assert(child_context);
+        }
         if (!is_last) {
             validate_voided_expr(g, child, return_type);
         }
     }
+    node->data.block.nested_block = child_context;
     return return_type;
 }
 
@@ -4750,8 +4754,8 @@ static TypeTableEntry *analyze_expression(CodeGen *g, ImportTableEntry *import,
         case NodeTypeReturnExpr:
             return_type = analyze_return_expr(g, import, context, expected_type, node);
             break;
-        case NodeTypeDeferExpr:
-            return_type = analyze_defer_expr(g, import, context, expected_type, node);
+        case NodeTypeDefer:
+            return_type = analyze_defer(g, import, context, expected_type, node);
             break;
         case NodeTypeVariableDeclaration:
             analyze_variable_declaration(g, import, context, expected_type, node);
@@ -4956,7 +4960,7 @@ static void analyze_top_level_decl(CodeGen *g, ImportTableEntry *import, AstNode
         case NodeTypeParamDecl:
         case NodeTypeFnDecl:
         case NodeTypeReturnExpr:
-        case NodeTypeDeferExpr:
+        case NodeTypeDefer:
         case NodeTypeRoot:
         case NodeTypeBlock:
         case NodeTypeBinOpExpr:
@@ -5039,8 +5043,8 @@ static void collect_expr_decl_deps(CodeGen *g, ImportTableEntry *import, AstNode
         case NodeTypeReturnExpr:
             collect_expr_decl_deps(g, import, node->data.return_expr.expr, decl_node);
             break;
-        case NodeTypeDeferExpr:
-            collect_expr_decl_deps(g, import, node->data.defer_expr.expr, decl_node);
+        case NodeTypeDefer:
+            collect_expr_decl_deps(g, import, node->data.defer.expr, decl_node);
             break;
         case NodeTypePrefixOpExpr:
             collect_expr_decl_deps(g, import, node->data.prefix_op_expr.primary_expr, decl_node);
@@ -5361,7 +5365,7 @@ static void detect_top_level_decl_deps(CodeGen *g, ImportTableEntry *import, Ast
         case NodeTypeParamDecl:
         case NodeTypeFnDecl:
         case NodeTypeReturnExpr:
-        case NodeTypeDeferExpr:
+        case NodeTypeDefer:
         case NodeTypeBlock:
         case NodeTypeBinOpExpr:
         case NodeTypeUnwrapErrorExpr:
@@ -5552,8 +5556,8 @@ Expr *get_resolved_expr(AstNode *node) {
     switch (node->type) {
         case NodeTypeReturnExpr:
             return &node->data.return_expr.resolved_expr;
-        case NodeTypeDeferExpr:
-            return &node->data.defer_expr.resolved_expr;
+        case NodeTypeDefer:
+            return &node->data.defer.resolved_expr;
         case NodeTypeBinOpExpr:
             return &node->data.bin_op_expr.resolved_expr;
         case NodeTypeUnwrapErrorExpr:
@@ -5652,7 +5656,7 @@ TopLevelDecl *get_resolved_top_level_decl(AstNode *node) {
             return &node->data.type_decl.top_level_decl;
         case NodeTypeNumberLiteral:
         case NodeTypeReturnExpr:
-        case NodeTypeDeferExpr:
+        case NodeTypeDefer:
         case NodeTypeBinOpExpr:
         case NodeTypeUnwrapErrorExpr:
         case NodeTypePrefixOpExpr:
src/ast_render.cpp
@@ -122,8 +122,8 @@ static const char *node_type_str(NodeType node_type) {
             return "Directive";
         case NodeTypeReturnExpr:
             return "ReturnExpr";
-        case NodeTypeDeferExpr:
-            return "DeferExpr";
+        case NodeTypeDefer:
+            return "Defer";
         case NodeTypeVariableDeclaration:
             return "VariableDeclaration";
         case NodeTypeTypeDecl:
@@ -261,12 +261,12 @@ void ast_print(FILE *f, AstNode *node, int indent) {
                     ast_print(f, node->data.return_expr.expr, indent + 2);
                 break;
             }
-        case NodeTypeDeferExpr:
+        case NodeTypeDefer:
             {
-                const char *prefix_str = return_prefix_str(node->data.defer_expr.kind);
+                const char *prefix_str = return_prefix_str(node->data.defer.kind);
                 fprintf(f, "%s%s\n", prefix_str, node_type_str(node->type));
-                if (node->data.defer_expr.expr)
-                    ast_print(f, node->data.defer_expr.expr, indent + 2);
+                if (node->data.defer.expr)
+                    ast_print(f, node->data.defer.expr, indent + 2);
                 break;
             }
         case NodeTypeVariableDeclaration:
@@ -630,7 +630,7 @@ static void render_node(AstRender *ar, AstNode *node) {
             break;
         case NodeTypeReturnExpr:
             zig_panic("TODO");
-        case NodeTypeDeferExpr:
+        case NodeTypeDefer:
             zig_panic("TODO");
         case NodeTypeVariableDeclaration:
             {
src/codegen.cpp
@@ -1669,11 +1669,9 @@ static LLVMValueRef gen_return_expr(CodeGen *g, AstNode *node) {
     }
 }
 
-static LLVMValueRef gen_defer_expr(CodeGen *g, AstNode *node) {
-    assert(node->type == NodeTypeDeferExpr);
+static LLVMValueRef gen_defer(CodeGen *g, AstNode *node) {
+    assert(node->type == NodeTypeDefer);
 
-    zig_panic("TODO");
-    //node->block_context->cur_defer_index = node->data.defer_expr.index_in_block;
 
     return nullptr;
 }
@@ -1800,31 +1798,37 @@ static LLVMValueRef gen_if_var_expr(CodeGen *g, AstNode *node) {
     return return_value;
 }
 
+//static int block_exit_path_count(BlockContext *block_context) {
+//    int sum = 0;
+//    for (int i = 0; i < BlockExitPathCount; i += 1) {
+//        sum += block_context->block_exit_paths[i] ? 1 : 0;
+//    }
+//    return sum;
+//}
+
 static LLVMValueRef gen_block(CodeGen *g, AstNode *block_node, TypeTableEntry *implicit_return_type) {
     assert(block_node->type == NodeTypeBlock);
 
-    /* TODO
-    BlockContext *block_context = block_node->data.block.block_context;
-    if (block_context->defer_list.length > 0) {
-        LLVMBasicBlockRef exit_scope_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "DeferExitScope");
-
-        for (int i = 0; i < block_context->defer_list.length; i += 1) {
-            AstNode *defer_node = block_context->defer_list.at(i);
-            defer_node->data.defer_expr.basic_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "DeferExpr");
-            LLVMPositionBuilderAtEnd(g->builder, body_block);
-        }
-
-        LLVMPositionBuilderAtEnd(g->builder, ?);
-    }
-    */
-
     LLVMValueRef return_value;
     for (int i = 0; i < block_node->data.block.statements.length; i += 1) {
         AstNode *statement_node = block_node->data.block.statements.at(i);
         return_value = gen_expr(g, statement_node);
     }
 
-    if (implicit_return_type && implicit_return_type->id != TypeTableEntryIdUnreachable) {
+    bool end_unreachable = implicit_return_type && implicit_return_type->id == TypeTableEntryIdUnreachable;
+    if (end_unreachable) {
+        return nullptr;
+    }
+
+    BlockContext *block_context = block_node->data.block.nested_block;
+    while (block_context != block_node->data.block.child_block) {
+        if (block_context->node->type == NodeTypeDefer) {
+            gen_expr(g, block_context->node->data.defer.expr);
+        }
+        block_context = block_context->parent;
+    }
+
+    if (implicit_return_type) {
         return gen_return(g, block_node, return_value);
     } else {
         return return_value;
@@ -2475,8 +2479,8 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
             return gen_unwrap_err_expr(g, node);
         case NodeTypeReturnExpr:
             return gen_return_expr(g, node);
-        case NodeTypeDeferExpr:
-            return gen_defer_expr(g, node);
+        case NodeTypeDefer:
+            return gen_defer(g, node);
         case NodeTypeVariableDeclaration:
             return gen_var_decl_expr(g, node);
         case NodeTypePrefixOpExpr:
src/parser.cpp
@@ -1651,7 +1651,7 @@ static AstNode *ast_parse_return_or_defer_expr(ParseContext *pc, int *token_inde
             *token_index += 2;
         } else if (next_token->id == TokenIdKeywordDefer) {
             kind = ReturnKindError;
-            node_type = NodeTypeDeferExpr;
+            node_type = NodeTypeDefer;
             *token_index += 2;
         } else {
             return nullptr;
@@ -1664,7 +1664,7 @@ static AstNode *ast_parse_return_or_defer_expr(ParseContext *pc, int *token_inde
             *token_index += 2;
         } else if (next_token->id == TokenIdKeywordDefer) {
             kind = ReturnKindMaybe;
-            node_type = NodeTypeDeferExpr;
+            node_type = NodeTypeDefer;
             *token_index += 2;
         } else {
             return nullptr;
@@ -1675,7 +1675,7 @@ static AstNode *ast_parse_return_or_defer_expr(ParseContext *pc, int *token_inde
         *token_index += 1;
     } else if (token->id == TokenIdKeywordDefer) {
         kind = ReturnKindUnconditional;
-        node_type = NodeTypeDeferExpr;
+        node_type = NodeTypeDefer;
         *token_index += 1;
     } else {
         return nullptr;
@@ -2703,8 +2703,8 @@ void normalize_parent_ptrs(AstNode *node) {
         case NodeTypeReturnExpr:
             set_field(&node->data.return_expr.expr);
             break;
-        case NodeTypeDeferExpr:
-            set_field(&node->data.defer_expr.expr);
+        case NodeTypeDefer:
+            set_field(&node->data.defer.expr);
             break;
         case NodeTypeVariableDeclaration:
             set_list_fields(node->data.variable_declaration.directives);
test/run_tests.cpp
@@ -1519,6 +1519,19 @@ pub fn main(args: [][]u8) -> %void {
     %%stdout.printf("OK\n");
 }
     )SOURCE", "OK\n");
+
+
+    add_simple_case("defer with only fallthrough", R"SOURCE(
+import "std.zig";
+pub fn main(args: [][]u8) -> %void {
+    %%stdout.printf("before\n");
+    defer %%stdout.printf("defer1\n");
+    defer %%stdout.printf("defer2\n");
+    defer %%stdout.printf("defer3\n");
+    %%stdout.printf("after\n");
+}
+    )SOURCE", "before\nafter\ndefer3\ndefer2\ndefer1\n");
+
 }