Commit c1d77f2a23

Andrew Kelley <superjoe30@gmail.com>
2015-11-29 22:39:11
function call names are expressions
1 parent 918e764
src/codegen.cpp
@@ -306,7 +306,7 @@ static void find_declarations(CodeGen *g, AstNode *node) {
         case NodeTypeRoot:
         case NodeTypeBlock:
         case NodeTypeBoolOrExpr:
-        case NodeTypeFnCall:
+        case NodeTypeFnCallExpr:
         case NodeTypeRootExportDecl:
         case NodeTypeBoolAndExpr:
         case NodeTypeComparisonExpr:
@@ -378,6 +378,14 @@ static void check_fn_def_control_flow(CodeGen *g, AstNode *node) {
     }
 }
 
+static Buf *hack_get_fn_call_name(CodeGen *g, AstNode *node) {
+    // Assume that the expression evaluates to a simple name and return the buf
+    // TODO after type checking works we should be able to remove this hack
+    assert(node->type == NodeTypePrimaryExpr);
+    assert(node->data.primary_expr.type == PrimaryExprTypeSymbol);
+    return &node->data.primary_expr.data.symbol;
+}
+
 static void analyze_node(CodeGen *g, AstNode *node) {
     switch (node->type) {
         case NodeTypeRoot:
@@ -487,9 +495,9 @@ static void analyze_node(CodeGen *g, AstNode *node) {
             if (node->data.bool_or_expr.op2)
                 analyze_node(g, node->data.bool_or_expr.op2);
             break;
-        case NodeTypeFnCall:
+        case NodeTypeFnCallExpr:
             {
-                Buf *name = &node->data.fn_call.name;
+                Buf *name = hack_get_fn_call_name(g, node->data.fn_call_expr.fn_ref_expr);
 
                 auto entry = g->fn_table.maybe_get(name);
                 if (!entry) {
@@ -499,7 +507,7 @@ static void analyze_node(CodeGen *g, AstNode *node) {
                     FnTableEntry *fn_table_entry = entry->value;
                     assert(fn_table_entry->proto_node->type == NodeTypeFnProto);
                     int expected_param_count = fn_table_entry->proto_node->data.fn_proto.params.length;
-                    int actual_param_count = node->data.fn_call.params.length;
+                    int actual_param_count = node->data.fn_call_expr.params.length;
                     if (expected_param_count != actual_param_count) {
                         add_node_error(g, node,
                                 buf_sprintf("wrong number of arguments. Expected %d, got %d.",
@@ -507,8 +515,8 @@ static void analyze_node(CodeGen *g, AstNode *node) {
                     }
                 }
 
-                for (int i = 0; i < node->data.fn_call.params.length; i += 1) {
-                    AstNode *child = node->data.fn_call.params.at(i);
+                for (int i = 0; i < node->data.fn_call_expr.params.length; i += 1) {
+                    AstNode *child = node->data.fn_call_expr.params.at(i);
                     analyze_node(g, child);
                 }
                 break;
@@ -551,11 +559,9 @@ static void analyze_node(CodeGen *g, AstNode *node) {
                 case PrimaryExprTypeNumber:
                 case PrimaryExprTypeString:
                 case PrimaryExprTypeUnreachable:
+                case PrimaryExprTypeSymbol:
                     // nothing to do
                     break;
-                case PrimaryExprTypeFnCall:
-                    analyze_node(g, node->data.primary_expr.data.fn_call);
-                    break;
                 case PrimaryExprTypeGroupedExpr:
                     analyze_node(g, node->data.primary_expr.data.grouped_expr);
                     break;
@@ -662,33 +668,6 @@ static void add_debug_source_node(CodeGen *g, AstNode *node) {
                 g->block_scopes.last()));
 }
 
-static LLVMValueRef gen_fn_call(CodeGen *g, AstNode *fn_call_node) {
-    assert(fn_call_node->type == NodeTypeFnCall);
-
-    Buf *name = &fn_call_node->data.fn_call.name;
-    FnTableEntry *fn_table_entry = g->fn_table.get(name);
-    assert(fn_table_entry->proto_node->type == NodeTypeFnProto);
-    int expected_param_count = fn_table_entry->proto_node->data.fn_proto.params.length;
-    int actual_param_count = fn_call_node->data.fn_call.params.length;
-    assert(expected_param_count == actual_param_count);
-
-    LLVMValueRef *param_values = allocate<LLVMValueRef>(actual_param_count);
-    for (int i = 0; i < actual_param_count; i += 1) {
-        AstNode *expr_node = fn_call_node->data.fn_call.params.at(i);
-        param_values[i] = gen_expr(g, expr_node);
-    }
-
-    add_debug_source_node(g, fn_call_node);
-    LLVMValueRef result = LLVMZigBuildCall(g->builder, fn_table_entry->fn_value,
-            param_values, actual_param_count, fn_table_entry->calling_convention, "");
-
-    if (type_is_unreachable(fn_table_entry->proto_node->data.fn_proto.return_type)) {
-        return LLVMBuildUnreachable(g->builder);
-    } else {
-        return result;
-    }
-}
-
 static LLVMValueRef find_or_create_string(CodeGen *g, Buf *str) {
     auto entry = g->str_table.maybe_get(str);
     if (entry) {
@@ -733,17 +712,47 @@ static LLVMValueRef gen_primary_expr(CodeGen *g, AstNode *node) {
         case PrimaryExprTypeUnreachable:
             add_debug_source_node(g, node);
             return LLVMBuildUnreachable(g->builder);
-        case PrimaryExprTypeFnCall:
-            return gen_fn_call(g, prim_expr->data.fn_call);
         case PrimaryExprTypeGroupedExpr:
             return gen_expr(g, prim_expr->data.grouped_expr);
         case PrimaryExprTypeBlock:
+            zig_panic("TODO block in expression");
+            break;
+        case PrimaryExprTypeSymbol:
+            zig_panic("TODO variable reference");
             break;
     }
 
     zig_unreachable();
 }
 
+static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) {
+    assert(node->type == NodeTypeFnCallExpr);
+
+    Buf *name = hack_get_fn_call_name(g, node->data.fn_call_expr.fn_ref_expr);
+
+    FnTableEntry *fn_table_entry = g->fn_table.get(name);
+    assert(fn_table_entry->proto_node->type == NodeTypeFnProto);
+    int expected_param_count = fn_table_entry->proto_node->data.fn_proto.params.length;
+    int actual_param_count = node->data.fn_call_expr.params.length;
+    assert(expected_param_count == actual_param_count);
+
+    LLVMValueRef *param_values = allocate<LLVMValueRef>(actual_param_count);
+    for (int i = 0; i < actual_param_count; i += 1) {
+        AstNode *expr_node = node->data.fn_call_expr.params.at(i);
+        param_values[i] = gen_expr(g, expr_node);
+    }
+
+    add_debug_source_node(g, node);
+    LLVMValueRef result = LLVMZigBuildCall(g->builder, fn_table_entry->fn_value,
+            param_values, actual_param_count, fn_table_entry->calling_convention, "");
+
+    if (type_is_unreachable(fn_table_entry->proto_node->data.fn_proto.return_type)) {
+        return LLVMBuildUnreachable(g->builder);
+    } else {
+        return result;
+    }
+}
+
 static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypePrefixOpExpr);
     assert(node->data.prefix_op_expr.primary_expr);
@@ -1028,6 +1037,8 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
             return gen_return_expr(g, node);
         case NodeTypePrefixOpExpr:
             return gen_prefix_op_expr(g, node);
+        case NodeTypeFnCallExpr:
+            return gen_fn_call_expr(g, node);
         case NodeTypeRoot:
         case NodeTypeRootExportDecl:
         case NodeTypeFnProto:
@@ -1036,7 +1047,6 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
         case NodeTypeParamDecl:
         case NodeTypeType:
         case NodeTypeBlock:
-        case NodeTypeFnCall:
         case NodeTypeExternBlock:
         case NodeTypeDirective:
         case NodeTypeBoolAndExpr:
src/parser.cpp
@@ -96,8 +96,8 @@ const char *node_type_str(NodeType node_type) {
             return "Block";
         case NodeTypeBoolOrExpr:
             return "BoolOrExpr";
-        case NodeTypeFnCall:
-            return "FnCall";
+        case NodeTypeFnCallExpr:
+            return "FnCallExpr";
         case NodeTypeExternBlock:
             return "ExternBlock";
         case NodeTypeDirective:
@@ -232,10 +232,11 @@ void ast_print(AstNode *node, int indent) {
             if (node->data.bool_or_expr.op2)
                 ast_print(node->data.bool_or_expr.op2, indent + 2);
             break;
-        case NodeTypeFnCall:
-            fprintf(stderr, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.fn_call.name));
-            for (int i = 0; i < node->data.fn_call.params.length; i += 1) {
-                AstNode *child = node->data.fn_call.params.at(i);
+        case NodeTypeFnCallExpr:
+            fprintf(stderr, "%s\n", node_type_str(node->type));
+            ast_print(node->data.fn_call_expr.fn_ref_expr, indent + 2);
+            for (int i = 0; i < node->data.fn_call_expr.params.length; i += 1) {
+                AstNode *child = node->data.fn_call_expr.params.at(i);
                 ast_print(child, indent + 2);
             }
             break;
@@ -318,10 +319,6 @@ void ast_print(AstNode *node, int indent) {
                 case PrimaryExprTypeUnreachable:
                     fprintf(stderr, "PrimaryExpr Unreachable\n");
                     break;
-                case PrimaryExprTypeFnCall:
-                    fprintf(stderr, "PrimaryExpr FnCall\n");
-                    ast_print(node->data.primary_expr.data.fn_call, indent + 2);
-                    break;
                 case PrimaryExprTypeGroupedExpr:
                     fprintf(stderr, "PrimaryExpr GroupedExpr\n");
                     ast_print(node->data.primary_expr.data.grouped_expr, indent + 2);
@@ -330,6 +327,10 @@ void ast_print(AstNode *node, int indent) {
                     fprintf(stderr, "PrimaryExpr Block\n");
                     ast_print(node->data.primary_expr.data.block, indent + 2);
                     break;
+                case PrimaryExprTypeSymbol:
+                    fprintf(stderr, "PrimaryExpr Symbol %s\n",
+                            buf_ptr(&node->data.primary_expr.data.symbol));
+                    break;
             }
             break;
         case NodeTypeGroupedExpr:
@@ -626,32 +627,7 @@ static AstNode *ast_parse_grouped_expr(ParseContext *pc, int *token_index, bool
 }
 
 /*
-FnCall : token(Symbol) token(LParen) list(Expression, token(Comma)) token(RParen) ;
-*/
-static AstNode *ast_parse_fn_call(ParseContext *pc, int *token_index, bool mandatory) {
-    Token *fn_name = &pc->tokens->at(*token_index);
-    if (fn_name->id != TokenIdSymbol) {
-        if (mandatory) {
-            ast_invalid_token_error(pc, fn_name);
-        } else {
-            return nullptr;
-        }
-    }
-
-    *token_index += 1;
-
-    AstNode *node = ast_create_node(NodeTypeFnCall, fn_name);
-
-
-    ast_buf_from_token(pc, fn_name, &node->data.fn_call.name);
-
-    ast_parse_fn_call_param_list(pc, *token_index, token_index, &node->data.fn_call.params);
-
-    return node;
-}
-
-/*
-PrimaryExpression : token(Number) | token(String) | token(Unreachable) | FnCall | GroupedExpression | Block
+PrimaryExpression : token(Number) | token(String) | token(Unreachable) | GroupedExpression | Block | token(Symbol)
 */
 static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool mandatory) {
     Token *token = &pc->tokens->at(*token_index);
@@ -673,6 +649,12 @@ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool
         node->data.primary_expr.type = PrimaryExprTypeUnreachable;
         *token_index += 1;
         return node;
+    } else if (token->id == TokenIdSymbol) {
+        AstNode *node = ast_create_node(NodeTypePrimaryExpr, token);
+        node->data.primary_expr.type = PrimaryExprTypeSymbol;
+        ast_buf_from_token(pc, token, &node->data.primary_expr.data.symbol);
+        *token_index += 1;
+        return node;
     }
 
     AstNode *block_node = ast_parse_block(pc, token_index, false);
@@ -691,20 +673,31 @@ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool
         return node;
     }
 
-    AstNode *fn_call_node = ast_parse_fn_call(pc, token_index, false);
-    if (fn_call_node) {
-        AstNode *node = ast_create_node(NodeTypePrimaryExpr, token);
-        node->data.primary_expr.type = PrimaryExprTypeFnCall;
-        node->data.primary_expr.data.fn_call = fn_call_node;
-        return node;
-    }
-
     if (!mandatory)
         return nullptr;
 
     ast_invalid_token_error(pc, token);
 }
 
+/*
+FnCallExpression : PrimaryExpression token(LParen) list(Expression, token(Comma)) token(RParen) | PrimaryExpression
+*/
+static AstNode *ast_parse_fn_call_expr(ParseContext *pc, int *token_index, bool mandatory) {
+    AstNode *primary_expr = ast_parse_primary_expr(pc, token_index, mandatory);
+    if (!primary_expr)
+        return nullptr;
+
+    Token *l_paren = &pc->tokens->at(*token_index);
+    if (l_paren->id != TokenIdLParen)
+        return primary_expr;
+
+    AstNode *node = ast_create_node_with_node(NodeTypeFnCallExpr, primary_expr);
+    node->data.fn_call_expr.fn_ref_expr = primary_expr;
+    ast_parse_fn_call_param_list(pc, *token_index, token_index, &node->data.fn_call_expr.params);
+
+    return node;
+}
+
 static PrefixOp tok_to_prefix_op(Token *token) {
     switch (token->id) {
         case TokenIdBang: return PrefixOpBoolNot;
@@ -732,15 +725,15 @@ static PrefixOp ast_parse_prefix_op(ParseContext *pc, int *token_index, bool man
 }
 
 /*
-PrefixOpExpression : PrefixOp PrimaryExpression | PrimaryExpression
+PrefixOpExpression : PrefixOp FnCallExpression | FnCallExpression
 */
 static AstNode *ast_parse_prefix_op_expr(ParseContext *pc, int *token_index, bool mandatory) {
     Token *token = &pc->tokens->at(*token_index);
     PrefixOp prefix_op = ast_parse_prefix_op(pc, token_index, false);
     if (prefix_op == PrefixOpInvalid)
-        return ast_parse_primary_expr(pc, token_index, mandatory);
+        return ast_parse_fn_call_expr(pc, token_index, mandatory);
 
-    AstNode *primary_expr = ast_parse_primary_expr(pc, token_index, true);
+    AstNode *primary_expr = ast_parse_fn_call_expr(pc, token_index, true);
     AstNode *node = ast_create_node(NodeTypePrefixOpExpr, token);
     node->data.prefix_op_expr.primary_expr = primary_expr;
     node->data.prefix_op_expr.prefix_op = prefix_op;
src/parser.hpp
@@ -24,7 +24,6 @@ enum NodeType {
     NodeTypeParamDecl,
     NodeTypeType,
     NodeTypeBlock,
-    NodeTypeFnCall,
     NodeTypeExternBlock,
     NodeTypeDirective,
     NodeTypeReturnExpr,
@@ -41,6 +40,7 @@ enum NodeType {
     NodeTypePrimaryExpr,
     NodeTypeGroupedExpr,
     NodeTypePrefixOpExpr,
+    NodeTypeFnCallExpr,
 };
 
 struct AstNodeRoot {
@@ -103,8 +103,8 @@ struct AstNodeBoolOrExpr {
     AstNode *op2;
 };
 
-struct AstNodeFnCall {
-    Buf name;
+struct AstNodeFnCallExpr {
+    AstNode *fn_ref_expr;
     ZigList<AstNode *> params;
 };
 
@@ -214,9 +214,9 @@ enum PrimaryExprType {
     PrimaryExprTypeNumber,
     PrimaryExprTypeString,
     PrimaryExprTypeUnreachable,
-    PrimaryExprTypeFnCall,
     PrimaryExprTypeGroupedExpr,
     PrimaryExprTypeBlock,
+    PrimaryExprTypeSymbol,
 };
 
 struct AstNodePrimaryExpr {
@@ -224,7 +224,7 @@ struct AstNodePrimaryExpr {
     union {
         Buf number;
         Buf string;
-        AstNode *fn_call;
+        Buf symbol;
         AstNode *grouped_expr;
         AstNode *block;
     } data;
@@ -263,7 +263,6 @@ struct AstNode {
         AstNodeBlock block;
         AstNodeReturnExpr return_expr;
         AstNodeBoolOrExpr bool_or_expr;
-        AstNodeFnCall fn_call;
         AstNodeExternBlock extern_block;
         AstNodeDirective directive;
         AstNodeBoolAndExpr bool_and_expr;
@@ -278,6 +277,7 @@ struct AstNode {
         AstNodePrimaryExpr primary_expr;
         AstNodeGroupedExpr grouped_expr;
         AstNodePrefixOpExpr prefix_op_expr;
+        AstNodeFnCallExpr fn_call_expr;
     } data;
 };
 
README.md
@@ -88,6 +88,8 @@ ExternBlock : many(Directive) token(Extern) token(LBrace) many(FnDecl) token(RBr
 
 FnProto : many(Directive) option(FnVisibleMod) token(Fn) token(Symbol) ParamDeclList option(token(Arrow) Type)
 
+Directive : token(NumberSign) token(Symbol) token(LParen) token(String) token(RParen)
+
 FnVisibleMod : token(Pub) | token(Export)
 
 FnDecl : FnProto token(Semicolon)
@@ -142,15 +144,13 @@ CastExpression : PrefixOpExpression token(as) Type | PrefixOpExpression
 
 PrefixOpExpression : PrefixOp FnCallExpression | FnCallExpression
 
-FnCallExpression : PrimaryExpression token(LParen) list(Expression, token(Comma)) token(RParen)
+FnCallExpression : PrimaryExpression token(LParen) list(Expression, token(Comma)) token(RParen) | PrimaryExpression
 
 PrefixOp : token(Not) | token(Dash) | token(Tilde)
 
-PrimaryExpression : token(Number) | token(String) | token(Unreachable) | GroupedExpression | Block
+PrimaryExpression : token(Number) | token(String) | token(Unreachable) | GroupedExpression | Block | token(Symbol)
 
 GroupedExpression : token(LParen) Expression token(RParen)
-
-Directive : token(NumberSign) token(Symbol) token(LParen) token(String) token(RParen)
 ```
 
 ### Operator Precedence