Commit e06885d64e

Andrew Kelley <superjoe30@gmail.com>
2016-09-23 21:18:02
enums support member functions
1 parent 9ec6a78
src/all_types.hpp
@@ -924,7 +924,7 @@ struct TypeTableEntryError {
 
 struct TypeTableEntryEnum {
     AstNode *decl_node;
-    uint32_t field_count;
+    uint32_t src_field_count;
     uint32_t gen_field_count;
     TypeEnumField *fields;
     bool is_invalid; // true if any fields are invalid
src/analyze.cpp
@@ -1325,7 +1325,7 @@ static void resolve_enum_type(CodeGen *g, ImportTableEntry *import, TypeTableEnt
 
     uint32_t field_count = decl_node->data.struct_decl.fields.length;
 
-    enum_type->data.enumeration.field_count = field_count;
+    enum_type->data.enumeration.src_field_count = field_count;
     enum_type->data.enumeration.fields = allocate<TypeEnumField>(field_count);
     ZigLLVMDIEnumerator **di_enumerators = allocate<ZigLLVMDIEnumerator*>(field_count);
 
@@ -2451,8 +2451,8 @@ static LabelTableEntry *find_label(CodeGen *g, BlockContext *orig_context, Buf *
     return nullptr;
 }
 
-static TypeEnumField *get_enum_field(TypeTableEntry *enum_type, Buf *name) {
-    for (uint32_t i = 0; i < enum_type->data.enumeration.field_count; i += 1) {
+static TypeEnumField *find_enum_type_field(TypeTableEntry *enum_type, Buf *name) {
+    for (uint32_t i = 0; i < enum_type->data.enumeration.src_field_count; i += 1) {
         TypeEnumField *type_enum_field = &enum_type->data.enumeration.fields[i];
         if (buf_eql_buf(type_enum_field->name, name)) {
             return type_enum_field;
@@ -2467,7 +2467,7 @@ static TypeTableEntry *analyze_enum_value_expr(CodeGen *g, ImportTableEntry *imp
 {
     assert(field_access_node->type == NodeTypeFieldAccessExpr);
 
-    TypeEnumField *type_enum_field = get_enum_field(enum_type, field_name);
+    TypeEnumField *type_enum_field = find_enum_type_field(enum_type, field_name);
     if (type_enum_field->type_entry->id == TypeTableEntryIdInvalid) {
         return g->builtin_types.entry_invalid;
     }
@@ -2715,6 +2715,41 @@ static TypeTableEntry *analyze_container_init_expr(CodeGen *g, ImportTableEntry
     }
 }
 
+static TypeTableEntry *analyze_member_access(CodeGen *g, bool wrapped_in_fn_call,
+    TypeTableEntry *bare_struct_type, Buf *field_name, AstNode *node, TypeTableEntry *struct_type)
+{
+    assert(node->type == NodeTypeFieldAccessExpr);
+    if (wrapped_in_fn_call && !is_slice(bare_struct_type)) {
+        BlockContext *container_block_context = get_container_block_context(bare_struct_type);
+        assert(container_block_context);
+        auto entry = container_block_context->decl_table.maybe_get(field_name);
+        AstNode *fn_decl_node = entry ? entry->value : nullptr;
+        if (fn_decl_node && fn_decl_node->type == NodeTypeFnProto) {
+            resolve_top_level_decl(g, fn_decl_node, false);
+            TopLevelDecl *tld = get_as_top_level_decl(fn_decl_node);
+            if (tld->resolution == TldResolutionInvalid) {
+                return g->builtin_types.entry_invalid;
+            }
+
+            node->data.field_access_expr.is_member_fn = true;
+            FnTableEntry *fn_entry = fn_decl_node->data.fn_proto.fn_table_entry;
+            if (fn_entry->type_entry->id == TypeTableEntryIdGenericFn) {
+                return resolve_expr_const_val_as_generic_fn(g, node, fn_entry->type_entry, false);
+            } else {
+                return resolve_expr_const_val_as_fn(g, node, fn_entry, false);
+            }
+        } else {
+            add_node_error(g, node, buf_sprintf("no function named '%s' in '%s'",
+                buf_ptr(field_name), buf_ptr(&bare_struct_type->name)));
+            return g->builtin_types.entry_invalid;
+        }
+    } else {
+        add_node_error(g, node,
+            buf_sprintf("no member named '%s' in '%s'", buf_ptr(field_name), buf_ptr(&struct_type->name)));
+        return g->builtin_types.entry_invalid;
+    }
+}
+
 static TypeTableEntry *analyze_field_access_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         TypeTableEntry *expected_type, AstNode *node)
 {
@@ -2742,34 +2777,28 @@ static TypeTableEntry *analyze_field_access_expr(CodeGen *g, ImportTableEntry *i
         node->data.field_access_expr.type_struct_field = find_struct_type_field(bare_struct_type, field_name);
         if (node->data.field_access_expr.type_struct_field) {
             return node->data.field_access_expr.type_struct_field->type_entry;
-        } else if (wrapped_in_fn_call && !is_slice(bare_struct_type)) {
-            BlockContext *container_block_context = get_container_block_context(bare_struct_type);
-            assert(container_block_context);
-            auto entry = container_block_context->decl_table.maybe_get(field_name);
-            AstNode *fn_decl_node = entry ? entry->value : nullptr;
-            if (fn_decl_node && fn_decl_node->type == NodeTypeFnProto) {
-                resolve_top_level_decl(g, fn_decl_node, false);
-                TopLevelDecl *tld = get_as_top_level_decl(fn_decl_node);
-                if (tld->resolution == TldResolutionInvalid) {
-                    return g->builtin_types.entry_invalid;
-                }
+        } else {
+            return analyze_member_access(g, wrapped_in_fn_call, bare_struct_type, field_name,
+                node, struct_type);
+        }
+    } else if (struct_type->id == TypeTableEntryIdEnum || (struct_type->id == TypeTableEntryIdPointer &&
+        struct_type->data.pointer.child_type->id == TypeTableEntryIdEnum))
+    {
+        TypeTableEntry *bare_struct_type = (struct_type->id == TypeTableEntryIdEnum) ?
+            struct_type : struct_type->data.pointer.child_type;
 
-                node->data.field_access_expr.is_member_fn = true;
-                FnTableEntry *fn_entry = fn_decl_node->data.fn_proto.fn_table_entry;
-                if (fn_entry->type_entry->id == TypeTableEntryIdGenericFn) {
-                    return resolve_expr_const_val_as_generic_fn(g, node, fn_entry->type_entry, false);
-                } else {
-                    return resolve_expr_const_val_as_fn(g, node, fn_entry, false);
-                }
-            } else {
-                add_node_error(g, node, buf_sprintf("no function named '%s' in '%s'",
-                    buf_ptr(field_name), buf_ptr(&bare_struct_type->name)));
-                return g->builtin_types.entry_invalid;
-            }
+        if (!bare_struct_type->data.enumeration.complete) {
+            resolve_struct_type(g, bare_struct_type->data.enumeration.decl_node->owner, bare_struct_type);
+        }
+
+        node->data.field_access_expr.bare_struct_type = bare_struct_type;
+        node->data.field_access_expr.type_enum_field = find_enum_type_field(bare_struct_type, field_name);
+
+        if (node->data.field_access_expr.type_enum_field) {
+            return node->data.field_access_expr.type_enum_field->type_entry;
         } else {
-            add_node_error(g, node,
-                buf_sprintf("no member named '%s' in '%s'", buf_ptr(field_name), buf_ptr(&struct_type->name)));
-            return g->builtin_types.entry_invalid;
+            return analyze_member_access(g, wrapped_in_fn_call, bare_struct_type, field_name,
+                node, struct_type);
         }
     } else if (struct_type->id == TypeTableEntryIdArray) {
         if (buf_eql_str(field_name, "len")) {
@@ -5269,7 +5298,7 @@ static TypeTableEntry *analyze_builtin_fn_call_expr(CodeGen *g, ImportTableEntry
                 if (type_entry->id == TypeTableEntryIdInvalid) {
                     return type_entry;
                 } else if (type_entry->id == TypeTableEntryIdEnum) {
-                    uint64_t value_count = type_entry->data.enumeration.field_count;
+                    uint64_t value_count = type_entry->data.enumeration.src_field_count;
                     return resolve_expr_const_val_as_unsigned_num_lit(g, node, expected_type,
                             value_count, false);
                 } else {
@@ -6130,7 +6159,7 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import,
     size_t *field_use_counts = nullptr;
     HashMap<int, AstNode *, int_hash, int_eq> err_use_nodes = {};
     if (expr_type->id == TypeTableEntryIdEnum) {
-        field_use_counts = allocate<size_t>(expr_type->data.enumeration.field_count);
+        field_use_counts = allocate<size_t>(expr_type->data.enumeration.src_field_count);
     } else if (expr_type->id == TypeTableEntryIdErrorUnion) {
         err_use_nodes.init(10);
     }
@@ -6168,7 +6197,7 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import,
                 if (expr_type->id == TypeTableEntryIdEnum) {
                     if (item_node->type == NodeTypeSymbol) {
                         Buf *field_name = item_node->data.symbol_expr.symbol;
-                        TypeEnumField *type_enum_field = get_enum_field(expr_type, field_name);
+                        TypeEnumField *type_enum_field = find_enum_type_field(expr_type, field_name);
                         if (type_enum_field) {
                             item_node->data.symbol_expr.enum_field = type_enum_field;
                             if (!var_type) {
@@ -6300,7 +6329,7 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import,
     }
 
     if (expr_type->id == TypeTableEntryIdEnum && !else_prong) {
-        for (uint32_t i = 0; i < expr_type->data.enumeration.field_count; i += 1) {
+        for (uint32_t i = 0; i < expr_type->data.enumeration.src_field_count; i += 1) {
             if (field_use_counts[i] == 0) {
                 add_node_error(g, node,
                     buf_sprintf("enumeration value '%s' not handled in switch",
src/codegen.cpp
@@ -4605,7 +4605,7 @@ static void define_builtin_types(CodeGen *g) {
         entry->zero_bits = true; // only allowed at compile time
         buf_init_from_str(&entry->name, "@OS");
         uint32_t field_count = target_os_count();
-        entry->data.enumeration.field_count = field_count;
+        entry->data.enumeration.src_field_count = field_count;
         entry->data.enumeration.fields = allocate<TypeEnumField>(field_count);
         for (uint32_t i = 0; i < field_count; i += 1) {
             TypeEnumField *type_enum_field = &entry->data.enumeration.fields[i];
@@ -4631,7 +4631,7 @@ static void define_builtin_types(CodeGen *g) {
         entry->zero_bits = true; // only allowed at compile time
         buf_init_from_str(&entry->name, "@Arch");
         uint32_t field_count = target_arch_count();
-        entry->data.enumeration.field_count = field_count;
+        entry->data.enumeration.src_field_count = field_count;
         entry->data.enumeration.fields = allocate<TypeEnumField>(field_count);
         for (uint32_t i = 0; i < field_count; i += 1) {
             TypeEnumField *type_enum_field = &entry->data.enumeration.fields[i];
@@ -4663,7 +4663,7 @@ static void define_builtin_types(CodeGen *g) {
         entry->zero_bits = true; // only allowed at compile time
         buf_init_from_str(&entry->name, "@Environ");
         uint32_t field_count = target_environ_count();
-        entry->data.enumeration.field_count = field_count;
+        entry->data.enumeration.src_field_count = field_count;
         entry->data.enumeration.fields = allocate<TypeEnumField>(field_count);
         for (uint32_t i = 0; i < field_count; i += 1) {
             TypeEnumField *type_enum_field = &entry->data.enumeration.fields[i];
@@ -4690,7 +4690,7 @@ static void define_builtin_types(CodeGen *g) {
         entry->zero_bits = true; // only allowed at compile time
         buf_init_from_str(&entry->name, "@ObjectFormat");
         uint32_t field_count = target_oformat_count();
-        entry->data.enumeration.field_count = field_count;
+        entry->data.enumeration.src_field_count = field_count;
         entry->data.enumeration.fields = allocate<TypeEnumField>(field_count);
         for (uint32_t i = 0; i < field_count; i += 1) {
             TypeEnumField *type_enum_field = &entry->data.enumeration.fields[i];
@@ -4716,7 +4716,7 @@ static void define_builtin_types(CodeGen *g) {
         entry->deep_const = true;
         buf_init_from_str(&entry->name, "AtomicOrder");
         uint32_t field_count = 6;
-        entry->data.enumeration.field_count = field_count;
+        entry->data.enumeration.src_field_count = field_count;
         entry->data.enumeration.fields = allocate<TypeEnumField>(field_count);
         entry->data.enumeration.fields[0].name = buf_create_from_str("Unordered");
         entry->data.enumeration.fields[0].value = AtomicOrderUnordered;
src/eval.cpp
@@ -683,7 +683,7 @@ void eval_const_expr_implicit_cast(CastOp cast_op,
             {
                 uint64_t value = other_val->data.x_bignum.data.x_uint;
                 assert(new_type->id == TypeTableEntryIdEnum);
-                assert(value < new_type->data.enumeration.field_count);
+                assert(value < new_type->data.enumeration.src_field_count);
                 const_val->data.x_enum.tag = value;
                 const_val->data.x_enum.payload = NULL;
                 const_val->ok = true;
src/parseh.cpp
@@ -870,7 +870,7 @@ static TypeTableEntry *resolve_enum_decl(Context *c, const EnumDecl *enum_decl)
         enum_type->data.enumeration.complete = true;
         enum_type->data.enumeration.tag_type = tag_type_entry;
 
-        enum_type->data.enumeration.field_count = field_count;
+        enum_type->data.enumeration.src_field_count = field_count;
         enum_type->data.enumeration.fields = allocate<TypeEnumField>(field_count);
         ZigLLVMDIEnumerator **di_enumerators = allocate<ZigLLVMDIEnumerator*>(field_count);
 
@@ -977,7 +977,7 @@ static void visit_enum_decl(Context *c, const EnumDecl *enum_decl) {
             enum_node->data.struct_decl.top_level_decl.visib_mod = VisibModExport;
             enum_node->data.struct_decl.type_entry = enum_type;
 
-            for (uint32_t i = 0; i < enum_type->data.enumeration.field_count; i += 1) {
+            for (uint32_t i = 0; i < enum_type->data.enumeration.src_field_count; i += 1) {
                 TypeEnumField *type_enum_field = &enum_type->data.enumeration.fields[i];
                 AstNode *type_node = make_type_node(c, type_enum_field->type_entry);
                 AstNode *field_node = create_struct_field_node(c, buf_ptr(type_enum_field->name), type_node);
test/cases/enum_with_members.zig
@@ -0,0 +1,30 @@
+const std = @import("std");
+const assert = std.debug.assert;
+const io = std.io;
+const str = std.str;
+
+enum ET {
+    SINT: i32,
+    UINT: u32,
+
+    pub fn print(a: &ET, buf: []u8) -> %usize {
+        return switch (*a) {
+            SINT => |x| { io.bufPrintInt(i32, buf, x) },
+            UINT => |x| { io.bufPrintInt(u32, buf, x) },
+        }
+    }
+}
+
+#attribute("test")
+fn enumWithMembers() {
+    const a = ET.SINT { -42 };
+    const b = ET.UINT { 42 };
+    var buf: [20]u8 = undefined;
+
+    assert(%%a.print(buf) == 3);
+    assert(str.eql(buf[0...3], "-42"));
+
+    assert(%%b.print(buf) == 2);
+    assert(str.eql(buf[0...2], "42"));
+}
+
test/self_hosted.zig
@@ -13,6 +13,7 @@ const test_var_params = @import("cases/var_params.zig");
 const test_const_slice_child = @import("cases/const_slice_child.zig");
 const test_switch_prong_implicit_cast = @import("cases/switch_prong_implicit_cast.zig");
 const test_switch_prong_err_enum = @import("cases/switch_prong_err_enum.zig");
+const test_enum_with_members = @import("cases/enum_with_members.zig");
 
 // normal comment
 /// this is a documentation comment