Commit 1ed926c321

Josh Wolfe <thejoshwolfe@gmail.com>
2015-12-01 23:54:46
implicit void statements and all tests pass with type checking
1 parent c6a9ab1
src/analyze.cpp
@@ -270,6 +270,7 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import,
         case NodeTypeNumberLiteral:
         case NodeTypeStringLiteral:
         case NodeTypeUnreachable:
+        case NodeTypeVoid:
         case NodeTypeSymbol:
         case NodeTypeCastExpr:
         case NodeTypePrefixOpExpr:
@@ -311,8 +312,12 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
                 for (int i = 0; i < node->data.block.statements.length; i += 1) {
                     AstNode *child = node->data.block.statements.at(i);
                     if (return_type == g->builtin_types.entry_unreachable) {
-                        add_node_error(g, child,
-                                buf_sprintf("unreachable code"));
+                        if (child->type == NodeTypeVoid) {
+                            // {unreachable;void;void} is allowed.
+                            // ignore void statements once we enter unreachable land.
+                            continue;
+                        }
+                        add_node_error(g, child, buf_sprintf("unreachable code"));
                         break;
                     }
                     return_type = analyze_expression(g, import, context, nullptr, child);
@@ -415,6 +420,10 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
             return_type = g->builtin_types.entry_unreachable;
             break;
 
+        case NodeTypeVoid:
+            return_type = g->builtin_types.entry_void;
+            break;
+
         case NodeTypeSymbol:
             // look up symbol in symbol table
             zig_panic("TODO");
@@ -439,59 +448,6 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
     return return_type;
 }
 
-static void check_fn_def_control_flow(CodeGen *g, AstNode *node) {
-    // Follow the execution flow and make sure the code returns appropriately.
-    // * A `return` statement in an unreachable type function should be an error.
-    // * Control flow should not be able to reach the end of an unreachable type function.
-    // * Functions that have a type other than void should not return without a value.
-    // * void functions without explicit return statements at the end need the
-    //   add_implicit_return flag set on the codegen node.
-    assert(node->type == NodeTypeFnDef);
-    AstNode *proto_node = node->data.fn_def.fn_proto;
-    assert(proto_node->type == NodeTypeFnProto);
-    AstNode *return_type_node = proto_node->data.fn_proto.return_type;
-    assert(return_type_node->type == NodeTypeType);
-
-    node->codegen_node = allocate<CodeGenNode>(1);
-    FnDefNode *codegen_fn_def = &node->codegen_node->data.fn_def_node;
-
-    assert(return_type_node->codegen_node);
-    TypeTableEntry *type_entry = return_type_node->codegen_node->data.type_node.entry;
-    assert(type_entry);
-
-    AstNode *body_node = node->data.fn_def.body;
-    assert(body_node->type == NodeTypeBlock);
-
-    // TODO once we understand types, do this pass after type checking, and
-    // if an expression has an unreachable value then stop looking at statements after
-    // it. then we can remove the check to `unreachable` in the end of this function.
-    bool prev_statement_return = false;
-    for (int i = 0; i < body_node->data.block.statements.length; i += 1) {
-        AstNode *statement_node = body_node->data.block.statements.at(i);
-        if (statement_node->type == NodeTypeReturnExpr) {
-            if (type_entry == g->builtin_types.entry_unreachable) {
-                add_node_error(g, statement_node,
-                        buf_sprintf("return statement in function with unreachable return type"));
-                return;
-            } else {
-                prev_statement_return = true;
-            }
-        } else if (prev_statement_return) {
-            add_node_error(g, statement_node,
-                    buf_sprintf("unreachable code"));
-        }
-    }
-
-    if (!prev_statement_return) {
-        if (type_entry == g->builtin_types.entry_void) {
-            codegen_fn_def->add_implicit_return = true;
-        } else if (type_entry != g->builtin_types.entry_unreachable) {
-            add_node_error(g, node,
-                    buf_sprintf("control reaches end of non-void function"));
-        }
-    }
-}
-
 static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import, AstNode *node) {
     switch (node->type) {
         case NodeTypeFnDef:
@@ -512,14 +468,15 @@ static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import,
                     // TODO: define local variables for parameters
                 }
 
-                check_fn_def_control_flow(g, node);
-
                 BlockContext context;
                 context.node = node;
                 context.root = &context;
                 context.parent = nullptr;
                 TypeTableEntry *expected_type = fn_proto->return_type->codegen_node->data.type_node.entry;
-                analyze_expression(g, import, &context, expected_type, node->data.fn_def.body);
+                TypeTableEntry *block_return_type = analyze_expression(g, import, &context, expected_type, node->data.fn_def.body);
+
+                node->codegen_node = allocate<CodeGenNode>(1);
+                node->codegen_node->data.fn_def_node.implicit_return_type = block_return_type;
             }
             break;
 
@@ -548,6 +505,7 @@ static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import,
         case NodeTypeNumberLiteral:
         case NodeTypeStringLiteral:
         case NodeTypeUnreachable:
+        case NodeTypeVoid:
         case NodeTypeSymbol:
         case NodeTypeCastExpr:
         case NodeTypePrefixOpExpr:
src/codegen.cpp
@@ -401,6 +401,8 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
         case NodeTypeUnreachable:
             add_debug_source_node(g, node);
             return LLVMBuildUnreachable(g->builder);
+        case NodeTypeVoid:
+            return nullptr;
         case NodeTypeNumberLiteral:
             {
                 Buf *number_str = &node->data.number;
@@ -441,7 +443,7 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
     zig_unreachable();
 }
 
-static void gen_block(CodeGen *g, ImportTableEntry *import, AstNode *block_node, bool add_implicit_return) {
+static void gen_block(CodeGen *g, ImportTableEntry *import, AstNode *block_node, TypeTableEntry *implicit_return_type) {
     assert(block_node->type == NodeTypeBlock);
 
     LLVMZigDILexicalBlock *di_block = LLVMZigCreateLexicalBlock(g->dbuilder, g->block_scopes.last(),
@@ -450,13 +452,16 @@ static void gen_block(CodeGen *g, ImportTableEntry *import, AstNode *block_node,
 
     add_debug_source_node(g, block_node);
 
+    LLVMValueRef return_value;
     for (int i = 0; i < block_node->data.block.statements.length; i += 1) {
         AstNode *statement_node = block_node->data.block.statements.at(i);
-        gen_expr(g, statement_node);
+        return_value = gen_expr(g, statement_node);
     }
 
-    if (add_implicit_return) {
+    if (implicit_return_type == g->builtin_types.entry_void) {
         LLVMBuildRetVoid(g->builder);
+    } else if (implicit_return_type != g->builtin_types.entry_unreachable) {
+        LLVMBuildRet(g->builder, return_value);
     }
 
     g->block_scopes.pop();
@@ -552,8 +557,8 @@ static void do_code_gen(CodeGen *g) {
         codegen_fn_def->params = allocate<LLVMValueRef>(LLVMCountParams(fn));
         LLVMGetParams(fn, codegen_fn_def->params);
 
-        bool add_implicit_return = codegen_fn_def->add_implicit_return;
-        gen_block(g, import, fn_def_node->data.fn_def.body, add_implicit_return);
+        TypeTableEntry *implicit_return_type = codegen_fn_def->implicit_return_type;
+        gen_block(g, import, fn_def_node->data.fn_def.body, implicit_return_type);
 
         g->block_scopes.pop();
     }
src/parser.cpp
@@ -89,6 +89,8 @@ const char *node_type_str(NodeType node_type) {
             return "PrefixOpExpr";
         case NodeTypeUse:
             return "Use";
+        case NodeTypeVoid:
+            return "Void";
     }
     zig_unreachable();
 }
@@ -233,6 +235,9 @@ void ast_print(AstNode *node, int indent) {
         case NodeTypeUse:
             fprintf(stderr, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.use.path));
             break;
+        case NodeTypeVoid:
+            fprintf(stderr, "Void\n");
+            break;
     }
 }
 
@@ -416,6 +421,9 @@ static AstNode *ast_parse_type(ParseContext *pc, int token_index, int *new_token
     if (token->id == TokenIdKeywordUnreachable) {
         node->data.type.type = AstNodeTypeTypePrimitive;
         buf_init_from_str(&node->data.type.primitive_name, "unreachable");
+    } else if (token->id == TokenIdKeywordVoid) {
+        node->data.type.type = AstNodeTypeTypePrimitive;
+        buf_init_from_str(&node->data.type.primitive_name, "void");
     } else if (token->id == TokenIdSymbol) {
         node->data.type.type = AstNodeTypeTypePrimitive;
         ast_buf_from_token(pc, token, &node->data.type.primitive_name);
@@ -569,6 +577,10 @@ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool
         AstNode *node = ast_create_node(pc, NodeTypeUnreachable, token);
         *token_index += 1;
         return node;
+    } else if (token->id == TokenIdKeywordVoid) {
+        AstNode *node = ast_create_node(pc, NodeTypeVoid, token);
+        *token_index += 1;
+        return node;
     } else if (token->id == TokenIdSymbol) {
         AstNode *node = ast_create_node(pc, NodeTypeSymbol, token);
         ast_buf_from_token(pc, token, &node->data.symbol);
@@ -1024,50 +1036,42 @@ static AstNode *ast_parse_expression(ParseContext *pc, int *token_index, bool ma
 }
 
 /*
-ExpressionStatement : Expression token(Semicolon)
-*/
-static AstNode *ast_parse_expression_statement(ParseContext *pc, int *token_index) {
-    AstNode *expr_node = ast_parse_expression(pc, token_index, true);
-
-    Token *semicolon = &pc->tokens->at(*token_index);
-    *token_index += 1;
-    ast_expect_token(pc, semicolon, TokenIdSemicolon);
-
-    return expr_node;
-}
-
-/*
-Statement : ExpressionStatement
-*/
-static AstNode *ast_parse_statement(ParseContext *pc, int *token_index) {
-    return ast_parse_expression_statement(pc, token_index);
-}
-
-/*
-Block : token(LBrace) many(Statement) token(RBrace);
+Block : token(LBrace) list(option(Expression), token(Semicolon)) token(RBrace)
 */
 static AstNode *ast_parse_block(ParseContext *pc, int *token_index, bool mandatory) {
-    Token *l_brace = &pc->tokens->at(*token_index);
+    Token *last_token = &pc->tokens->at(*token_index);
 
-    if (l_brace->id != TokenIdLBrace) {
+    if (last_token->id != TokenIdLBrace) {
         if (mandatory) {
-            ast_invalid_token_error(pc, l_brace);
+            ast_invalid_token_error(pc, last_token);
         } else {
             return nullptr;
         }
     }
     *token_index += 1;
 
-    AstNode *node = ast_create_node(pc, NodeTypeBlock, l_brace);
+    AstNode *node = ast_create_node(pc, NodeTypeBlock, last_token);
 
+    // {}   -> {void}
+    // {;}  -> {void;void}
+    // {2}  -> {2}
+    // {2;} -> {2;void}
+    // {;2} -> {void;2}
     for (;;) {
-        Token *token = &pc->tokens->at(*token_index);
-        if (token->id == TokenIdRBrace) {
+        AstNode *expression_node = ast_parse_expression(pc, token_index, false);
+        if (!expression_node) {
+            expression_node = ast_create_node(pc, NodeTypeVoid, last_token);
+        }
+        node->data.block.statements.append(expression_node);
+
+        last_token = &pc->tokens->at(*token_index);
+        if (last_token->id == TokenIdRBrace) {
             *token_index += 1;
             return node;
+        } else if (last_token->id == TokenIdSemicolon) {
+            *token_index += 1;
         } else {
-            AstNode *statement_node = ast_parse_statement(pc, token_index);
-            node->data.block.statements.append(statement_node);
+            ast_invalid_token_error(pc, last_token);
         }
     }
     zig_unreachable();
src/parser.hpp
@@ -38,6 +38,7 @@ enum NodeType {
     NodeTypePrefixOpExpr,
     NodeTypeFnCallExpr,
     NodeTypeUse,
+    NodeTypeVoid,
 };
 
 struct AstNodeRoot {
src/semantic_info.hpp
@@ -106,7 +106,7 @@ struct TypeNode {
 };
 
 struct FnDefNode {
-    bool add_implicit_return;
+    TypeTableEntry *implicit_return_type;
     bool skip;
     LLVMValueRef *params;
 };
src/tokenizer.cpp
@@ -181,6 +181,8 @@ static void end_token(Tokenize *t) {
         t->cur_tok->id = TokenIdKeywordAs;
     } else if (mem_eql_str(token_mem, token_len, "use")) {
         t->cur_tok->id = TokenIdKeywordUse;
+    } else if (mem_eql_str(token_mem, token_len, "void")) {
+        t->cur_tok->id = TokenIdKeywordVoid;
     }
 
     t->cur_tok = nullptr;
@@ -574,6 +576,7 @@ static const char * token_name(Token *token) {
         case TokenIdKeywordExport: return "Export";
         case TokenIdKeywordAs: return "As";
         case TokenIdKeywordUse: return "Use";
+        case TokenIdKeywordVoid: return "Void";
         case TokenIdLParen: return "LParen";
         case TokenIdRParen: return "RParen";
         case TokenIdComma: return "Comma";
src/tokenizer.hpp
@@ -23,6 +23,7 @@ enum TokenId {
     TokenIdKeywordExport,
     TokenIdKeywordAs,
     TokenIdKeywordUse,
+    TokenIdKeywordVoid,
     TokenIdLParen,
     TokenIdRParen,
     TokenIdComma,
test/run_tests.cpp
@@ -209,11 +209,11 @@ fn a() {}
 
     add_compile_fail_case("unreachable with return", R"SOURCE(
 fn a() -> unreachable {return;}
-    )SOURCE", 1, ".tmp_source.zig:2:24: error: return statement in function with unreachable return type");
+    )SOURCE", 1, ".tmp_source.zig:2:24: error: type mismatch. expected unreachable. got void");
 
     add_compile_fail_case("control reaches end of non-void function", R"SOURCE(
 fn a() -> i32 {}
-    )SOURCE", 1, ".tmp_source.zig:2:1: error: control reaches end of non-void function");
+    )SOURCE", 1, ".tmp_source.zig:2:15: error: type mismatch. expected i32. got void");
 
     add_compile_fail_case("undefined function call", R"SOURCE(
 fn a() {
README.md
@@ -104,11 +104,7 @@ Type : token(Symbol) | PointerType | token(Unreachable)
 
 PointerType : token(Star) token(Const) Type | token(Star) token(Mut) Type
 
-Block : token(LBrace) many(Statement) token(RBrace)
-
-Statement : ExpressionStatement
-
-ExpressionStatement : Expression token(Semicolon)
+Block : token(LBrace) list(option(Expression), token(Semicolon)) token(RBrace)
 
 Expression : BoolOrExpression | ReturnExpression