Commit c17309dbc5

Andrew Kelley <superjoe30@gmail.com>
2016-01-20 04:29:09
add switch statement support to parser
1 parent 17e574f
doc/langref.md
@@ -94,7 +94,7 @@ BlockExpression : IfExpression | Block | WhileExpression | ForExpression | Switc
 
 SwitchExpression : "switch" "(" Expression ")" "{" many(SwitchProng) "}"
 
-SwitchProng : (list(SwitchItem, ",") | "else") option("(" "Symbol" ")") "=>" Expression ","
+SwitchProng : (list(SwitchItem, ",") | "else") option("," "(" "Symbol" ")") "=>" Expression ","
 
 SwitchItem : Expression | (Expression "..." Expression)
 
src/all_types.hpp
@@ -139,6 +139,9 @@ enum NodeType {
     NodeTypeIfVarExpr,
     NodeTypeWhileExpr,
     NodeTypeForExpr,
+    NodeTypeSwitchExpr,
+    NodeTypeSwitchProng,
+    NodeTypeSwitchRange,
     NodeTypeLabel,
     NodeTypeGoto,
     NodeTypeBreak,
@@ -411,6 +414,25 @@ struct AstNodeForExpr {
     VariableTableEntry *index_var;
 };
 
+struct AstNodeSwitchExpr {
+    AstNode *expr;
+    ZigList<AstNode *> prongs;
+
+    // populated by semantic analyzer
+    Expr resolved_expr;
+};
+
+struct AstNodeSwitchProng {
+    ZigList<AstNode *> items;
+    AstNode *var_symbol;
+    AstNode *expr;
+};
+
+struct AstNodeSwitchRange {
+    AstNode *start;
+    AstNode *end;
+};
+
 struct AstNodeLabel {
     Buf name;
 
@@ -623,6 +645,9 @@ struct AstNode {
         AstNodeIfVarExpr if_var_expr;
         AstNodeWhileExpr while_expr;
         AstNodeForExpr for_expr;
+        AstNodeSwitchExpr switch_expr;
+        AstNodeSwitchProng switch_prong;
+        AstNodeSwitchRange switch_range;
         AstNodeLabel label;
         AstNodeGoto goto_expr;
         AstNodeAsmExpr asm_expr;
src/analyze.cpp
@@ -30,6 +30,8 @@ static AstNode *first_executing_node(AstNode *node) {
             return first_executing_node(node->data.slice_expr.array_ref_expr);
         case NodeTypeFieldAccessExpr:
             return first_executing_node(node->data.field_access_expr.struct_expr);
+        case NodeTypeSwitchRange:
+            return first_executing_node(node->data.switch_range.start);
         case NodeTypeRoot:
         case NodeTypeRootExportDecl:
         case NodeTypeFnProto:
@@ -61,6 +63,8 @@ static AstNode *first_executing_node(AstNode *node) {
         case NodeTypeStructValueField:
         case NodeTypeWhileExpr:
         case NodeTypeForExpr:
+        case NodeTypeSwitchExpr:
+        case NodeTypeSwitchProng:
         case NodeTypeContainerInitExpr:
         case NodeTypeArrayType:
             return node;
@@ -943,6 +947,9 @@ static void resolve_top_level_decl(CodeGen *g, ImportTableEntry *import, AstNode
         case NodeTypeIfVarExpr:
         case NodeTypeWhileExpr:
         case NodeTypeForExpr:
+        case NodeTypeSwitchExpr:
+        case NodeTypeSwitchProng:
+        case NodeTypeSwitchRange:
         case NodeTypeLabel:
         case NodeTypeGoto:
         case NodeTypeBreak:
@@ -3007,6 +3014,12 @@ static TypeTableEntry *analyze_prefix_op_expr(CodeGen *g, ImportTableEntry *impo
     zig_unreachable();
 }
 
+static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
+        TypeTableEntry *expected_type, AstNode *node)
+{
+    zig_panic("TODO analyze_switch_expr");
+}
+
 static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         TypeTableEntry *expected_type, AstNode *node)
 {
@@ -3184,6 +3197,11 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
         case NodeTypeArrayType:
             return_type = analyze_array_type(g, import, context, expected_type, node);
             break;
+        case NodeTypeSwitchExpr:
+            return_type = analyze_switch_expr(g, import, context, expected_type, node);
+            break;
+        case NodeTypeSwitchProng:
+        case NodeTypeSwitchRange:
         case NodeTypeDirective:
         case NodeTypeFnDecl:
         case NodeTypeFnProto:
@@ -3338,6 +3356,9 @@ static void analyze_top_level_decl(CodeGen *g, ImportTableEntry *import, AstNode
         case NodeTypeIfVarExpr:
         case NodeTypeWhileExpr:
         case NodeTypeForExpr:
+        case NodeTypeSwitchExpr:
+        case NodeTypeSwitchProng:
+        case NodeTypeSwitchRange:
         case NodeTypeLabel:
         case NodeTypeGoto:
         case NodeTypeBreak:
@@ -3472,6 +3493,24 @@ static void collect_expr_decl_deps(CodeGen *g, ImportTableEntry *import, AstNode
             }
             collect_expr_decl_deps(g, import, node->data.array_type.child_type, decl_node);
             break;
+        case NodeTypeSwitchExpr:
+            collect_expr_decl_deps(g, import, node->data.switch_expr.expr, decl_node);
+            for (int i = 0; i < node->data.switch_expr.prongs.length; i += 1) {
+                AstNode *prong = node->data.switch_expr.prongs.at(i);
+                collect_expr_decl_deps(g, import, prong, decl_node);
+            }
+            break;
+        case NodeTypeSwitchProng:
+            for (int i = 0; i < node->data.switch_prong.items.length; i += 1) {
+                AstNode *child = node->data.switch_prong.items.at(i);
+                collect_expr_decl_deps(g, import, child, decl_node);
+            }
+            collect_expr_decl_deps(g, import, node->data.switch_prong.expr, decl_node);
+            break;
+        case NodeTypeSwitchRange:
+            collect_expr_decl_deps(g, import, node->data.switch_range.start, decl_node);
+            collect_expr_decl_deps(g, import, node->data.switch_range.end, decl_node);
+            break;
         case NodeTypeVariableDeclaration:
         case NodeTypeFnProto:
         case NodeTypeExternBlock:
@@ -3661,6 +3700,9 @@ static void detect_top_level_decl_deps(CodeGen *g, ImportTableEntry *import, Ast
         case NodeTypeIfVarExpr:
         case NodeTypeWhileExpr:
         case NodeTypeForExpr:
+        case NodeTypeSwitchExpr:
+        case NodeTypeSwitchProng:
+        case NodeTypeSwitchRange:
         case NodeTypeLabel:
         case NodeTypeGoto:
         case NodeTypeBreak:
@@ -3869,6 +3911,10 @@ Expr *get_resolved_expr(AstNode *node) {
             return &node->data.label.resolved_expr;
         case NodeTypeArrayType:
             return &node->data.array_type.resolved_expr;
+        case NodeTypeSwitchExpr:
+            return &node->data.switch_expr.resolved_expr;
+        case NodeTypeSwitchProng:
+        case NodeTypeSwitchRange:
         case NodeTypeRoot:
         case NodeTypeRootExportDecl:
         case NodeTypeFnProto:
@@ -3902,6 +3948,9 @@ NumLitCodeGen *get_resolved_num_lit(AstNode *node) {
         case NodeTypeIfVarExpr:
         case NodeTypeWhileExpr:
         case NodeTypeForExpr:
+        case NodeTypeSwitchExpr:
+        case NodeTypeSwitchProng:
+        case NodeTypeSwitchRange:
         case NodeTypeAsmExpr:
         case NodeTypeContainerInitExpr:
         case NodeTypeRoot:
@@ -3953,6 +4002,9 @@ TopLevelDecl *get_resolved_top_level_decl(AstNode *node) {
         case NodeTypeIfVarExpr:
         case NodeTypeWhileExpr:
         case NodeTypeForExpr:
+        case NodeTypeSwitchExpr:
+        case NodeTypeSwitchProng:
+        case NodeTypeSwitchRange:
         case NodeTypeAsmExpr:
         case NodeTypeContainerInitExpr:
         case NodeTypeRoot:
src/codegen.cpp
@@ -1965,6 +1965,12 @@ static LLVMValueRef gen_symbol(CodeGen *g, AstNode *node) {
     return fn_entry->fn_value;
 }
 
+static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) {
+    assert(node->type == NodeTypeSwitchExpr);
+
+    zig_panic("TODO gen_switch_expr");
+}
+
 static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
     switch (node->type) {
         case NodeTypeBinOpExpr:
@@ -2040,6 +2046,8 @@ static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
             }
         case NodeTypeContainerInitExpr:
             return gen_container_init_expr(g, node);
+        case NodeTypeSwitchExpr:
+            return gen_switch_expr(g, node);
         case NodeTypeRoot:
         case NodeTypeRootExportDecl:
         case NodeTypeFnProto:
@@ -2053,6 +2061,8 @@ static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
         case NodeTypeStructField:
         case NodeTypeStructValueField:
         case NodeTypeArrayType:
+        case NodeTypeSwitchProng:
+        case NodeTypeSwitchRange:
             zig_unreachable();
     }
     zig_unreachable();
src/parser.cpp
@@ -123,6 +123,12 @@ const char *node_type_str(NodeType node_type) {
             return "WhileExpr";
         case NodeTypeForExpr:
             return "ForExpr";
+        case NodeTypeSwitchExpr:
+            return "SwitchExpr";
+        case NodeTypeSwitchProng:
+            return "SwitchProng";
+        case NodeTypeSwitchRange:
+            return "SwitchRange";
         case NodeTypeLabel:
             return "Label";
         case NodeTypeGoto:
@@ -342,6 +348,30 @@ void ast_print(AstNode *node, int indent) {
             }
             ast_print(node->data.for_expr.body, indent + 2);
             break;
+        case NodeTypeSwitchExpr:
+            fprintf(stderr, "%s\n", node_type_str(node->type));
+            ast_print(node->data.switch_expr.expr, indent + 2);
+            for (int i = 0; i < node->data.switch_expr.prongs.length; i += 1) {
+                AstNode *child_node = node->data.switch_expr.prongs.at(i);
+                ast_print(child_node, indent + 2);
+            }
+            break;
+        case NodeTypeSwitchProng:
+            fprintf(stderr, "%s\n", node_type_str(node->type));
+            for (int i = 0; i < node->data.switch_prong.items.length; i += 1) {
+                AstNode *child_node = node->data.switch_prong.items.at(i);
+                ast_print(child_node, indent + 2);
+            }
+            if (node->data.switch_prong.var_symbol) {
+                ast_print(node->data.switch_prong.var_symbol, indent + 2);
+            }
+            ast_print(node->data.switch_prong.expr, indent + 2);
+            break;
+        case NodeTypeSwitchRange:
+            fprintf(stderr, "%s\n", node_type_str(node->type));
+            ast_print(node->data.switch_range.start, indent + 2);
+            ast_print(node->data.switch_range.end, indent + 2);
+            break;
         case NodeTypeLabel:
             fprintf(stderr, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.label.name));
             break;
@@ -2167,7 +2197,80 @@ static AstNode *ast_parse_for_expr(ParseContext *pc, int *token_index, bool mand
 }
 
 /*
-BlockExpression : IfExpression | Block | WhileExpression | ForExpression
+SwitchExpression : "switch" "(" Expression ")" "{" many(SwitchProng) "}"
+SwitchProng : (list(SwitchItem, ",") | "else") option("," "(" "Symbol" ")") "=>" Expression ","
+SwitchItem : Expression | (Expression "..." Expression)
+*/
+static AstNode *ast_parse_switch_expr(ParseContext *pc, int *token_index, bool mandatory) {
+    Token *token = &pc->tokens->at(*token_index);
+
+    if (token->id != TokenIdKeywordSwitch) {
+        if (mandatory) {
+            ast_invalid_token_error(pc, token);
+        } else {
+            return nullptr;
+        }
+    }
+    *token_index += 1;
+
+    AstNode *node = ast_create_node(pc, NodeTypeSwitchExpr, token);
+
+    ast_eat_token(pc, token_index, TokenIdLParen);
+    node->data.switch_expr.expr = ast_parse_expression(pc, token_index, true);
+    ast_eat_token(pc, token_index, TokenIdRParen);
+    ast_eat_token(pc, token_index, TokenIdLBrace);
+
+    for (;;) {
+        Token *token = &pc->tokens->at(*token_index);
+
+        if (token->id == TokenIdRBrace) {
+            *token_index += 1;
+            return node;
+        }
+
+        AstNode *prong_node = ast_create_node(pc, NodeTypeSwitchProng, token);
+        node->data.switch_expr.prongs.append(prong_node);
+
+        if (token->id == TokenIdKeywordElse) {
+            *token_index += 1;
+        } else for (;;) {
+            AstNode *expr1 = ast_parse_expression(pc, token_index, true);
+            Token *ellipsis_tok = &pc->tokens->at(*token_index);
+            if (ellipsis_tok->id == TokenIdEllipsis) {
+                *token_index += 1;
+
+                AstNode *range_node = ast_create_node(pc, NodeTypeSwitchRange, ellipsis_tok);
+                prong_node->data.switch_prong.items.append(range_node);
+
+                range_node->data.switch_range.start = expr1;
+                range_node->data.switch_range.end = ast_parse_expression(pc, token_index, true);
+            } else {
+                prong_node->data.switch_prong.items.append(expr1);
+            }
+            Token *comma_tok = &pc->tokens->at(*token_index);
+            if (comma_tok->id == TokenIdComma) {
+                *token_index += 1;
+                continue;
+            }
+            break;
+        }
+
+        Token *arrow_or_comma = &pc->tokens->at(*token_index);
+        if (arrow_or_comma->id == TokenIdComma) {
+            *token_index += 1;
+            ast_eat_token(pc, token_index, TokenIdLParen);
+            prong_node->data.switch_prong.var_symbol = ast_parse_symbol(pc, token_index);
+            ast_eat_token(pc, token_index, TokenIdRParen);
+        }
+
+        ast_eat_token(pc, token_index, TokenIdFatArrow);
+        prong_node->data.switch_prong.expr = ast_parse_expression(pc, token_index, true);
+        ast_eat_token(pc, token_index, TokenIdComma);
+    }
+}
+
+/*
+BlockExpression : IfExpression | Block | WhileExpression | ForExpression | SwitchExpression
 */
 static AstNode *ast_parse_block_expr(ParseContext *pc, int *token_index, bool mandatory) {
     Token *token = &pc->tokens->at(*token_index);
@@ -2176,10 +2279,6 @@ static AstNode *ast_parse_block_expr(ParseContext *pc, int *token_index, bool ma
     if (if_expr)
         return if_expr;
 
-    AstNode *block = ast_parse_block(pc, token_index, false);
-    if (block)
-        return block;
-
     AstNode *while_expr = ast_parse_while_expr(pc, token_index, false);
     if (while_expr)
         return while_expr;
@@ -2188,6 +2287,14 @@ static AstNode *ast_parse_block_expr(ParseContext *pc, int *token_index, bool ma
     if (for_expr)
         return for_expr;
 
+    AstNode *switch_expr = ast_parse_switch_expr(pc, token_index, false);
+    if (switch_expr)
+        return switch_expr;
+
+    AstNode *block = ast_parse_block(pc, token_index, false);
+    if (block)
+        return block;
+
     if (mandatory)
         ast_invalid_token_error(pc, token);
 
src/tokenizer.cpp
@@ -243,6 +243,8 @@ static void end_token(Tokenize *t) {
         t->cur_tok->id = TokenIdKeywordNull;
     } else if (mem_eql_str(token_mem, token_len, "noalias")) {
         t->cur_tok->id = TokenIdKeywordNoAlias;
+    } else if (mem_eql_str(token_mem, token_len, "switch")) {
+        t->cur_tok->id = TokenIdKeywordSwitch;
     }
 
     t->cur_tok = nullptr;
@@ -1035,6 +1037,7 @@ const char * token_name(TokenId id) {
         case TokenIdKeywordBreak: return "break";
         case TokenIdKeywordNull: return "null";
         case TokenIdKeywordNoAlias: return "noalias";
+        case TokenIdKeywordSwitch: return "switch";
         case TokenIdLParen: return "(";
         case TokenIdRParen: return ")";
         case TokenIdComma: return ",";
src/tokenizer.hpp
@@ -36,6 +36,7 @@ enum TokenId {
     TokenIdKeywordBreak,
     TokenIdKeywordNull,
     TokenIdKeywordNoAlias,
+    TokenIdKeywordSwitch,
     TokenIdLParen,
     TokenIdRParen,
     TokenIdComma,