Commit 9ca9a2c554

Andrew Kelley <superjoe30@gmail.com>
2015-11-27 18:52:31
allow empty function and return with no expression
1 parent 8219073
src/codegen.cpp
@@ -80,9 +80,14 @@ struct TypeNode {
     TypeTableEntry *entry;
 };
 
+struct FnDefNode {
+    bool add_implicit_return;
+};
+
 struct CodeGenNode {
     union {
         TypeNode type_node; // for NodeTypeType
+        FnDefNode fn_def_node; // for NodeTypeFnDef
     } data;
 };
 
@@ -275,6 +280,60 @@ static void find_declarations(CodeGen *g, AstNode *node) {
     }
 }
 
+static void check_fn_def_control_flow(CodeGen *g, AstNode *node) {
+    // Follow the execution flow and make sure the code returns appropriately.
+    // * A `return` statement in an unreachable type function should be an error.
+    // * Control flow should not be able to reach the end of an unreachable type function.
+    // * Functions that have a type other than void should not return without a value.
+    // * void functions without explicit return statements at the end need the
+    //   add_implicit_return flag set on the codegen node.
+    assert(node->type == NodeTypeFnDef);
+    AstNode *proto_node = node->data.fn_def.fn_proto;
+    assert(proto_node->type == NodeTypeFnProto);
+    AstNode *return_type_node = proto_node->data.fn_proto.return_type;
+    assert(return_type_node->type == NodeTypeType);
+
+    node->codegen_node = allocate<CodeGenNode>(1);
+    FnDefNode *codegen_fn_def = &node->codegen_node->data.fn_def_node;
+
+    assert(return_type_node->codegen_node);
+    TypeTableEntry *type_entry = return_type_node->codegen_node->data.type_node.entry;
+    assert(type_entry);
+    TypeId type_id = type_entry->id;
+
+    AstNode *body_node = node->data.fn_def.body;
+    assert(body_node->type == NodeTypeBlock);
+
+    // TODO once we understand types, do this pass after type checking, and
+    // if an expression has an unreachable value then stop looking at statements after
+    // it. then we can remove the check to `unreachable` in the end of this function.
+    bool prev_statement_return = false;
+    for (int i = 0; i < body_node->data.block.statements.length; i += 1) {
+        AstNode *statement_node = body_node->data.block.statements.at(i);
+        if (statement_node->type == NodeTypeStatementReturn) {
+            if (type_id == TypeIdUnreachable) {
+                add_node_error(g, statement_node,
+                        buf_sprintf("return statement in function with unreachable return type"));
+                return;
+            } else {
+                prev_statement_return = true;
+            }
+        } else if (prev_statement_return) {
+            add_node_error(g, statement_node,
+                    buf_sprintf("unreachable code"));
+        }
+    }
+
+    if (!prev_statement_return) {
+        if (type_id == TypeIdVoid) {
+            codegen_fn_def->add_implicit_return = true;
+        } else if (type_id != TypeIdUnreachable) {
+            add_node_error(g, node,
+                    buf_sprintf("control reaches end of non-void function"));
+        }
+    }
+}
+
 static void analyze_node(CodeGen *g, AstNode *node) {
     switch (node->type) {
         case NodeTypeRoot:
@@ -299,6 +358,8 @@ static void analyze_node(CodeGen *g, AstNode *node) {
                 AstNode *proto_node = node->data.fn_def.fn_proto;
                 assert(proto_node->type == NodeTypeFnProto);
                 analyze_node(g, proto_node);
+
+                check_fn_def_control_flow(g, node);
                 break;
             }
         case NodeTypeFnDecl:
@@ -331,7 +392,9 @@ static void analyze_node(CodeGen *g, AstNode *node) {
             }
             break;
         case NodeTypeStatementReturn:
-            analyze_node(g, node->data.statement_return.expression);
+            if (node->data.statement_return.expression) {
+                analyze_node(g, node->data.statement_return.expression);
+            }
             break;
         case NodeTypeExpression:
             switch (node->data.expression.type) {
@@ -545,7 +608,7 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *expr_node) {
     zig_unreachable();
 }
 
-static void gen_block(CodeGen *g, AstNode *block_node) {
+static void gen_block(CodeGen *g, AstNode *block_node, bool add_implicit_return) {
     assert(block_node->type == NodeTypeBlock);
 
     llvm::DILexicalBlock *di_block = g->dbuilder->createLexicalBlock(g->block_scopes.last(),
@@ -558,10 +621,15 @@ static void gen_block(CodeGen *g, AstNode *block_node) {
             case NodeTypeStatementReturn:
                 {
                     AstNode *expr_node = statement_node->data.statement_return.expression;
-                    LLVMValueRef value = gen_expr(g, expr_node);
+                    if (expr_node) {
+                        LLVMValueRef value = gen_expr(g, expr_node);
 
-                    add_debug_source_node(g, statement_node);
-                    LLVMBuildRet(g->builder, value);
+                        add_debug_source_node(g, statement_node);
+                        LLVMBuildRet(g->builder, value);
+                    } else {
+                        add_debug_source_node(g, statement_node);
+                        LLVMBuildRetVoid(g->builder);
+                    }
                     break;
                 }
             case NodeTypeExpression:
@@ -583,6 +651,10 @@ static void gen_block(CodeGen *g, AstNode *block_node) {
         }
     }
 
+    if (add_implicit_return) {
+        LLVMBuildRetVoid(g->builder);
+    }
+
     g->block_scopes.pop();
 }
 
@@ -685,7 +757,10 @@ void code_gen(CodeGen *g) {
         LLVMBasicBlockRef entry_block = LLVMAppendBasicBlock(fn, "entry");
         LLVMPositionBuilderAtEnd(g->builder, entry_block);
 
-        gen_block(g, fn_def_node->data.fn_def.body);
+        CodeGenNode *codegen_node = fn_def_node->codegen_node;
+        assert(codegen_node);
+        bool add_implicit_return = codegen_node->data.fn_def_node.add_implicit_return;
+        gen_block(g, fn_def_node->data.fn_def.body, add_implicit_return);
 
         g->block_scopes.pop();
     }
src/parser.cpp
@@ -128,7 +128,8 @@ void ast_print(AstNode *node, int indent) {
             break;
         case NodeTypeStatementReturn:
             fprintf(stderr, "ReturnStatement\n");
-            ast_print(node->data.statement_return.expression, indent + 2);
+            if (node->data.statement_return.expression)
+                ast_print(node->data.statement_return.expression, indent + 2);
             break;
         case NodeTypeExternBlock:
             {
@@ -258,7 +259,7 @@ void ast_invalid_token_error(ParseContext *pc, Token *token) {
     ast_error(token, "invalid token: '%s'", buf_ptr(&token_value));
 }
 
-static AstNode *ast_parse_expression(ParseContext *pc, int token_index, int *new_token_index);
+static AstNode *ast_parse_expression(ParseContext *pc, int *token_index, bool mandatory);
 
 
 static void ast_expect_token(ParseContext *pc, Token *token, TokenId token_id) {
@@ -374,7 +375,7 @@ static void ast_parse_fn_call_param_list(ParseContext *pc, int token_index, int
     }
 
     for (;;) {
-        AstNode *expr = ast_parse_expression(pc, token_index, &token_index);
+        AstNode *expr = ast_parse_expression(pc, &token_index, true);
         params->append(expr);
 
         Token *token = &pc->tokens->at(token_index);
@@ -411,28 +412,29 @@ static AstNode *ast_parse_fn_call(ParseContext *pc, int token_index, int *new_to
 /*
 Expression : token(Number) | token(String) | token(Unreachable) | FnCall
 */
-static AstNode *ast_parse_expression(ParseContext *pc, int token_index, int *new_token_index) {
-    Token *token = &pc->tokens->at(token_index);
+static AstNode *ast_parse_expression(ParseContext *pc, int *token_index, bool mandatory) {
+    Token *token = &pc->tokens->at(*token_index);
     AstNode *node = ast_create_node(NodeTypeExpression, token);
     if (token->id == TokenIdKeywordUnreachable) {
         node->data.expression.type = AstNodeExpressionTypeUnreachable;
-        token_index += 1;
+        *token_index += 1;
     } else if (token->id == TokenIdSymbol) {
         node->data.expression.type = AstNodeExpressionTypeFnCall;
-        node->data.expression.data.fn_call = ast_parse_fn_call(pc, token_index, &token_index);
+        node->data.expression.data.fn_call = ast_parse_fn_call(pc, *token_index, token_index);
     } else if (token->id == TokenIdNumberLiteral) {
         node->data.expression.type = AstNodeExpressionTypeNumber;
         ast_buf_from_token(pc, token, &node->data.expression.data.number);
-        token_index += 1;
+        *token_index += 1;
     } else if (token->id == TokenIdStringLiteral) {
         node->data.expression.type = AstNodeExpressionTypeString;
         parse_string_literal(pc, token, &node->data.expression.data.string);
-        token_index += 1;
-    } else {
+        *token_index += 1;
+    } else if (mandatory) {
         ast_invalid_token_error(pc, token);
+    } else {
+        return nullptr;
     }
 
-    *new_token_index = token_index;
     return node;
 }
 
@@ -441,14 +443,14 @@ Statement : ExpressionStatement  | ReturnStatement ;
 
 ExpressionStatement : Expression token(Semicolon) ;
 
-ReturnStatement : token(Return) Expression token(Semicolon) ;
+ReturnStatement : token(Return) option(Expression) token(Semicolon) ;
 */
 static AstNode *ast_parse_statement(ParseContext *pc, int token_index, int *new_token_index) {
     Token *token = &pc->tokens->at(token_index);
     if (token->id == TokenIdKeywordReturn) {
         AstNode *node = ast_create_node(NodeTypeStatementReturn, token);
         token_index += 1;
-        node->data.statement_return.expression = ast_parse_expression(pc, token_index, &token_index);
+        node->data.statement_return.expression = ast_parse_expression(pc, &token_index, false);
 
         Token *semicolon = &pc->tokens->at(token_index);
         token_index += 1;
@@ -460,7 +462,7 @@ static AstNode *ast_parse_statement(ParseContext *pc, int token_index, int *new_
                token->id == TokenIdKeywordUnreachable ||
                token->id == TokenIdNumberLiteral)
     {
-        AstNode *node = ast_parse_expression(pc, token_index, &token_index);
+        AstNode *node = ast_parse_expression(pc, &token_index, true);
 
         Token *semicolon = &pc->tokens->at(token_index);
         token_index += 1;
test/standalone.cpp
@@ -66,7 +66,12 @@ static void add_all_test_cases(void) {
             fn exit(code: i32) -> unreachable;
         }
 
+        fn empty_function_1() {}
+        fn empty_function_2() { return; }
+
         fn _start() -> unreachable {
+            empty_function_1();
+            empty_function_2();
             this_is_a_function();
         }
 
@@ -86,7 +91,7 @@ static void add_all_test_cases(void) {
         /**
          * multi line doc comment
          */
-        fn another_function() -> i32 { return 0; }
+        fn another_function() {}
 
         /// this is a documentation comment
         /// doc comment line 2
README.md
@@ -31,7 +31,9 @@ readable, safe, optimal, and concise code to solve any computing problem.
 
 ## Roadmap
 
- * empty function and return with no expression
+ * pub/private/export functions
+ * make sure that release mode optimizes out empty private functions
+ * test framework to test for compile errors
  * Simple .so library
  * Multiple files
  * figure out integers
@@ -87,7 +89,7 @@ Statement : ExpressionStatement | ReturnStatement
 
 ExpressionStatement : Expression token(Semicolon)
 
-ReturnStatement : token(Return) Expression token(Semicolon)
+ReturnStatement : token(Return) option(Expression) token(Semicolon)
 
 Expression : token(Number) | token(String) | token(Unreachable) | FnCall