Commit ad9759bc8e

Andrew Kelley <superjoe30@gmail.com>
2016-01-20 10:12:24
basic support for switch expression
1 parent 3eca42c
doc/langref.md
@@ -94,7 +94,7 @@ BlockExpression : IfExpression | Block | WhileExpression | ForExpression | Switc
 
 SwitchExpression : "switch" "(" Expression ")" "{" many(SwitchProng) "}"
 
-SwitchProng : (list(SwitchItem, ",") | "else") option("," "(" "Symbol" ")") "=>" Expression ","
+SwitchProng : (list(SwitchItem, ",") | "else") option(":" "(" "Symbol" ")") "=>" Expression ","
 
 SwitchItem : Expression | (Expression "..." Expression)
 
@@ -197,8 +197,8 @@ x{}
                 | Example  | Characters  | Escapes        | Null Term | Type
 ----------------|----------|-------------|----------------|-----------|----------
  Byte           | 'H'      | All ASCII   | Byte           | No        | u8
- UTF-8 Bytes    | "hello"  | All Unicode | Byte & Unicode | No        | [5; u8]
- UTF-8 C string | c"hello" | All Unicode | Byte & Unicode | Yes       | *const u8
+ UTF-8 Bytes    | "hello"  | All Unicode | Byte & Unicode | No        | [5]u8
+ UTF-8 C string | c"hello" | All Unicode | Byte & Unicode | Yes       | &const u8
 
 ### Byte Escapes
 
example/list/list.zig
@@ -68,37 +68,35 @@ pub fn free#(T: type)(ptr: ?&T) {
 
 ////////////////// alternate
 
-// previously proposed, but with : instead of ->
-// `:` means "parser should expect a type now"
-fn max#(T :type)(a :T, b :T) :T {
+// previously proposed but without ->
+fn max#(T: type)(a: T, b: T) T {
     if (a > b) a else b
 }
 
 // andy's new idea
-// parameters can talk about @typeof() for previous parameters.
-// using :T here is equivalent to @child_type(@typeof(T))
-fn max(T :type, a :T, b :T) :T {
+// parameters can reference other inline parameters.
+fn max(inline T: type, a: T, b: T) T {
     if (a > b) a else b
 }
 
 fn f() {
-    const x :i32 = 1234;
-    const y :i32 = 5678;
+    const x: i32 = 1234;
+    const y: i32 = 5678;
     const z = max(@typeof(x), x, y);
 }
 
 // So, type-generic functions don't need any fancy syntax. type-generic
 // containers still do, though:
 
-pub struct List(T :type) {
-    items :?&T,
-    length :isize,
-    capacity :isize,
+pub struct List(T: type) {
+    items: ?&T,
+    length: isize,
+    capacity: isize,
 }
 
-// Types are always marked with ':' so we don't need '#' to indicate type generic parameters.
+// we don't need '#' to indicate type generic parameters.
 
 fn f() {
-    var list :List(:u8);
+    var list: List(u8);
 }
 
src/all_types.hpp
@@ -26,6 +26,7 @@ struct BuiltinFnEntry;
 struct LabelTableEntry;
 struct TypeStructField;
 struct CodeGen;
+struct ConstExprValue;
 
 enum OutType {
     OutTypeUnknown,
@@ -57,6 +58,11 @@ struct Cast {
     AstNode *source_node;
 };
 
+struct ConstEnumValue {
+    uint64_t tag;
+    ConstExprValue *payload;
+};
+
 struct ConstExprValue {
     bool ok; // true if constant expression evalution worked
     bool depends_on_compile_var;
@@ -69,6 +75,7 @@ struct ConstExprValue {
         FnTableEntry *x_fn;
         TypeTableEntry *x_type;
         ConstExprValue *x_maybe;
+        ConstEnumValue x_enum;
     } data;
 };
 
@@ -426,6 +433,10 @@ struct AstNodeSwitchProng {
     ZigList<AstNode *> items;
     AstNode *var_symbol;
     AstNode *expr;
+
+    // populated by semantic analyzer
+    BlockContext *block_context;
+    VariableTableEntry *var;
 };
 
 struct AstNodeSwitchRange {
@@ -933,6 +944,7 @@ struct CodeGen {
 
     OutType out_type;
     FnTableEntry *cur_fn;
+    // TODO remove this in favor of get_resolved_expr(expr_node)->context
     BlockContext *cur_block_context;
     ZigList<LLVMBasicBlockRef> break_block_stack;
     ZigList<LLVMBasicBlockRef> continue_block_stack;
src/analyze.cpp
@@ -1350,6 +1350,11 @@ static TypeTableEntry *analyze_enum_value_expr(CodeGen *g, ImportTableEntry *imp
                     buf_ptr(&enum_type->name),
                     buf_ptr(field_name),
                     buf_ptr(&type_enum_field->type_entry->name)));
+        } else {
+            Expr *expr = get_resolved_expr(field_access_node);
+            expr->const_val.ok = true;
+            expr->const_val.data.x_enum.tag = type_enum_field->value;
+            expr->const_val.data.x_enum.payload = nullptr;
         }
     } else {
         add_node_error(g, field_access_node,
@@ -1945,6 +1950,25 @@ static TypeTableEntry *analyze_bool_bin_op_expr(CodeGen *g, ImportTableEntry *im
         }
     } else if (resolved_type->id == TypeTableEntryIdFloat) {
         answer = eval_bool_bin_op_float(op1_val->data.x_float, bin_op_type, op2_val->data.x_float);
+    } else if (resolved_type->id == TypeTableEntryIdEnum) {
+        ConstEnumValue *enum1 = &op1_val->data.x_enum;
+        ConstEnumValue *enum2 = &op2_val->data.x_enum;
+        bool are_equal = false;
+        if (enum1->tag == enum2->tag) {
+            TypeEnumField *enum_field = &op1_type->data.enumeration.fields[enum1->tag];
+            if (enum_field->type_entry->size_in_bits > 0) {
+                zig_panic("TODO const expr analyze enum special value for equality");
+            } else {
+                are_equal = true;
+            }
+        }
+        if (bin_op_type == BinOpTypeCmpEq) {
+            answer = are_equal;
+        } else if (bin_op_type == BinOpTypeCmpNotEq) {
+            answer = !are_equal;
+        } else {
+            zig_unreachable();
+        }
     } else {
         zig_unreachable();
     }
@@ -3017,7 +3041,62 @@ static TypeTableEntry *analyze_prefix_op_expr(CodeGen *g, ImportTableEntry *impo
 static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         TypeTableEntry *expected_type, AstNode *node)
 {
-    zig_panic("TODO analyze_switch_expr");
+    AstNode *expr_node = node->data.switch_expr.expr;
+    TypeTableEntry *expr_type = analyze_expression(g, import, context, nullptr, expr_node);
+
+    if (expected_type == nullptr) {
+        zig_panic("TODO resolve peer compatibility of switch prongs");
+    }
+
+    if (expr_type->id == TypeTableEntryIdInvalid) {
+        return expr_type;
+    } else if (expr_type->id == TypeTableEntryIdUnreachable) {
+        add_node_error(g, first_executing_node(expr_node),
+                buf_sprintf("switch on unreachable expression not allowed"));
+        return g->builtin_types.entry_invalid;
+    } else {
+        AstNode *else_prong = nullptr;
+        for (int prong_i = 0; prong_i < node->data.switch_expr.prongs.length; prong_i += 1) {
+            AstNode *prong_node = node->data.switch_expr.prongs.at(prong_i);
+
+            TypeTableEntry *var_type;
+            if (prong_node->data.switch_prong.items.length == 0) {
+                if (else_prong) {
+                    add_node_error(g, prong_node, buf_sprintf("multiple else prongs in switch expression"));
+                } else {
+                    else_prong = prong_node;
+                }
+                var_type = expr_type;
+            } else {
+                for (int item_i = 0; item_i < prong_node->data.switch_prong.items.length; item_i += 1) {
+                    AstNode *item_node = prong_node->data.switch_prong.items.at(item_i);
+                    if (item_node->type == NodeTypeSwitchRange) {
+                        zig_panic("TODO range in switch statement");
+                    }
+                    analyze_expression(g, import, context, expr_type, item_node);
+                    ConstExprValue *const_val = &get_resolved_expr(item_node)->const_val;
+                    if (!const_val->ok) {
+                        add_node_error(g, item_node, buf_sprintf("unable to resolve constant expression"));
+                    }
+                }
+                var_type = expr_type;
+            }
+
+            BlockContext *child_context = new_block_context(node, context);
+            prong_node->data.switch_prong.block_context = child_context;
+            AstNode *var_node = prong_node->data.switch_prong.var_symbol;
+            if (var_node) {
+                assert(var_node->type == NodeTypeSymbol);
+                Buf *var_name = &var_node->data.symbol_expr.symbol;
+                prong_node->data.switch_prong.var = add_local_var(g, var_node, child_context, var_name,
+                        var_type, true);
+            }
+
+            analyze_expression(g, import, child_context, expected_type,
+                    prong_node->data.switch_prong.expr);
+        }
+    }
+    return expected_type;
 }
 
 static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, BlockContext *context,
src/codegen.cpp
@@ -1968,7 +1968,69 @@ static LLVMValueRef gen_symbol(CodeGen *g, AstNode *node) {
 static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypeSwitchExpr);
 
-    zig_panic("TODO gen_switch_expr");
+    LLVMValueRef target_value = gen_expr(g, node->data.switch_expr.expr);
+
+    bool end_unreachable = (get_expr_type(node)->id == TypeTableEntryIdUnreachable);
+
+    LLVMBasicBlockRef end_block = end_unreachable ?
+        nullptr : LLVMAppendBasicBlock(g->cur_fn->fn_value, "SwitchEnd");
+    LLVMBasicBlockRef else_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "SwitchElse");
+    int prong_count = node->data.switch_expr.prongs.length;
+
+    add_debug_source_node(g, node);
+    LLVMValueRef switch_instr = LLVMBuildSwitch(g->builder, target_value, else_block, prong_count);
+
+    ZigList<LLVMValueRef> incoming_values = {0};
+    ZigList<LLVMBasicBlockRef> incoming_blocks = {0};
+
+    AstNode *else_prong = nullptr;
+    for (int prong_i = 0; prong_i < prong_count; prong_i += 1) {
+        AstNode *prong_node = node->data.switch_expr.prongs.at(prong_i);
+        LLVMBasicBlockRef prong_block;
+        if (prong_node->data.switch_prong.items.length == 0) {
+            assert(!else_prong);
+            else_prong = prong_node;
+            prong_block = else_block;
+        } else {
+            prong_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "SwitchProng");
+            for (int item_i = 0; item_i < prong_node->data.switch_prong.items.length; item_i += 1) {
+                AstNode *item_node = prong_node->data.switch_prong.items.at(item_i);
+                assert(item_node->type != NodeTypeSwitchRange);
+                assert(get_resolved_expr(item_node)->const_val.ok);
+                LLVMValueRef val = gen_expr(g, item_node);
+                LLVMAddCase(switch_instr, val, prong_block);
+            }
+        }
+        assert(!prong_node->data.switch_prong.var_symbol);
+        LLVMPositionBuilderAtEnd(g->builder, prong_block);
+        AstNode *prong_expr = prong_node->data.switch_prong.expr;
+        LLVMValueRef prong_val = gen_expr(g, prong_expr);
+
+        if (get_expr_type(prong_expr)->id != TypeTableEntryIdUnreachable) {
+            add_debug_source_node(g, prong_expr);
+            LLVMBuildBr(g->builder, end_block);
+            incoming_values.append(prong_val);
+            incoming_blocks.append(prong_block);
+        }
+    }
+
+    if (!else_prong) {
+        LLVMPositionBuilderAtEnd(g->builder, else_block);
+        add_debug_source_node(g, node);
+        LLVMBuildUnreachable(g->builder);
+    }
+
+    if (end_unreachable) {
+        return nullptr;
+    }
+
+    LLVMPositionBuilderAtEnd(g->builder, end_block);
+
+    add_debug_source_node(g, node);
+    LLVMValueRef phi = LLVMBuildPhi(g->builder, get_expr_type(node)->type_ref, "");
+    LLVMAddIncoming(phi, incoming_values.items, incoming_blocks.items, incoming_values.length);
+
+    return phi;
 }
 
 static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
src/parser.cpp
@@ -2255,8 +2255,8 @@ static AstNode *ast_parse_switch_expr(ParseContext *pc, int *token_index, bool m
             break;
         }
 
-        Token *arrow_or_comma = &pc->tokens->at(*token_index);
-        if (arrow_or_comma->id == TokenIdComma) {
+        Token *arrow_or_colon = &pc->tokens->at(*token_index);
+        if (arrow_or_colon->id == TokenIdColon) {
             *token_index += 1;
             ast_eat_token(pc, token_index, TokenIdLParen);
             prong_node->data.switch_prong.var_symbol = ast_parse_symbol(pc, token_index);
test/run_tests.cpp
@@ -1180,6 +1180,33 @@ fn fn2() u32 => {6}
 fn fn3() u32 => {7}
 fn fn4() u32 => {8}
     )SOURCE", "5\n6\n7\n8\n");
+
+    add_simple_case("switch statement", R"SOURCE(
+import "std.zig";
+
+enum Foo {
+    A,
+    B,
+    C,
+    D,
+}
+
+pub fn main(args: [][]u8) i32 => {
+    const foo = Foo.C;
+    const val: i32 = switch (foo) {
+        Foo.A => 1,
+        Foo.B => 2,
+        Foo.C => 3,
+        Foo.D => 4,
+    };
+    if (val != 3) {
+        print_str("BAD\n");
+    }
+
+    print_str("OK\n");
+    return 0;
+}
+    )SOURCE", "OK\n");
 }
 
 
@@ -1511,6 +1538,16 @@ fn f(Foo: i32) => {
 }
     )SOURCE", 2, ".tmp_source.zig:5:6: error: variable shadows type 'Foo'",
                  ".tmp_source.zig:6:5: error: variable shadows type 'Bar'");
+
+    add_compile_fail_case("multiple else prongs in a switch", R"SOURCE(
+fn f() => {
+    const value: bool = switch (u32(111)) {
+        1234 => false,
+        else => true,
+        else => true,
+    };
+}
+    )SOURCE", 1, ".tmp_source.zig:6:9: error: multiple else prongs in switch expression");
 }
 
 static void print_compiler_invocation(TestCase *test_case) {