Commit 2ed72022ce

Andrew Kelley <superjoe30@gmail.com>
2016-05-08 02:00:58
support generic data structures
See #22
1 parent 01c46ee
doc/langref.md
@@ -15,7 +15,7 @@ GlobalVarDecl = VariableDeclaration ";"
 
 VariableDeclaration = ("var" | "const") "Symbol" option(":" TypeExpr) "=" Expression
 
-ContainerDecl = ("struct" | "enum" | "union") "Symbol" "{" many(StructMember) "}"
+ContainerDecl = ("struct" | "enum" | "union") "Symbol" option(ParamDeclList) "{" many(StructMember) "}"
 
 StructMember = many(Directive) option(VisibleMod) (StructField | FnDef)
 
src/all_types.hpp
@@ -600,12 +600,16 @@ struct AstNodeStructDecl {
     TopLevelDecl top_level_decl;
     Buf name;
     ContainerKind kind;
+    ZigList<AstNode *> generic_params;
+    bool generic_params_is_var_args; // always an error but it can happen from parsing
     ZigList<AstNode *> fields;
     ZigList<AstNode *> fns;
 
     // populated by semantic analyzer
     BlockContext *block_context;
     TypeTableEntry *type_entry;
+    TypeTableEntry *generic_fn_type;
+    bool skip;
 };
 
 struct AstNodeStructField {
src/analyze.cpp
@@ -756,11 +756,11 @@ static TypeTableEntryId container_to_type(ContainerKind kind) {
     zig_unreachable();
 }
 
-TypeTableEntry *get_partial_container_type(CodeGen *g, ImportTableEntry *import,
+TypeTableEntry *get_partial_container_type(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         ContainerKind kind, AstNode *decl_node, const char *name)
 {
     TypeTableEntryId type_id = container_to_type(kind);
-    TypeTableEntry *entry = new_container_type_entry(type_id, decl_node, import->block_context);
+    TypeTableEntry *entry = new_container_type_entry(type_id, decl_node, context);
 
     switch (kind) {
         case ContainerKindStruct:
@@ -1171,6 +1171,8 @@ static void resolve_enum_type(CodeGen *g, ImportTableEntry *import, TypeTableEnt
     uint64_t biggest_align_in_bits = 0;
     uint64_t biggest_union_member_size_in_bits = 0;
 
+    BlockContext *context = enum_type->data.enumeration.block_context;
+
     // set temporary flag
     enum_type->data.enumeration.embedded_in_current = true;
 
@@ -1179,7 +1181,7 @@ static void resolve_enum_type(CodeGen *g, ImportTableEntry *import, TypeTableEnt
         AstNode *field_node = decl_node->data.struct_decl.fields.at(i);
         TypeEnumField *type_enum_field = &enum_type->data.enumeration.fields[i];
         type_enum_field->name = &field_node->data.struct_field.name;
-        TypeTableEntry *field_type = analyze_type_expr(g, import, import->block_context,
+        TypeTableEntry *field_type = analyze_type_expr(g, import, context,
                 field_node->data.struct_field.type);
         type_enum_field->type_entry = field_type;
         type_enum_field->value = i;
@@ -1362,12 +1364,14 @@ static void resolve_struct_type(CodeGen *g, ImportTableEntry *import, TypeTableE
     // this field should be set to true only during the recursive calls to resolve_struct_type
     struct_type->data.structure.embedded_in_current = true;
 
+    BlockContext *context = struct_type->data.structure.block_context;
+
     int gen_field_index = 0;
     for (int i = 0; i < field_count; i += 1) {
         AstNode *field_node = decl_node->data.struct_decl.fields.at(i);
         TypeStructField *type_struct_field = &struct_type->data.structure.fields[i];
         type_struct_field->name = &field_node->data.struct_field.name;
-        TypeTableEntry *field_type = analyze_type_expr(g, import, import->block_context,
+        TypeTableEntry *field_type = analyze_type_expr(g, import, context,
                 field_node->data.struct_field.type);
         type_struct_field->type_entry = field_type;
         type_struct_field->src_index = i;
@@ -1469,16 +1473,28 @@ static void get_fully_qualified_decl_name(Buf *buf, AstNode *decl_node, uint8_t
 }
 
 static void preview_generic_fn_proto(CodeGen *g, ImportTableEntry *import, AstNode *node) {
-    assert(node->type == NodeTypeFnProto);
+    if (node->type == NodeTypeFnProto) {
+        if (node->data.fn_proto.generic_params_is_var_args) {
+            add_node_error(g, node, buf_sprintf("generic parameters cannot be var args"));
+            node->data.fn_proto.skip = true;
+            node->data.fn_proto.generic_fn_type = g->builtin_types.entry_invalid;
+            return;
+        }
 
-    if (node->data.fn_proto.generic_params_is_var_args) {
-        add_node_error(g, node, buf_sprintf("generic parameters cannot be var args"));
-        node->data.fn_proto.skip = true;
-        node->data.fn_proto.generic_fn_type = g->builtin_types.entry_invalid;
-        return;
+        node->data.fn_proto.generic_fn_type = get_generic_fn_type(g, node);
+    } else if (node->type == NodeTypeStructDecl) {
+        if (node->data.struct_decl.generic_params_is_var_args) {
+            add_node_error(g, node, buf_sprintf("generic parameters cannot be var args"));
+            node->data.struct_decl.skip = true;
+            node->data.struct_decl.generic_fn_type = g->builtin_types.entry_invalid;
+            return;
+        }
+
+        node->data.struct_decl.generic_fn_type = get_generic_fn_type(g, node);
+    } else {
+        zig_unreachable();
     }
 
-    node->data.fn_proto.generic_fn_type = get_generic_fn_type(g, node);
 }
 
 static void preview_fn_proto_instance(CodeGen *g, ImportTableEntry *import, AstNode *proto_node,
@@ -1538,6 +1554,50 @@ static void preview_fn_proto(CodeGen *g, ImportTableEntry *import, AstNode *prot
 
 }
 
+static void scan_struct_decl(CodeGen *g, ImportTableEntry *import, BlockContext *context, AstNode *node) {
+    assert(node->type == NodeTypeStructDecl);
+
+    Buf *name = &node->data.struct_decl.name;
+    TypeTableEntry *container_type = get_partial_container_type(g, import, context,
+            node->data.struct_decl.kind, node, buf_ptr(name));
+    node->data.struct_decl.type_entry = container_type;
+
+    // handle the member function definitions independently
+    for (int i = 0; i < node->data.struct_decl.fns.length; i += 1) {
+        AstNode *child_node = node->data.struct_decl.fns.at(i);
+        get_as_top_level_decl(child_node)->parent_decl = node;
+        BlockContext *child_context = get_container_block_context(container_type);
+        scan_decls(g, import, child_context, child_node);
+    }
+}
+
+static void resolve_struct_instance(CodeGen *g, ImportTableEntry *import, AstNode *node) {
+    TypeTableEntry *type_entry = node->data.struct_decl.type_entry;
+    assert(type_entry);
+
+    // struct/enum member fns will get resolved independently
+
+    switch (node->data.struct_decl.kind) {
+        case ContainerKindStruct:
+            resolve_struct_type(g, import, type_entry);
+            break;
+        case ContainerKindEnum:
+            resolve_enum_type(g, import, type_entry);
+            break;
+        case ContainerKindUnion:
+            resolve_union_type(g, import, type_entry);
+            break;
+    }
+}
+
+static void resolve_struct_decl(CodeGen *g, ImportTableEntry *import, AstNode *node) {
+    if (node->data.struct_decl.generic_params.length > 0) {
+        return preview_generic_fn_proto(g, import, node);
+    } else {
+        return resolve_struct_instance(g, import, node);
+    }
+}
+
 static void preview_error_value_decl(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypeErrorValueDecl);
 
@@ -1587,25 +1647,8 @@ static void resolve_top_level_decl(CodeGen *g, AstNode *node, bool pointer_only)
             preview_fn_proto(g, import, node);
             break;
         case NodeTypeStructDecl:
-            {
-                TypeTableEntry *type_entry = node->data.struct_decl.type_entry;
-
-                // struct/enum member fns will get resolved independently
-
-                switch (node->data.struct_decl.kind) {
-                    case ContainerKindStruct:
-                        resolve_struct_type(g, import, type_entry);
-                        break;
-                    case ContainerKindEnum:
-                        resolve_enum_type(g, import, type_entry);
-                        break;
-                    case ContainerKindUnion:
-                        resolve_union_type(g, import, type_entry);
-                        break;
-                }
-
-                break;
-            }
+            resolve_struct_decl(g, import, node);
+            break;
         case NodeTypeVariableDeclaration:
             {
                 AstNodeVariableDeclaration *variable_declaration = &node->data.variable_declaration;
@@ -2729,6 +2772,7 @@ static TypeTableEntry *resolve_expr_const_val_as_generic_fn(CodeGen *g, AstNode
     return type_entry;
 }
 
+
 static TypeTableEntry *resolve_expr_const_val_as_err(CodeGen *g, AstNode *node, ErrorTableEntry *err) {
     Expr *expr = get_resolved_expr(node);
     expr->const_val.ok = true;
@@ -2888,7 +2932,13 @@ static TypeTableEntry *analyze_decl_ref(CodeGen *g, AstNode *source_node, AstNod
             return resolve_expr_const_val_as_fn(g, source_node, fn_entry);
         }
     } else if (decl_node->type == NodeTypeStructDecl) {
-        return resolve_expr_const_val_as_type(g, source_node, decl_node->data.struct_decl.type_entry);
+        if (decl_node->data.struct_decl.generic_params.length > 0) {
+            TypeTableEntry *type_entry = decl_node->data.struct_decl.generic_fn_type;
+            assert(type_entry);
+            return resolve_expr_const_val_as_generic_fn(g, source_node, type_entry);
+        } else {
+            return resolve_expr_const_val_as_type(g, source_node, decl_node->data.struct_decl.type_entry);
+        }
     } else if (decl_node->type == NodeTypeTypeDecl) {
         return resolve_expr_const_val_as_type(g, source_node, decl_node->data.type_decl.child_type_entry);
     } else {
@@ -5043,9 +5093,16 @@ static TypeTableEntry *analyze_generic_fn_call(CodeGen *g, ImportTableEntry *imp
     assert(generic_fn_type->id == TypeTableEntryIdGenericFn);
 
     AstNode *decl_node = generic_fn_type->data.generic_fn.decl_node;
-    assert(decl_node->type == NodeTypeFnProto);
+    ZigList<AstNode *> *generic_params;
+    if (decl_node->type == NodeTypeFnProto) {
+        generic_params = &decl_node->data.fn_proto.generic_params;
+    } else if (decl_node->type == NodeTypeStructDecl) {
+        generic_params = &decl_node->data.struct_decl.generic_params;
+    } else {
+        zig_unreachable();
+    }
 
-    int expected_param_count = decl_node->data.fn_proto.generic_params.length;
+    int expected_param_count = generic_params->length;
     int actual_param_count = node->data.fn_call_expr.params.length;
 
     if (actual_param_count != expected_param_count) {
@@ -5061,7 +5118,7 @@ static TypeTableEntry *analyze_generic_fn_call(CodeGen *g, ImportTableEntry *imp
 
     BlockContext *child_context = decl_node->owner->block_context;
     for (int i = 0; i < actual_param_count; i += 1) {
-        AstNode *generic_param_decl_node = decl_node->data.fn_proto.generic_params.at(i);
+        AstNode *generic_param_decl_node = generic_params->at(i);
         assert(generic_param_decl_node->type == NodeTypeParamDecl);
 
         AstNode **generic_param_type_node = &generic_param_decl_node->data.param_decl.type;
@@ -5104,24 +5161,36 @@ static TypeTableEntry *analyze_generic_fn_call(CodeGen *g, ImportTableEntry *imp
     auto entry = g->generic_table.maybe_get(generic_fn_type_id);
     if (entry) {
         AstNode *impl_decl_node = entry->value;
-        assert(impl_decl_node->type == NodeTypeFnProto);
-        FnTableEntry *fn_table_entry = impl_decl_node->data.fn_proto.fn_table_entry;
-        return resolve_expr_const_val_as_fn(g, node, fn_table_entry);
+        if (impl_decl_node->type == NodeTypeFnProto) {
+            FnTableEntry *fn_table_entry = impl_decl_node->data.fn_proto.fn_table_entry;
+            return resolve_expr_const_val_as_fn(g, node, fn_table_entry);
+        } else if (impl_decl_node->type == NodeTypeStructDecl) {
+            TypeTableEntry *type_entry = impl_decl_node->data.struct_decl.type_entry;
+            return resolve_expr_const_val_as_type(g, node, type_entry);
+        } else {
+            zig_unreachable();
+        }
     }
 
     // make a type from the generic parameters supplied
-    assert(decl_node->type == NodeTypeFnProto);
-    AstNode *impl_fn_def_node = ast_clone_subtree(decl_node->data.fn_proto.fn_def_node, &g->next_node_index);
-    AstNode *impl_decl_node = impl_fn_def_node->data.fn_def.fn_proto;
-
-
+    if (decl_node->type == NodeTypeFnProto) {
+        AstNode *impl_fn_def_node = ast_clone_subtree(decl_node->data.fn_proto.fn_def_node, &g->next_node_index);
+        AstNode *impl_decl_node = impl_fn_def_node->data.fn_def.fn_proto;
 
-    preview_fn_proto_instance(g, import, impl_decl_node, child_context);
-
-    g->generic_table.put(generic_fn_type_id, impl_decl_node);
-
-    FnTableEntry *fn_table_entry = impl_decl_node->data.fn_proto.fn_table_entry;
-    return resolve_expr_const_val_as_fn(g, node, fn_table_entry);
+        preview_fn_proto_instance(g, import, impl_decl_node, child_context);
+        g->generic_table.put(generic_fn_type_id, impl_decl_node);
+        FnTableEntry *fn_table_entry = impl_decl_node->data.fn_proto.fn_table_entry;
+        return resolve_expr_const_val_as_fn(g, node, fn_table_entry);
+    } else if (decl_node->type == NodeTypeStructDecl) {
+        AstNode *impl_decl_node = ast_clone_subtree(decl_node, &g->next_node_index);
+        g->generic_table.put(generic_fn_type_id, impl_decl_node);
+        scan_struct_decl(g, import, child_context, impl_decl_node);
+        TypeTableEntry *type_entry = impl_decl_node->data.struct_decl.type_entry;
+        resolve_struct_type(g, import, type_entry);
+        return resolve_expr_const_val_as_type(g, node, type_entry);
+    } else {
+        zig_unreachable();
+    }
 }
 
 static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
@@ -6065,7 +6134,8 @@ static void add_top_level_decl(CodeGen *g, ImportTableEntry *import, BlockContex
     tld->name = name;
 
     bool want_as_export = (g->check_unused || g->is_test_build || tld->visib_mod == VisibModExport);
-    bool is_generic = (node->type == NodeTypeFnProto && node->data.fn_proto.generic_params.length > 0);
+    bool is_generic = (node->type == NodeTypeFnProto && node->data.fn_proto.generic_params.length > 0) ||
+                      (node->type == NodeTypeStructDecl && node->data.struct_decl.generic_params.length > 0);
     if (!is_generic && want_as_export) {
         g->export_queue.append(node);
     }
@@ -6093,21 +6163,12 @@ static void scan_decls(CodeGen *g, ImportTableEntry *import, BlockContext *conte
         case NodeTypeStructDecl:
             {
                 Buf *name = &node->data.struct_decl.name;
-                TypeTableEntry *container_type = get_partial_container_type(g, import,
-                        node->data.struct_decl.kind, node, buf_ptr(name));
-                node->data.struct_decl.type_entry = container_type;
                 add_top_level_decl(g, import, context, node, name);
-
-                // handle the member function definitions independently
-                for (int i = 0; i < node->data.struct_decl.fns.length; i += 1) {
-                    AstNode *child_node = node->data.struct_decl.fns.at(i);
-                    get_as_top_level_decl(child_node)->parent_decl = node;
-                    BlockContext *child_context = get_container_block_context(container_type);
-                    scan_decls(g, import, child_context, child_node);
+                if (node->data.struct_decl.generic_params.length == 0) {
+                    scan_struct_decl(g, import, context, node);
                 }
-
-                break;
             }
+            break;
         case NodeTypeFnDef:
             node->data.fn_def.fn_proto->data.fn_proto.fn_def_node = node;
             scan_decls(g, import, context, node->data.fn_def.fn_proto);
src/analyze.hpp
@@ -27,7 +27,7 @@ TypeTableEntry *get_fn_type(CodeGen *g, FnTypeId *fn_type_id);
 TypeTableEntry *get_maybe_type(CodeGen *g, TypeTableEntry *child_type);
 TypeTableEntry *get_array_type(CodeGen *g, TypeTableEntry *child_type, uint64_t array_size);
 TypeTableEntry *get_slice_type(CodeGen *g, TypeTableEntry *child_type, bool is_const);
-TypeTableEntry *get_partial_container_type(CodeGen *g, ImportTableEntry *import,
+TypeTableEntry *get_partial_container_type(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         ContainerKind kind, AstNode *decl_node, const char *name);
 TypeTableEntry *get_smallest_unsigned_int_type(CodeGen *g, uint64_t x);
 bool handle_is_ptr(TypeTableEntry *type_entry);
src/parseh.cpp
@@ -801,6 +801,7 @@ static TypeTableEntry *resolve_enum_decl(Context *c, const EnumDecl *enum_decl)
     const EnumDecl *enum_def = enum_decl->getDefinition();
     if (!enum_def) {
         TypeTableEntry *enum_type = get_partial_container_type(c->codegen, c->import,
+                c->import->block_context,
                 ContainerKindEnum, c->source_node, buf_ptr(full_type_name));
         c->enum_type_table.put(bare_name, enum_type);
         c->decl_table.put(enum_decl, enum_type);
@@ -825,6 +826,7 @@ static TypeTableEntry *resolve_enum_decl(Context *c, const EnumDecl *enum_decl)
 
     if (pure_enum) {
         TypeTableEntry *enum_type = get_partial_container_type(c->codegen, c->import,
+                c->import->block_context,
                 ContainerKindEnum, c->source_node, buf_ptr(full_type_name));
         c->enum_type_table.put(bare_name, enum_type);
         c->decl_table.put(enum_decl, enum_type);
@@ -985,7 +987,7 @@ static TypeTableEntry *resolve_record_decl(Context *c, const RecordDecl *record_
 
 
     TypeTableEntry *struct_type = get_partial_container_type(c->codegen, c->import,
-            ContainerKindStruct, c->source_node, buf_ptr(full_type_name));
+            c->import->block_context, ContainerKindStruct, c->source_node, buf_ptr(full_type_name));
 
     c->struct_type_table.put(bare_name, struct_type);
     c->decl_table.put(record_decl, struct_type);
src/parser.cpp
@@ -762,9 +762,7 @@ static void ast_parse_param_decl_list(ParseContext *pc, int *token_index,
 {
     *is_var_args = false;
 
-    Token *l_paren = &pc->tokens->at(*token_index);
-    *token_index += 1;
-    ast_expect_token(pc, l_paren, TokenIdLParen);
+    ast_eat_token(pc, token_index, TokenIdLParen);
 
     Token *token = &pc->tokens->at(*token_index);
     if (token->id == TokenIdRParen) {
@@ -2606,7 +2604,7 @@ static AstNode *ast_parse_use(ParseContext *pc, int *token_index,
 }
 
 /*
-ContainerDecl = ("struct" | "enum" | "union") "Symbol" "{" many(StructMember) "}"
+ContainerDecl = ("struct" | "enum" | "union") "Symbol" option(ParamDeclList) "{" many(StructMember) "}"
 StructMember: many(Directive) option(VisibleMod) (StructField | FnDef)
 StructField : "Symbol" option(":" Expression) ",")
 */
@@ -2636,7 +2634,16 @@ static AstNode *ast_parse_container_decl(ParseContext *pc, int *token_index,
     node->data.struct_decl.top_level_decl.visib_mod = visib_mod;
     node->data.struct_decl.top_level_decl.directives = directives;
 
-    ast_eat_token(pc, token_index, TokenIdLBrace);
+    Token *paren_or_brace = &pc->tokens->at(*token_index);
+    if (paren_or_brace->id == TokenIdLParen) {
+        ast_parse_param_decl_list(pc, token_index, &node->data.struct_decl.generic_params,
+                &node->data.struct_decl.generic_params_is_var_args);
+        ast_eat_token(pc, token_index, TokenIdLBrace);
+    } else if (paren_or_brace->id == TokenIdLBrace) {
+        *token_index += 1;
+    } else {
+        ast_invalid_token_error(pc, paren_or_brace);
+    }
 
     for (;;) {
         Token *directive_token = &pc->tokens->at(*token_index);
test/self_hosted.zig
@@ -1566,3 +1566,15 @@ fn c_string_concatenation() {
     assert(a[len] == 0);
     assert(b[len] == 0);
 }
+
+#attribute("test")
+fn generic_struct() {
+    var a1 = GenNode(i32) {.value = 13, .next = null,};
+    var b1 = GenNode(bool) {.value = true, .next = null,};
+    assert(a1.value == 13);
+    assert(b1.value);
+}
+struct GenNode(T: type) {
+    value: T,
+    next: ?&GenNode(T),
+}