Commit a4cba900e5

Andrew Kelley <superjoe30@gmail.com>
2016-02-05 00:09:06
no namespace required when switching on enum
See #43
1 parent 5490f90
src/all_types.hpp
@@ -673,6 +673,7 @@ struct AstNodeSymbolExpr {
     FnTableEntry *fn_entry;
     // set this to instead of analyzing the node, pretend it's a type entry and it's this one.
     TypeTableEntry *override_type_entry;
+    TypeEnumField *enum_field;
 };
 
 struct AstNodeBoolLiteral {
src/analyze.cpp
@@ -4508,10 +4508,30 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import,
                     if (item_node->type == NodeTypeSwitchRange) {
                         zig_panic("TODO range in switch statement");
                     }
-                    analyze_expression(g, import, context, expr_type, item_node);
-                    ConstExprValue *const_val = &get_resolved_expr(item_node)->const_val;
-                    if (!const_val->ok) {
-                        add_node_error(g, item_node, buf_sprintf("unable to resolve constant expression"));
+
+                    if (expr_type->id == TypeTableEntryIdEnum) {
+                        if (item_node->type == NodeTypeSymbol) {
+                            Buf *field_name = &item_node->data.symbol_expr.symbol;
+                            TypeEnumField *type_enum_field = get_enum_field(expr_type, field_name);
+                            if (type_enum_field) {
+                                item_node->data.symbol_expr.enum_field = type_enum_field;
+                            } else {
+                                add_node_error(g, item_node,
+                                        buf_sprintf("enum '%s' has no field '%s'",
+                                            buf_ptr(&expr_type->name), buf_ptr(field_name)));
+                            }
+                        } else {
+                            add_node_error(g, item_node, buf_sprintf("expected enum tag name"));
+                        }
+                    } else {
+                        TypeTableEntry *item_type = analyze_expression(g, import, context, expr_type, item_node);
+                        if (item_type->id != TypeTableEntryIdInvalid) {
+                            ConstExprValue *const_val = &get_resolved_expr(item_node)->const_val;
+                            if (!const_val->ok) {
+                                add_node_error(g, item_node,
+                                    buf_sprintf("unable to resolve constant expression"));
+                            }
+                        }
                     }
                 }
                 var_type = expr_type;
src/codegen.cpp
@@ -2357,21 +2357,16 @@ static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) {
             for (int item_i = 0; item_i < prong_node->data.switch_prong.items.length; item_i += 1) {
                 AstNode *item_node = prong_node->data.switch_prong.items.at(item_i);
                 assert(item_node->type != NodeTypeSwitchRange);
-                assert(get_resolved_expr(item_node)->const_val.ok);
-                LLVMValueRef val_handle = gen_expr(g, item_node);
                 LLVMValueRef val;
-                if (handle_is_ptr(target_type)) {
-                    if (target_type->id == TypeTableEntryIdEnum) {
-                        ConstExprValue *item_const_val = &get_resolved_expr(item_node)->const_val;
-                        assert(item_const_val->ok);
-                        assert(get_expr_type(item_node)->id == TypeTableEntryIdEnum);
-                        val = LLVMConstInt(target_type->data.enumeration.tag_type->type_ref,
-                                item_const_val->data.x_enum.tag, false);
-                    } else {
-                        zig_unreachable();
-                    }
+                if (target_type->id == TypeTableEntryIdEnum) {
+                    assert(item_node->type == NodeTypeSymbol);
+                    TypeEnumField *enum_field = item_node->data.symbol_expr.enum_field;
+                    assert(enum_field);
+                    val = LLVMConstInt(target_type->data.enumeration.tag_type->type_ref,
+                            enum_field->value, false);
                 } else {
-                    val = val_handle;
+                    assert(get_resolved_expr(item_node)->const_val.ok);
+                    val = gen_expr(g, item_node);
                 }
                 LLVMAddCase(switch_instr, val, prong_block);
             }
test/run_tests.cpp
@@ -1157,32 +1157,6 @@ fn fn3() -> u32 {7}
 fn fn4() -> u32 {8}
     )SOURCE", "5\n6\n7\n8\n");
 
-    add_simple_case("switch statement", R"SOURCE(
-import "std.zig";
-
-enum Foo {
-    A,
-    B,
-    C,
-    D,
-}
-
-pub fn main(args: [][]u8) -> %void {
-    const foo = Foo.C;
-    const val: i32 = switch (foo) {
-        Foo.A => 1,
-        Foo.B => 2,
-        Foo.C => 3,
-        Foo.D => 4,
-    };
-    if (val != 3) {
-        %%stdout.printf("BAD\n");
-    }
-
-    %%stdout.printf("OK\n");
-}
-    )SOURCE", "OK\n");
-
     add_simple_case("const number literal", R"SOURCE(
 import "std.zig";
 
test/self_hosted.zig
@@ -59,14 +59,14 @@ fn constant_enum_with_payload() {
 
 fn should_be_empty(x: AnEnumWithPayload) {
     switch (x) {
-        AnEnumWithPayload.Empty => {},
+        Empty => {},
         else => unreachable{},
     }
 }
 
 fn should_be_not_empty(x: AnEnumWithPayload) {
     switch (x) {
-        AnEnumWithPayload.Empty => unreachable{},
+        Empty => unreachable{},
         else => {},
     }
 }
@@ -111,9 +111,9 @@ fn non_const_cast_bool_to_int(t: bool, f: bool) {
 fn switch_on_enum() {
     const fruit = Fruit.Orange;
     switch (fruit) {
-        Fruit.Apple => unreachable{},
-        Fruit.Orange => {},
-        Fruit.Banana => unreachable{},
+        Apple => unreachable{},
+        Orange => {},
+        Banana => unreachable{},
     }
     non_const_switch_on_enum(fruit);
 }
@@ -124,8 +124,26 @@ enum Fruit {
 }
 fn non_const_switch_on_enum(fruit: Fruit) {
     switch (fruit) {
-        Fruit.Apple => unreachable{},
-        Fruit.Orange => {},
-        Fruit.Banana => unreachable{},
+        Apple => unreachable{},
+        Orange => {},
+        Banana => unreachable{},
     }
 }
+
+#attribute("test")
+fn switch_statement() {
+    const foo = SwitchStatmentFoo.C;
+    const val: i32 = switch (foo) {
+        A => 1,
+        B => 2,
+        C => 3,
+        D => 4,
+    };
+    if (val != 3) unreachable{};
+}
+enum SwitchStatmentFoo {
+    A,
+    B,
+    C,
+    D,
+}