Commit 67152f7294

Andrew Kelley <superjoe30@gmail.com>
2016-04-06 20:38:12
support simple generic functions
1 parent e144dda
doc/langref.md
@@ -25,7 +25,7 @@ UseDecl = "use" Expression ";"
 
 ExternDecl = "extern" (FnProto | VariableDeclaration) ";"
 
-FnProto = "fn" option("Symbol") ParamDeclList option("->" TypeExpr)
+FnProto = "fn" option("Symbol") option(ParamDeclList) ParamDeclList option("->" TypeExpr)
 
 Directive = "#" "Symbol" "(" Expression ")"
 
src/all_types.hpp
@@ -193,8 +193,10 @@ struct AstNodeRoot {
 struct AstNodeFnProto {
     TopLevelDecl top_level_decl;
     Buf name;
+    ZigList<AstNode *> generic_params;
     ZigList<AstNode *> params;
     AstNode *return_type;
+    bool generic_params_is_var_args;
     bool is_var_args;
     bool is_extern;
     bool is_inline;
@@ -206,6 +208,7 @@ struct AstNodeFnProto {
     FnTableEntry *fn_table_entry;
     bool skip;
     Expr resolved_expr;
+    TypeTableEntry *generic_fn_type;
 };
 
 struct AstNodeFnDef {
@@ -797,6 +800,21 @@ struct FnTypeParamInfo {
     TypeTableEntry *type;
 };
 
+struct GenericParamValue {
+    TypeTableEntry *type;
+    AstNode *node;
+};
+
+struct GenericFnTypeId {
+    AstNode *decl_node; // the generic fn or container decl node
+    GenericParamValue *generic_params;
+    int generic_param_count;
+};
+
+uint32_t generic_fn_type_id_hash(GenericFnTypeId *id);
+bool generic_fn_type_id_eql(GenericFnTypeId *a, GenericFnTypeId *b);
+
+
 static const int fn_type_id_prealloc_param_info_count = 4;
 struct FnTypeId {
     TypeTableEntry *return_type;
@@ -812,7 +830,6 @@ struct FnTypeId {
 uint32_t fn_type_id_hash(FnTypeId*);
 bool fn_type_id_eql(FnTypeId *a, FnTypeId *b);
 
-
 struct TypeTableEntryPointer {
     TypeTableEntry *child_type;
     bool is_const;
@@ -899,6 +916,10 @@ struct TypeTableEntryFn {
     LLVMCallConv calling_convention;
 };
 
+struct TypeTableEntryGenericFn {
+    AstNode *decl_node;
+};
+
 struct TypeTableEntryTypeDecl {
     TypeTableEntry *child_type;
     TypeTableEntry *canonical_type;
@@ -925,6 +946,7 @@ enum TypeTableEntryId {
     TypeTableEntryIdFn,
     TypeTableEntryIdTypeDecl,
     TypeTableEntryIdNamespace,
+    TypeTableEntryIdGenericFn,
 };
 
 struct TypeTableEntry {
@@ -947,6 +969,7 @@ struct TypeTableEntry {
         TypeTableEntryEnum enumeration;
         TypeTableEntryFn fn;
         TypeTableEntryTypeDecl type_decl;
+        TypeTableEntryGenericFn generic_fn;
     } data;
 
     // use these fields to make sure we don't duplicate type table entries for the same type
@@ -992,6 +1015,7 @@ struct FnTableEntry {
     bool internal_linkage;
     bool is_extern;
     bool is_test;
+    BlockContext *parent_block_context;
 
     ZigList<AstNode *> cast_alloca_list;
     ZigList<StructValExprCodeGen *> struct_val_expr_alloca_list;
@@ -1047,6 +1071,7 @@ struct CodeGen {
     HashMap<Buf *, TypeTableEntry *, buf_hash, buf_eql_buf> primitive_type_table;
     HashMap<FnTypeId *, TypeTableEntry *, fn_type_id_hash, fn_type_id_eql> fn_type_table;
     HashMap<Buf *, ErrorTableEntry *, buf_hash, buf_eql_buf> error_table;
+    HashMap<GenericFnTypeId *, AstNode *, generic_fn_type_id_hash, generic_fn_type_id_eql> generic_table;
 
     ZigList<ImportTableEntry *> import_queue;
     int import_queue_index;
@@ -1172,7 +1197,10 @@ struct VariableTableEntry {
     LLVMValueRef value_ref;
     bool is_const;
     bool is_ptr; // if true, value_ref is a pointer
+    // which node is the declaration of the variable
     AstNode *decl_node;
+    // which node contains the ConstExprValue for this variable's value
+    AstNode *val_node;
     LLVMZigDILocalVariable *di_loc_var;
     int src_arg_index;
     int gen_arg_index;
src/analyze.cpp
@@ -41,6 +41,7 @@ static VariableTableEntry *analyze_variable_declaration_raw(CodeGen *g, ImportTa
         AstNodeVariableDeclaration *variable_declaration,
         bool expr_is_maybe, AstNode *decl_node);
 static void scan_decls(CodeGen *g, ImportTableEntry *import, BlockContext *context, AstNode *node);
+static void analyze_fn_body(CodeGen *g, FnTableEntry *fn_table_entry);
 
 static AstNode *first_executing_node(AstNode *node) {
     switch (node->type) {
@@ -192,6 +193,7 @@ static bool type_is_complete(TypeTableEntry *type_entry) {
         case TypeTableEntryIdFn:
         case TypeTableEntryIdTypeDecl:
         case TypeTableEntryIdNamespace:
+        case TypeTableEntryIdGenericFn:
             return true;
     }
     zig_unreachable();
@@ -201,6 +203,14 @@ TypeTableEntry *get_smallest_unsigned_int_type(CodeGen *g, uint64_t x) {
     return get_int_type(g, false, bits_needed_for_unsigned(x));
 }
 
+static TypeTableEntry *get_generic_fn_type(CodeGen *g, AstNode *decl_node) {
+    TypeTableEntry *entry = new_type_table_entry(TypeTableEntryIdGenericFn);
+    buf_init_from_str(&entry->name, "(generic function)");
+    entry->zero_bits = true;
+    entry->data.generic_fn.decl_node = decl_node;
+    return entry;
+}
+
 TypeTableEntry *get_pointer_to_type(CodeGen *g, TypeTableEntry *child_type, bool is_const) {
     assert(child_type->id != TypeTableEntryIdInvalid);
     TypeTableEntry **parent_pointer = &child_type->pointer_parent[(is_const ? 1 : 0)];
@@ -776,7 +786,7 @@ static TypeTableEntry *analyze_fn_proto_type(CodeGen *g, ImportTableEntry *impor
     }
 
     fn_type_id.is_var_args = fn_proto->is_var_args;
-    fn_type_id.return_type = analyze_type_expr(g, import, import->block_context, node->data.fn_proto.return_type);
+    fn_type_id.return_type = analyze_type_expr(g, import, context, node->data.fn_proto.return_type);
 
     if (fn_type_id.return_type->id == TypeTableEntryIdInvalid) {
         fn_proto->skip = true;
@@ -785,7 +795,7 @@ static TypeTableEntry *analyze_fn_proto_type(CodeGen *g, ImportTableEntry *impor
     for (int i = 0; i < fn_type_id.param_count; i += 1) {
         AstNode *child = node->data.fn_proto.params.at(i);
         assert(child->type == NodeTypeParamDecl);
-        TypeTableEntry *type_entry = analyze_type_expr(g, import, import->block_context,
+        TypeTableEntry *type_entry = analyze_type_expr(g, import, context,
                 child->data.param_decl.type);
         switch (type_entry->id) {
             case TypeTableEntryIdInvalid:
@@ -797,6 +807,7 @@ static TypeTableEntry *analyze_fn_proto_type(CodeGen *g, ImportTableEntry *impor
             case TypeTableEntryIdMetaType:
             case TypeTableEntryIdUnreachable:
             case TypeTableEntryIdNamespace:
+            case TypeTableEntryIdGenericFn:
                 fn_proto->skip = true;
                 add_node_error(g, child->data.param_decl.type,
                     buf_sprintf("parameter of type '%s' not allowed'", buf_ptr(&type_entry->name)));
@@ -880,7 +891,7 @@ static bool resolve_const_expr_bool(CodeGen *g, ImportTableEntry *import, BlockC
 }
 
 static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_table_entry,
-        ImportTableEntry *import)
+        ImportTableEntry *import, BlockContext *containing_context)
 {
     assert(node->type == NodeTypeFnProto);
     AstNodeFnProto *fn_proto = &node->data.fn_proto;
@@ -946,7 +957,7 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t
 
 
 
-    TypeTableEntry *fn_type = analyze_fn_proto_type(g, import, import->block_context, nullptr, node,
+    TypeTableEntry *fn_type = analyze_fn_proto_type(g, import, containing_context, nullptr, node,
             is_naked, is_cold);
 
     fn_table_entry->type_entry = fn_type;
@@ -963,6 +974,8 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t
     } else {
         symbol_name = buf_sprintf("_%s", buf_ptr(&fn_table_entry->symbol_name));
     }
+    // TODO mangle the name if it's a generic instance
+
     fn_table_entry->fn_value = LLVMAddFunction(g->module, buf_ptr(symbol_name),
         fn_type->data.fn.raw_type_ref);
 
@@ -992,12 +1005,12 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t
         unsigned flags = 0;
         bool is_optimized = g->is_release_build;
         LLVMZigDISubprogram *subprogram = LLVMZigCreateFunction(g->dbuilder,
-            import->block_context->di_scope, buf_ptr(&fn_table_entry->symbol_name), "",
+            containing_context->di_scope, buf_ptr(&fn_table_entry->symbol_name), "",
             import->di_file, line_number,
             fn_type->di_type, fn_table_entry->internal_linkage,
             is_definition, scope_line, flags, is_optimized, nullptr);
 
-        BlockContext *context = new_block_context(fn_table_entry->fn_def_node, import->block_context);
+        BlockContext *context = new_block_context(fn_table_entry->fn_def_node, containing_context);
         fn_table_entry->fn_def_node->data.fn_def.block_context = context;
         context->di_scope = LLVMZigSubprogramToScope(subprogram);
     }
@@ -1321,17 +1334,35 @@ static void get_fully_qualified_decl_name(Buf *buf, AstNode *decl_node, uint8_t
     }
 }
 
-static void preview_fn_proto(CodeGen *g, ImportTableEntry *import, AstNode *proto_node) {
+static void preview_generic_fn_proto(CodeGen *g, ImportTableEntry *import, AstNode *node) {
+    assert(node->type == NodeTypeFnProto);
+
+    if (node->data.fn_proto.generic_params_is_var_args) {
+        add_node_error(g, node, buf_sprintf("generic parameters cannot be var args"));
+        node->data.fn_proto.skip = true;
+        node->data.fn_proto.generic_fn_type = g->builtin_types.entry_invalid;
+        return;
+    }
+
+    node->data.fn_proto.generic_fn_type = get_generic_fn_type(g, node);
+}
+
+static void preview_fn_proto_instance(CodeGen *g, ImportTableEntry *import, AstNode *proto_node,
+        BlockContext *containing_context)
+{
     if (proto_node->data.fn_proto.skip) {
         return;
     }
 
+    bool is_generic_instance = (proto_node->data.fn_proto.generic_params.length > 0);
+
     AstNode *parent_decl = proto_node->data.fn_proto.top_level_decl.parent_decl;
+    Buf *proto_name = &proto_node->data.fn_proto.name;
 
     AstNode *fn_def_node = proto_node->data.fn_proto.fn_def_node;
     bool is_extern = proto_node->data.fn_proto.is_extern;
 
-    Buf *proto_name = &proto_node->data.fn_proto.name;
+    assert(!is_extern || !is_generic_instance);
 
     if (!is_extern && proto_node->data.fn_proto.is_var_args) {
         add_node_error(g, proto_node,
@@ -1352,13 +1383,24 @@ static void preview_fn_proto(CodeGen *g, ImportTableEntry *import, AstNode *prot
         g->fn_defs.append(fn_table_entry);
     }
 
-    bool is_main_fn = !parent_decl && (import == g->root_import) && buf_eql_str(proto_name, "main");
+    bool is_main_fn = !is_generic_instance &&
+        !parent_decl && (import == g->root_import) &&
+        buf_eql_str(proto_name, "main");
     if (is_main_fn) {
         g->main_fn = fn_table_entry;
     }
 
     proto_node->data.fn_proto.fn_table_entry = fn_table_entry;
-    resolve_function_proto(g, proto_node, fn_table_entry, import);
+    resolve_function_proto(g, proto_node, fn_table_entry, import, containing_context);
+}
+
+static void preview_fn_proto(CodeGen *g, ImportTableEntry *import, AstNode *proto_node) {
+    if (proto_node->data.fn_proto.generic_params.length > 0) {
+        return preview_generic_fn_proto(g, import, proto_node);
+    } else {
+        return preview_fn_proto_instance(g, import, proto_node, import->block_context);
+    }
+
 }
 
 static void preview_error_value_decl(CodeGen *g, AstNode *node) {
@@ -1539,6 +1581,7 @@ static bool type_has_codegen_value(TypeTableEntry *type_entry) {
         case TypeTableEntryIdNumLitInt:
         case TypeTableEntryIdUndefLit:
         case TypeTableEntryIdNamespace:
+        case TypeTableEntryIdGenericFn:
             return false;
 
         case TypeTableEntryIdBool:
@@ -2433,6 +2476,15 @@ static TypeTableEntry *resolve_expr_const_val_as_fn(CodeGen *g, AstNode *node, F
     return fn->type_entry;
 }
 
+static TypeTableEntry *resolve_expr_const_val_as_generic_fn(CodeGen *g, AstNode *node,
+        TypeTableEntry *type_entry)
+{
+    Expr *expr = get_resolved_expr(node);
+    expr->const_val.ok = true;
+    expr->const_val.data.x_type = type_entry;
+    return type_entry;
+}
+
 static TypeTableEntry *resolve_expr_const_val_as_err(CodeGen *g, AstNode *node, ErrorTableEntry *err) {
     Expr *expr = get_resolved_expr(node);
     expr->const_val.ok = true;
@@ -2570,14 +2622,10 @@ static TypeTableEntry *analyze_error_literal_expr(CodeGen *g, ImportTableEntry *
 
 static TypeTableEntry *analyze_var_ref(CodeGen *g, AstNode *source_node, VariableTableEntry *var) {
     get_resolved_expr(source_node)->variable = var;
-    if (var->is_const) {
-        AstNode *decl_node = var->decl_node;
-        if (decl_node->type == NodeTypeVariableDeclaration) {
-            AstNode *expr_node = decl_node->data.variable_declaration.expr;
-            ConstExprValue *other_const_val = &get_resolved_expr(expr_node)->const_val;
-            if (other_const_val->ok) {
-                return resolve_expr_const_val_as_other_expr(g, source_node, expr_node);
-            }
+    if (var->is_const && var->val_node) {
+        ConstExprValue *other_const_val = &get_resolved_expr(var->val_node)->const_val;
+        if (other_const_val->ok) {
+            return resolve_expr_const_val_as_other_expr(g, source_node, var->val_node);
         }
     }
     return var->type;
@@ -2596,9 +2644,15 @@ static TypeTableEntry *analyze_decl_ref(CodeGen *g, AstNode *source_node, AstNod
         VariableTableEntry *var = decl_node->data.variable_declaration.variable;
         return analyze_var_ref(g, source_node, var);
     } else if (decl_node->type == NodeTypeFnProto) {
-        FnTableEntry *fn_entry = decl_node->data.fn_proto.fn_table_entry;
-        assert(fn_entry->type_entry);
-        return resolve_expr_const_val_as_fn(g, source_node, fn_entry);
+        if (decl_node->data.fn_proto.generic_params.length > 0) {
+            TypeTableEntry *type_entry = decl_node->data.fn_proto.generic_fn_type;
+            assert(type_entry);
+            return resolve_expr_const_val_as_generic_fn(g, source_node, type_entry);
+        } else {
+            FnTableEntry *fn_entry = decl_node->data.fn_proto.fn_table_entry;
+            assert(fn_entry->type_entry);
+            return resolve_expr_const_val_as_fn(g, source_node, fn_entry);
+        }
     } else if (decl_node->type == NodeTypeStructDecl) {
         return resolve_expr_const_val_as_type(g, source_node, decl_node->data.struct_decl.type_entry);
     } else if (decl_node->type == NodeTypeTypeDecl) {
@@ -3113,7 +3167,7 @@ static TypeTableEntry *analyze_bin_op_expr(CodeGen *g, ImportTableEntry *import,
 
 // Set name to nullptr to make the variable anonymous (not visible to programmer).
 static VariableTableEntry *add_local_var(CodeGen *g, AstNode *source_node, ImportTableEntry *import,
-        BlockContext *context, Buf *name, TypeTableEntry *type_entry, bool is_const)
+        BlockContext *context, Buf *name, TypeTableEntry *type_entry, bool is_const, AstNode *val_node)
 {
     VariableTableEntry *variable_entry = allocate<VariableTableEntry>(1);
     variable_entry->type = type_entry;
@@ -3160,6 +3214,8 @@ static VariableTableEntry *add_local_var(CodeGen *g, AstNode *source_node, Impor
     variable_entry->is_const = is_const;
     variable_entry->is_ptr = true;
     variable_entry->decl_node = source_node;
+    variable_entry->val_node = val_node;
+
 
     return variable_entry;
 }
@@ -3182,7 +3238,7 @@ static TypeTableEntry *analyze_unwrap_error_expr(CodeGen *g, ImportTableEntry *i
             var_node->block_context = child_context;
             Buf *var_name = &var_node->data.symbol_expr.symbol;
             node->data.unwrap_err_expr.var = add_local_var(g, var_node, import, child_context, var_name,
-                    g->builtin_types.entry_pure_error, true);
+                    g->builtin_types.entry_pure_error, true, nullptr);
         } else {
             child_context = parent_context;
         }
@@ -3260,7 +3316,8 @@ static VariableTableEntry *analyze_variable_declaration_raw(CodeGen *g, ImportTa
     assert(type != nullptr); // should have been caught by the parser
 
     VariableTableEntry *var = add_local_var(g, source_node, import, context,
-            &variable_declaration->symbol, type, is_const);
+            &variable_declaration->symbol, type, is_const,
+            expr_is_maybe ? nullptr : variable_declaration->expr);
 
     variable_declaration->variable = var;
 
@@ -3453,17 +3510,17 @@ static TypeTableEntry *analyze_for_expr(CodeGen *g, ImportTableEntry *import, Bl
     elem_var_node->block_context = child_context;
     Buf *elem_var_name = &elem_var_node->data.symbol_expr.symbol;
     node->data.for_expr.elem_var = add_local_var(g, elem_var_node, import, child_context, elem_var_name,
-            child_type, true);
+            child_type, true, nullptr);
 
     AstNode *index_var_node = node->data.for_expr.index_node;
     if (index_var_node) {
         Buf *index_var_name = &index_var_node->data.symbol_expr.symbol;
         index_var_node->block_context = child_context;
         node->data.for_expr.index_var = add_local_var(g, index_var_node, import, child_context, index_var_name,
-                g->builtin_types.entry_isize, true);
+                g->builtin_types.entry_isize, true, nullptr);
     } else {
         node->data.for_expr.index_var = add_local_var(g, node, import, child_context, nullptr,
-                g->builtin_types.entry_isize, true);
+                g->builtin_types.entry_isize, true, nullptr);
     }
 
     AstNode *for_body_node = node->data.for_expr.body;
@@ -4330,6 +4387,7 @@ static TypeTableEntry *analyze_builtin_fn_call_expr(CodeGen *g, ImportTableEntry
                     case TypeTableEntryIdNumLitInt:
                     case TypeTableEntryIdUndefLit:
                     case TypeTableEntryIdNamespace:
+                    case TypeTableEntryIdGenericFn:
                         add_node_error(g, expr_node,
                                 buf_sprintf("type '%s' not eligible for @typeof", buf_ptr(&type_entry->name)));
                         return g->builtin_types.entry_invalid;
@@ -4541,6 +4599,92 @@ static TypeTableEntry *analyze_fn_call_raw(CodeGen *g, ImportTableEntry *import,
     return analyze_fn_call_ptr(g, import, context, expected_type, node, fn_table_entry->type_entry, struct_type);
 }
 
+static TypeTableEntry *analyze_generic_fn_call(CodeGen *g, ImportTableEntry *import, BlockContext *parent_context,
+        TypeTableEntry *expected_type, AstNode *node, TypeTableEntry *generic_fn_type)
+{
+    assert(node->type == NodeTypeFnCallExpr);
+    assert(generic_fn_type->id == TypeTableEntryIdGenericFn);
+
+    AstNode *decl_node = generic_fn_type->data.generic_fn.decl_node;
+    assert(decl_node->type == NodeTypeFnProto);
+
+    int expected_param_count = decl_node->data.fn_proto.generic_params.length;
+    int actual_param_count = node->data.fn_call_expr.params.length;
+
+    if (actual_param_count != expected_param_count) {
+        add_node_error(g, first_executing_node(node),
+                buf_sprintf("expected %d arguments, got %d", expected_param_count, actual_param_count));
+        return g->builtin_types.entry_invalid;
+    }
+
+    GenericFnTypeId *generic_fn_type_id = allocate<GenericFnTypeId>(1);
+    generic_fn_type_id->decl_node = decl_node;
+    generic_fn_type_id->generic_param_count = actual_param_count;
+    generic_fn_type_id->generic_params = allocate<GenericParamValue>(actual_param_count);
+
+    BlockContext *child_context = import->block_context;
+    for (int i = 0; i < actual_param_count; i += 1) {
+        AstNode *generic_param_decl_node = decl_node->data.fn_proto.generic_params.at(i);
+        assert(generic_param_decl_node->type == NodeTypeParamDecl);
+
+        AstNode **generic_param_type_node = &generic_param_decl_node->data.param_decl.type;
+
+        TypeTableEntry *expected_param_type = analyze_expression(g, decl_node->owner,
+                decl_node->owner->block_context, nullptr, *generic_param_type_node);
+        if (expected_param_type->id == TypeTableEntryIdInvalid) {
+            return expected_param_type;
+        }
+        AstNode **param_node = &node->data.fn_call_expr.params.at(i);
+
+        TypeTableEntry *param_type = analyze_expression(g, import, child_context, expected_param_type,
+                *param_node);
+        if (param_type->id == TypeTableEntryIdInvalid) {
+            return param_type;
+        }
+
+        // set child_context so that the previous param is in scope
+        child_context = new_block_context(generic_param_decl_node, child_context);
+
+        ConstExprValue *const_val = &get_resolved_expr(*param_node)->const_val;
+        if (const_val->ok) {
+            add_local_var(g, generic_param_decl_node, decl_node->owner, child_context,
+                    &generic_param_decl_node->data.param_decl.name, param_type, true, *param_node);
+        } else {
+            add_node_error(g, *param_node, buf_sprintf("unable to resolve constant expression"));
+
+            add_local_var(g, generic_param_decl_node, decl_node->owner, child_context,
+                    &generic_param_decl_node->data.param_decl.name, g->builtin_types.entry_invalid,
+                    true, nullptr);
+
+            return g->builtin_types.entry_invalid;
+        }
+
+        GenericParamValue *generic_param_value = &generic_fn_type_id->generic_params[i];
+        generic_param_value->type = param_type;
+        generic_param_value->node = *param_node;
+    }
+
+
+    auto entry = g->generic_table.maybe_get(generic_fn_type_id);
+    if (entry) {
+        AstNode *impl_decl_node = entry->value;
+        assert(impl_decl_node->type == NodeTypeFnProto);
+        FnTableEntry *fn_table_entry = impl_decl_node->data.fn_proto.fn_table_entry;
+        return resolve_expr_const_val_as_fn(g, node, fn_table_entry);
+    }
+
+    // make a type from the generic parameters supplied
+    assert(decl_node->type == NodeTypeFnProto);
+    AstNode *impl_decl_node = ast_clone_subtree(decl_node);
+
+    preview_fn_proto_instance(g, import, decl_node, child_context);
+
+    g->generic_table.put(generic_fn_type_id, impl_decl_node);
+
+    FnTableEntry *fn_table_entry = decl_node->data.fn_proto.fn_table_entry;
+    return resolve_expr_const_val_as_fn(g, node, fn_table_entry);
+}
+
 static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         TypeTableEntry *expected_type, AstNode *node)
 {
@@ -4627,6 +4771,8 @@ static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import
 
             return analyze_fn_call_raw(g, import, context, expected_type, node,
                     const_val->data.x_fn, bare_struct_type);
+        } else if (invoke_type_entry->id == TypeTableEntryIdGenericFn) {
+            return analyze_generic_fn_call(g, import, context, expected_type, node, const_val->data.x_type);
         } else {
             add_node_error(g, fn_ref_expr,
                 buf_sprintf("type '%s' not a function", buf_ptr(&invoke_type_entry->name)));
@@ -4971,7 +5117,7 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import,
             Buf *var_name = &var_node->data.symbol_expr.symbol;
             var_node->block_context = child_context;
             prong_node->data.switch_prong.var = add_local_var(g, var_node, import,
-                    child_context, var_name, var_type, true);
+                    child_context, var_name, var_type, true, nullptr);
             prong_node->data.switch_prong.var_is_target_expr = var_is_target_expr;
         }
     }
@@ -5391,7 +5537,7 @@ static void analyze_fn_body(CodeGen *g, FnTableEntry *fn_table_entry) {
         }
 
         VariableTableEntry *var = add_local_var(g, param_decl_node, import, context, &param_decl->name,
-                type, true);
+                type, true, nullptr);
         var->src_arg_index = i;
         param_decl_node->data.param_decl.variable = var;
 
@@ -5413,7 +5559,9 @@ static void add_top_level_decl(CodeGen *g, ImportTableEntry *import, BlockContex
     tld->import = import;
     tld->name = name;
 
-    if (g->check_unused || g->is_test_build || tld->visib_mod == VisibModExport) {
+    bool want_as_export = (g->check_unused || g->is_test_build || tld->visib_mod == VisibModExport);
+    bool is_generic = (node->type == NodeTypeFnProto && node->data.fn_proto.generic_params.length > 0);
+    if (!is_generic && want_as_export) {
         g->export_queue.append(node);
     }
 
@@ -5909,6 +6057,7 @@ bool handle_is_ptr(TypeTableEntry *type_entry) {
         case TypeTableEntryIdNumLitInt:
         case TypeTableEntryIdUndefLit:
         case TypeTableEntryIdNamespace:
+        case TypeTableEntryIdGenericFn:
              zig_unreachable();
         case TypeTableEntryIdUnreachable:
         case TypeTableEntryIdVoid:
@@ -5965,7 +6114,6 @@ uint32_t fn_type_id_hash(FnTypeId *id) {
     result += id->is_cold ? 3605523458 : 0;
     result += id->is_var_args ? 1931444534 : 0;
     result += hash_ptr(id->return_type);
-    result += id->param_count;
     for (int i = 0; i < id->param_count; i += 1) {
         FnTypeParamInfo *info = &id->param_info[i];
         result += info->is_noalias ? 892356923 : 0;
@@ -5999,6 +6147,76 @@ bool fn_type_id_eql(FnTypeId *a, FnTypeId *b) {
     return true;
 }
 
+static uint32_t hash_const_val(TypeTableEntry *type, ConstExprValue *const_val) {
+    switch (type->id) {
+        case TypeTableEntryIdBool:
+            return const_val->data.x_bool ? 127863866 : 215080464;
+        case TypeTableEntryIdMetaType:
+            return hash_ptr(const_val->data.x_type);
+        case TypeTableEntryIdVoid:
+            return 4149439618;
+        case TypeTableEntryIdInt:
+        case TypeTableEntryIdNumLitInt:
+            return ((uint32_t)(bignum_to_twos_complement(&const_val->data.x_bignum) % UINT32_MAX)) * 1331471175;
+        case TypeTableEntryIdFloat:
+        case TypeTableEntryIdNumLitFloat:
+            return const_val->data.x_bignum.data.x_float * UINT32_MAX;
+        case TypeTableEntryIdPointer:
+            return hash_ptr(const_val->data.x_ptr.ptr);
+        case TypeTableEntryIdUndefLit:
+            return 162837799;
+        case TypeTableEntryIdArray:
+            // TODO better hashing algorithm
+            return 1166190605;
+        case TypeTableEntryIdStruct:
+            // TODO better hashing algorithm
+            return 1532530855;
+        case TypeTableEntryIdMaybe:
+            if (const_val->data.x_maybe) {
+                TypeTableEntry *child_type = type->data.maybe.child_type;
+                return hash_const_val(child_type, const_val->data.x_maybe) * 1992916303;
+            } else {
+                return 4016830364;
+            }
+        case TypeTableEntryIdErrorUnion:
+            // TODO better hashing algorithm
+            return 3415065496;
+        case TypeTableEntryIdPureError:
+            // TODO better hashing algorithm
+            return 2630160122;
+        case TypeTableEntryIdEnum:
+            // TODO better hashing algorithm
+            return 31643936;
+        case TypeTableEntryIdFn:
+            return hash_ptr(const_val->data.x_fn);
+        case TypeTableEntryIdTypeDecl:
+            return hash_ptr(const_val->data.x_type);
+        case TypeTableEntryIdNamespace:
+            return hash_ptr(const_val->data.x_import);
+        case TypeTableEntryIdGenericFn:
+        case TypeTableEntryIdInvalid:
+        case TypeTableEntryIdUnreachable:
+            zig_unreachable();
+    }
+}
+
+uint32_t generic_fn_type_id_hash(GenericFnTypeId *id) {
+    uint32_t result = 0;
+    result += hash_ptr(id->decl_node);
+    for (int i = 0; i < id->generic_param_count; i += 1) {
+        GenericParamValue *generic_param = &id->generic_params[i];
+        ConstExprValue *const_val = &get_resolved_expr(generic_param->node)->const_val;
+        assert(const_val->ok);
+        result += hash_const_val(generic_param->type, const_val);
+    }
+    return result;
+}
+
+bool generic_fn_type_id_eql(GenericFnTypeId *a, GenericFnTypeId *b) {
+    // TODO
+    return true;
+}
+
 bool type_has_bits(TypeTableEntry *type_entry) {
     assert(type_entry);
     assert(type_entry->id != TypeTableEntryIdInvalid);
@@ -6027,6 +6245,7 @@ static TypeTableEntry *type_of_first_thing_in_memory(TypeTableEntry *type_entry)
         case TypeTableEntryIdMetaType:
         case TypeTableEntryIdVoid:
         case TypeTableEntryIdNamespace:
+        case TypeTableEntryIdGenericFn:
             zig_unreachable();
         case TypeTableEntryIdArray:
             return type_of_first_thing_in_memory(type_entry->data.array.child_type);
src/ast_render.cpp
@@ -59,15 +59,6 @@ static const char *prefix_op_str(PrefixOp prefix_op) {
     zig_unreachable();
 }
 
-static const char *return_prefix_str(ReturnKind kind) {
-    switch (kind) {
-        case ReturnKindError: return "%";
-        case ReturnKindMaybe: return "?";
-        case ReturnKindUnconditional: return "";
-    }
-    zig_unreachable();
-}
-
 static const char *visib_mod_string(VisibMod mod) {
     switch (mod) {
         case VisibModPub: return "pub ";
@@ -195,316 +186,36 @@ static const char *node_type_str(NodeType node_type) {
     zig_unreachable();
 }
 
+struct AstPrint {
+    int indent;
+    FILE *f;
+};
 
-void ast_print(FILE *f, AstNode *node, int indent) {
-    for (int i = 0; i < indent; i += 1) {
-        fprintf(f, " ");
-    }
-    assert(node->type == NodeTypeRoot || *node->parent_field == node);
-
-    switch (node->type) {
-        case NodeTypeRoot:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            for (int i = 0; i < node->data.root.top_level_decls.length; i += 1) {
-                AstNode *child = node->data.root.top_level_decls.at(i);
-                ast_print(f, child, indent + 2);
-            }
-            break;
-        case NodeTypeFnDef:
-            {
-                fprintf(f, "%s\n", node_type_str(node->type));
-                AstNode *child = node->data.fn_def.fn_proto;
-                ast_print(f, child, indent + 2);
-                ast_print(f, node->data.fn_def.body, indent + 2);
-                break;
-            }
-        case NodeTypeFnProto:
-            {
-                Buf *name_buf = &node->data.fn_proto.name;
-                fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(name_buf));
+static void ast_print_visit(AstNode **node_ptr, void *context) {
+    AstNode *node = *node_ptr;
+    AstPrint *ap = (AstPrint *)context;
 
-                for (int i = 0; i < node->data.fn_proto.params.length; i += 1) {
-                    AstNode *child = node->data.fn_proto.params.at(i);
-                    ast_print(f, child, indent + 2);
-                }
+    for (int i = 0; i < ap->indent; i += 1) {
+        fprintf(ap->f, " ");
+    }
 
-                ast_print(f, node->data.fn_proto.return_type, indent + 2);
+    fprintf(ap->f, "%s\n", node_type_str(node->type));
 
-                break;
-            }
-        case NodeTypeBlock:
-            {
-                fprintf(f, "%s\n", node_type_str(node->type));
-                for (int i = 0; i < node->data.block.statements.length; i += 1) {
-                    AstNode *child = node->data.block.statements.at(i);
-                    ast_print(f, child, indent + 2);
-                }
-                break;
-            }
-        case NodeTypeParamDecl:
-            {
-                Buf *name_buf = &node->data.param_decl.name;
-                fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(name_buf));
+    AstPrint new_ap;
+    new_ap.indent = ap->indent + 2;
+    new_ap.f = ap->f;
 
-                ast_print(f, node->data.param_decl.type, indent + 2);
+    ast_visit_node_children(node, ast_print_visit, &new_ap);
+}
 
-                break;
-            }
-        case NodeTypeReturnExpr:
-            {
-                const char *prefix_str = return_prefix_str(node->data.return_expr.kind);
-                fprintf(f, "%s%s\n", prefix_str, node_type_str(node->type));
-                if (node->data.return_expr.expr)
-                    ast_print(f, node->data.return_expr.expr, indent + 2);
-                break;
-            }
-        case NodeTypeDefer:
-            {
-                const char *prefix_str = return_prefix_str(node->data.defer.kind);
-                fprintf(f, "%s%s\n", prefix_str, node_type_str(node->type));
-                if (node->data.defer.expr)
-                    ast_print(f, node->data.defer.expr, indent + 2);
-                break;
-            }
-        case NodeTypeVariableDeclaration:
-            {
-                Buf *name_buf = &node->data.variable_declaration.symbol;
-                fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(name_buf));
-                if (node->data.variable_declaration.type)
-                    ast_print(f, node->data.variable_declaration.type, indent + 2);
-                if (node->data.variable_declaration.expr)
-                    ast_print(f, node->data.variable_declaration.expr, indent + 2);
-                break;
-            }
-        case NodeTypeTypeDecl:
-            {
-                Buf *name_buf = &node->data.type_decl.symbol;
-                fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(name_buf));
-                ast_print(f, node->data.type_decl.child_type, indent + 2);
-                break;
-            }
-        case NodeTypeErrorValueDecl:
-            {
-                Buf *name_buf = &node->data.error_value_decl.name;
-                fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(name_buf));
-                break;
-            }
-        case NodeTypeFnDecl:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            ast_print(f, node->data.fn_decl.fn_proto, indent + 2);
-            break;
-        case NodeTypeBinOpExpr:
-            fprintf(f, "%s %s\n", node_type_str(node->type),
-                    bin_op_str(node->data.bin_op_expr.bin_op));
-            ast_print(f, node->data.bin_op_expr.op1, indent + 2);
-            ast_print(f, node->data.bin_op_expr.op2, indent + 2);
-            break;
-        case NodeTypeUnwrapErrorExpr:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            ast_print(f, node->data.unwrap_err_expr.op1, indent + 2);
-            if (node->data.unwrap_err_expr.symbol) {
-                ast_print(f, node->data.unwrap_err_expr.symbol, indent + 2);
-            }
-            ast_print(f, node->data.unwrap_err_expr.op2, indent + 2);
-            break;
-        case NodeTypeFnCallExpr:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            ast_print(f, node->data.fn_call_expr.fn_ref_expr, indent + 2);
-            for (int i = 0; i < node->data.fn_call_expr.params.length; i += 1) {
-                AstNode *child = node->data.fn_call_expr.params.at(i);
-                ast_print(f, child, indent + 2);
-            }
-            break;
-        case NodeTypeArrayAccessExpr:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            ast_print(f, node->data.array_access_expr.array_ref_expr, indent + 2);
-            ast_print(f, node->data.array_access_expr.subscript, indent + 2);
-            break;
-        case NodeTypeSliceExpr:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            ast_print(f, node->data.slice_expr.array_ref_expr, indent + 2);
-            ast_print(f, node->data.slice_expr.start, indent + 2);
-            if (node->data.slice_expr.end) {
-                ast_print(f, node->data.slice_expr.end, indent + 2);
-            }
-            break;
-        case NodeTypeDirective:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            ast_print(f, node->data.directive.expr, indent + 2);
-            break;
-        case NodeTypePrefixOpExpr:
-            fprintf(f, "%s %s\n", node_type_str(node->type),
-                    prefix_op_str(node->data.prefix_op_expr.prefix_op));
-            ast_print(f, node->data.prefix_op_expr.primary_expr, indent + 2);
-            break;
-        case NodeTypeNumberLiteral:
-            {
-                NumLit kind = node->data.number_literal.kind;
-                const char *name = node_type_str(node->type);
-                if (kind == NumLitUInt) {
-                    fprintf(f, "%s uint %" PRIu64 "\n", name, node->data.number_literal.data.x_uint);
-                } else {
-                    fprintf(f, "%s float %f\n", name, node->data.number_literal.data.x_float);
-                }
-                break;
-            }
-        case NodeTypeStringLiteral:
-            {
-                const char *c = node->data.string_literal.c ? "c" : "";
-                fprintf(f, "StringLiteral %s'%s'\n", c,
-                        buf_ptr(&node->data.string_literal.buf));
-                break;
-            }
-        case NodeTypeCharLiteral:
-            {
-                fprintf(f, "%s '%c'\n", node_type_str(node->type), node->data.char_literal.value);
-                break;
-            }
-        case NodeTypeSymbol:
-            fprintf(f, "Symbol %s\n", buf_ptr(&node->data.symbol_expr.symbol));
-            break;
-        case NodeTypeUse:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            ast_print(f, node->data.use.expr, indent + 2);
-            break;
-        case NodeTypeBoolLiteral:
-            fprintf(f, "%s '%s'\n", node_type_str(node->type),
-                    node->data.bool_literal.value ? "true" : "false");
-            break;
-        case NodeTypeNullLiteral:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            break;
-        case NodeTypeIfBoolExpr:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            if (node->data.if_bool_expr.condition)
-                ast_print(f, node->data.if_bool_expr.condition, indent + 2);
-            ast_print(f, node->data.if_bool_expr.then_block, indent + 2);
-            if (node->data.if_bool_expr.else_node)
-                ast_print(f, node->data.if_bool_expr.else_node, indent + 2);
-            break;
-        case NodeTypeIfVarExpr:
-            {
-                Buf *name_buf = &node->data.if_var_expr.var_decl.symbol;
-                fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(name_buf));
-                if (node->data.if_var_expr.var_decl.type)
-                    ast_print(f, node->data.if_var_expr.var_decl.type, indent + 2);
-                if (node->data.if_var_expr.var_decl.expr)
-                    ast_print(f, node->data.if_var_expr.var_decl.expr, indent + 2);
-                ast_print(f, node->data.if_var_expr.then_block, indent + 2);
-                if (node->data.if_var_expr.else_node)
-                    ast_print(f, node->data.if_var_expr.else_node, indent + 2);
-                break;
-            }
-        case NodeTypeWhileExpr:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            ast_print(f, node->data.while_expr.condition, indent + 2);
-            ast_print(f, node->data.while_expr.body, indent + 2);
-            break;
-        case NodeTypeForExpr:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            ast_print(f, node->data.for_expr.elem_node, indent + 2);
-            ast_print(f, node->data.for_expr.array_expr, indent + 2);
-            if (node->data.for_expr.index_node) {
-                ast_print(f, node->data.for_expr.index_node, indent + 2);
-            }
-            ast_print(f, node->data.for_expr.body, indent + 2);
-            break;
-        case NodeTypeSwitchExpr:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            ast_print(f, node->data.switch_expr.expr, indent + 2);
-            for (int i = 0; i < node->data.switch_expr.prongs.length; i += 1) {
-                AstNode *child_node = node->data.switch_expr.prongs.at(i);
-                ast_print(f, child_node, indent + 2);
-            }
-            break;
-        case NodeTypeSwitchProng:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            for (int i = 0; i < node->data.switch_prong.items.length; i += 1) {
-                AstNode *child_node = node->data.switch_prong.items.at(i);
-                ast_print(f, child_node, indent + 2);
-            }
-            if (node->data.switch_prong.var_symbol) {
-                ast_print(f, node->data.switch_prong.var_symbol, indent + 2);
-            }
-            ast_print(f, node->data.switch_prong.expr, indent + 2);
-            break;
-        case NodeTypeSwitchRange:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            ast_print(f, node->data.switch_range.start, indent + 2);
-            ast_print(f, node->data.switch_range.end, indent + 2);
-            break;
-        case NodeTypeLabel:
-            fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.label.name));
-            break;
-        case NodeTypeGoto:
-            fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.goto_expr.name));
-            break;
-        case NodeTypeBreak:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            break;
-        case NodeTypeContinue:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            break;
-        case NodeTypeUndefinedLiteral:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            break;
-        case NodeTypeAsmExpr:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            break;
-        case NodeTypeFieldAccessExpr:
-            fprintf(f, "%s '%s'\n", node_type_str(node->type),
-                    buf_ptr(&node->data.field_access_expr.field_name));
-            ast_print(f, node->data.field_access_expr.struct_expr, indent + 2);
-            break;
-        case NodeTypeStructDecl:
-            fprintf(f, "%s '%s'\n",
-                    node_type_str(node->type), buf_ptr(&node->data.struct_decl.name));
-            for (int i = 0; i < node->data.struct_decl.fields.length; i += 1) {
-                AstNode *child = node->data.struct_decl.fields.at(i);
-                ast_print(f, child, indent + 2);
-            }
-            for (int i = 0; i < node->data.struct_decl.fns.length; i += 1) {
-                AstNode *child = node->data.struct_decl.fns.at(i);
-                ast_print(f, child, indent + 2);
-            }
-            break;
-        case NodeTypeStructField:
-            fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.struct_field.name));
-            if (node->data.struct_field.type) {
-                ast_print(f, node->data.struct_field.type, indent + 2);
-            }
-            break;
-        case NodeTypeStructValueField:
-            fprintf(f, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.struct_val_field.name));
-            ast_print(f, node->data.struct_val_field.expr, indent + 2);
-            break;
-        case NodeTypeContainerInitExpr:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            ast_print(f, node->data.container_init_expr.type, indent + 2);
-            for (int i = 0; i < node->data.container_init_expr.entries.length; i += 1) {
-                AstNode *child = node->data.container_init_expr.entries.at(i);
-                ast_print(f, child, indent + 2);
-            }
-            break;
-        case NodeTypeArrayType:
-            {
-                const char *const_str = node->data.array_type.is_const ? "const" : "var";
-                fprintf(f, "%s %s\n", node_type_str(node->type), const_str);
-                if (node->data.array_type.size) {
-                    ast_print(f, node->data.array_type.size, indent + 2);
-                }
-                ast_print(f, node->data.array_type.child_type, indent + 2);
-                break;
-            }
-        case NodeTypeErrorType:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            break;
-        case NodeTypeTypeLiteral:
-            fprintf(f, "%s\n", node_type_str(node->type));
-            break;
-    }
+void ast_print(FILE *f, AstNode *node, int indent) {
+    AstPrint ap;
+    ap.indent = indent;
+    ap.f = f;
+    ast_visit_node_children(node, ast_print_visit, &ap);
 }
 
+
 struct AstRender {
     int indent;
     int indent_size;
src/ast_render.hpp
@@ -9,6 +9,7 @@
 #define ZIG_AST_RENDER_HPP
 
 #include "all_types.hpp"
+#include "parser.hpp"
 
 #include <stdio.h>
 
src/codegen.cpp
@@ -62,6 +62,7 @@ CodeGen *codegen_create(Buf *root_source_dir, const ZigTarget *target) {
     g->primitive_type_table.init(32);
     g->fn_type_table.init(32);
     g->error_table.init(16);
+    g->generic_table.init(16);
     g->is_release_build = false;
     g->is_test_build = false;
     g->error_value_count = 1;
@@ -2927,6 +2928,7 @@ static LLVMValueRef gen_const_val(CodeGen *g, TypeTableEntry *type_entry, ConstE
         case TypeTableEntryIdUndefLit:
         case TypeTableEntryIdVoid:
         case TypeTableEntryIdNamespace:
+        case TypeTableEntryIdGenericFn:
             zig_unreachable();
 
     }
src/parser.cpp
@@ -2243,7 +2243,7 @@ static AstNode *ast_parse_block(ParseContext *pc, int *token_index, bool mandato
 }
 
 /*
-FnProto : "fn" option("Symbol") ParamDeclList option("->" PrefixOpExpression)
+FnProto = "fn" option("Symbol") option(ParamDeclList) ParamDeclList option("->" TypeExpr)
 */
 static AstNode *ast_parse_fn_proto(ParseContext *pc, int *token_index, bool mandatory,
         ZigList<AstNode*> *directives, VisibMod visib_mod)
@@ -2273,6 +2273,17 @@ static AstNode *ast_parse_fn_proto(ParseContext *pc, int *token_index, bool mand
 
     ast_parse_param_decl_list(pc, token_index, &node->data.fn_proto.params, &node->data.fn_proto.is_var_args);
 
+    Token *maybe_lparen = &pc->tokens->at(*token_index);
+    if (maybe_lparen->id == TokenIdLParen) {
+        for (int i = 0; i < node->data.fn_proto.params.length; i += 1) {
+            node->data.fn_proto.generic_params.append(node->data.fn_proto.params.at(i));
+        }
+        node->data.fn_proto.generic_params_is_var_args = node->data.fn_proto.is_var_args;
+
+        node->data.fn_proto.params.resize(0);
+        ast_parse_param_decl_list(pc, token_index, &node->data.fn_proto.params, &node->data.fn_proto.is_var_args);
+    }
+
     Token *next_token = &pc->tokens->at(*token_index);
     if (next_token->id == TokenIdArrow) {
         *token_index += 1;
@@ -2626,72 +2637,73 @@ AstNode *ast_parse(Buf *buf, ZigList<Token> *tokens, ImportTableEntry *owner,
     return pc.root;
 }
 
-static void set_field(AstNode **field) {
-    if (*field) {
-        (*field)->parent_field = field;
+static void visit_field(AstNode **node, void (*visit)(AstNode **, void *context), void *context) {
+    if (*node) {
+        visit(node, context);
     }
 }
 
-static void set_list_fields(ZigList<AstNode*> *list) {
+static void visit_node_list(ZigList<AstNode *> *list, void (*visit)(AstNode **, void *context), void *context) {
     if (list) {
         for (int i = 0; i < list->length; i += 1) {
-            set_field(&list->at(i));
+            visit(&list->at(i), context);
         }
     }
 }
 
-void normalize_parent_ptrs(AstNode *node) {
+void ast_visit_node_children(AstNode *node, void (*visit)(AstNode **, void *context), void *context) {
     switch (node->type) {
         case NodeTypeRoot:
-            set_list_fields(&node->data.root.top_level_decls);
+            visit_node_list(&node->data.root.top_level_decls, visit, context);
             break;
         case NodeTypeFnProto:
-            set_field(&node->data.fn_proto.return_type);
-            set_list_fields(node->data.fn_proto.top_level_decl.directives);
-            set_list_fields(&node->data.fn_proto.params);
+            visit_field(&node->data.fn_proto.return_type, visit, context);
+            visit_node_list(node->data.fn_proto.top_level_decl.directives, visit, context);
+            visit_node_list(&node->data.fn_proto.generic_params, visit, context);
+            visit_node_list(&node->data.fn_proto.params, visit, context);
             break;
         case NodeTypeFnDef:
-            set_field(&node->data.fn_def.fn_proto);
-            set_field(&node->data.fn_def.body);
+            visit_field(&node->data.fn_def.fn_proto, visit, context);
+            visit_field(&node->data.fn_def.body, visit, context);
             break;
         case NodeTypeFnDecl:
-            set_field(&node->data.fn_decl.fn_proto);
+            visit_field(&node->data.fn_decl.fn_proto, visit, context);
             break;
         case NodeTypeParamDecl:
-            set_field(&node->data.param_decl.type);
+            visit_field(&node->data.param_decl.type, visit, context);
             break;
         case NodeTypeBlock:
-            set_list_fields(&node->data.block.statements);
+            visit_node_list(&node->data.block.statements, visit, context);
             break;
         case NodeTypeDirective:
-            set_field(&node->data.directive.expr);
+            visit_field(&node->data.directive.expr, visit, context);
             break;
         case NodeTypeReturnExpr:
-            set_field(&node->data.return_expr.expr);
+            visit_field(&node->data.return_expr.expr, visit, context);
             break;
         case NodeTypeDefer:
-            set_field(&node->data.defer.expr);
+            visit_field(&node->data.defer.expr, visit, context);
             break;
         case NodeTypeVariableDeclaration:
-            set_list_fields(node->data.variable_declaration.top_level_decl.directives);
-            set_field(&node->data.variable_declaration.type);
-            set_field(&node->data.variable_declaration.expr);
+            visit_node_list(node->data.variable_declaration.top_level_decl.directives, visit, context);
+            visit_field(&node->data.variable_declaration.type, visit, context);
+            visit_field(&node->data.variable_declaration.expr, visit, context);
             break;
         case NodeTypeTypeDecl:
-            set_list_fields(node->data.type_decl.top_level_decl.directives);
-            set_field(&node->data.type_decl.child_type);
+            visit_node_list(node->data.type_decl.top_level_decl.directives, visit, context);
+            visit_field(&node->data.type_decl.child_type, visit, context);
             break;
         case NodeTypeErrorValueDecl:
             // none
             break;
         case NodeTypeBinOpExpr:
-            set_field(&node->data.bin_op_expr.op1);
-            set_field(&node->data.bin_op_expr.op2);
+            visit_field(&node->data.bin_op_expr.op1, visit, context);
+            visit_field(&node->data.bin_op_expr.op2, visit, context);
             break;
         case NodeTypeUnwrapErrorExpr:
-            set_field(&node->data.unwrap_err_expr.op1);
-            set_field(&node->data.unwrap_err_expr.symbol);
-            set_field(&node->data.unwrap_err_expr.op2);
+            visit_field(&node->data.unwrap_err_expr.op1, visit, context);
+            visit_field(&node->data.unwrap_err_expr.symbol, visit, context);
+            visit_field(&node->data.unwrap_err_expr.op2, visit, context);
             break;
         case NodeTypeNumberLiteral:
             // none
@@ -2706,27 +2718,27 @@ void normalize_parent_ptrs(AstNode *node) {
             // none
             break;
         case NodeTypePrefixOpExpr:
-            set_field(&node->data.prefix_op_expr.primary_expr);
+            visit_field(&node->data.prefix_op_expr.primary_expr, visit, context);
             break;
         case NodeTypeFnCallExpr:
-            set_field(&node->data.fn_call_expr.fn_ref_expr);
-            set_list_fields(&node->data.fn_call_expr.params);
+            visit_field(&node->data.fn_call_expr.fn_ref_expr, visit, context);
+            visit_node_list(&node->data.fn_call_expr.params, visit, context);
             break;
         case NodeTypeArrayAccessExpr:
-            set_field(&node->data.array_access_expr.array_ref_expr);
-            set_field(&node->data.array_access_expr.subscript);
+            visit_field(&node->data.array_access_expr.array_ref_expr, visit, context);
+            visit_field(&node->data.array_access_expr.subscript, visit, context);
             break;
         case NodeTypeSliceExpr:
-            set_field(&node->data.slice_expr.array_ref_expr);
-            set_field(&node->data.slice_expr.start);
-            set_field(&node->data.slice_expr.end);
+            visit_field(&node->data.slice_expr.array_ref_expr, visit, context);
+            visit_field(&node->data.slice_expr.start, visit, context);
+            visit_field(&node->data.slice_expr.end, visit, context);
             break;
         case NodeTypeFieldAccessExpr:
-            set_field(&node->data.field_access_expr.struct_expr);
+            visit_field(&node->data.field_access_expr.struct_expr, visit, context);
             break;
         case NodeTypeUse:
-            set_field(&node->data.use.expr);
-            set_list_fields(node->data.use.top_level_decl.directives);
+            visit_field(&node->data.use.expr, visit, context);
+            visit_node_list(node->data.use.top_level_decl.directives, visit, context);
             break;
         case NodeTypeBoolLiteral:
             // none
@@ -2738,38 +2750,38 @@ void normalize_parent_ptrs(AstNode *node) {
             // none
             break;
         case NodeTypeIfBoolExpr:
-            set_field(&node->data.if_bool_expr.condition);
-            set_field(&node->data.if_bool_expr.then_block);
-            set_field(&node->data.if_bool_expr.else_node);
+            visit_field(&node->data.if_bool_expr.condition, visit, context);
+            visit_field(&node->data.if_bool_expr.then_block, visit, context);
+            visit_field(&node->data.if_bool_expr.else_node, visit, context);
             break;
         case NodeTypeIfVarExpr:
-            set_field(&node->data.if_var_expr.var_decl.type);
-            set_field(&node->data.if_var_expr.var_decl.expr);
-            set_field(&node->data.if_var_expr.then_block);
-            set_field(&node->data.if_var_expr.else_node);
+            visit_field(&node->data.if_var_expr.var_decl.type, visit, context);
+            visit_field(&node->data.if_var_expr.var_decl.expr, visit, context);
+            visit_field(&node->data.if_var_expr.then_block, visit, context);
+            visit_field(&node->data.if_var_expr.else_node, visit, context);
             break;
         case NodeTypeWhileExpr:
-            set_field(&node->data.while_expr.condition);
-            set_field(&node->data.while_expr.body);
+            visit_field(&node->data.while_expr.condition, visit, context);
+            visit_field(&node->data.while_expr.body, visit, context);
             break;
         case NodeTypeForExpr:
-            set_field(&node->data.for_expr.elem_node);
-            set_field(&node->data.for_expr.array_expr);
-            set_field(&node->data.for_expr.index_node);
-            set_field(&node->data.for_expr.body);
+            visit_field(&node->data.for_expr.elem_node, visit, context);
+            visit_field(&node->data.for_expr.array_expr, visit, context);
+            visit_field(&node->data.for_expr.index_node, visit, context);
+            visit_field(&node->data.for_expr.body, visit, context);
             break;
         case NodeTypeSwitchExpr:
-            set_field(&node->data.switch_expr.expr);
-            set_list_fields(&node->data.switch_expr.prongs);
+            visit_field(&node->data.switch_expr.expr, visit, context);
+            visit_node_list(&node->data.switch_expr.prongs, visit, context);
             break;
         case NodeTypeSwitchProng:
-            set_list_fields(&node->data.switch_prong.items);
-            set_field(&node->data.switch_prong.var_symbol);
-            set_field(&node->data.switch_prong.expr);
+            visit_node_list(&node->data.switch_prong.items, visit, context);
+            visit_field(&node->data.switch_prong.var_symbol, visit, context);
+            visit_field(&node->data.switch_prong.expr, visit, context);
             break;
         case NodeTypeSwitchRange:
-            set_field(&node->data.switch_range.start);
-            set_field(&node->data.switch_range.end);
+            visit_field(&node->data.switch_range.start, visit, context);
+            visit_field(&node->data.switch_range.end, visit, context);
             break;
         case NodeTypeLabel:
             // none
@@ -2786,32 +2798,32 @@ void normalize_parent_ptrs(AstNode *node) {
         case NodeTypeAsmExpr:
             for (int i = 0; i < node->data.asm_expr.input_list.length; i += 1) {
                 AsmInput *asm_input = node->data.asm_expr.input_list.at(i);
-                set_field(&asm_input->expr);
+                visit_field(&asm_input->expr, visit, context);
             }
             for (int i = 0; i < node->data.asm_expr.output_list.length; i += 1) {
                 AsmOutput *asm_output = node->data.asm_expr.output_list.at(i);
-                set_field(&asm_output->return_type);
+                visit_field(&asm_output->return_type, visit, context);
             }
             break;
         case NodeTypeStructDecl:
-            set_list_fields(&node->data.struct_decl.fields);
-            set_list_fields(&node->data.struct_decl.fns);
-            set_list_fields(node->data.struct_decl.top_level_decl.directives);
+            visit_node_list(&node->data.struct_decl.fields, visit, context);
+            visit_node_list(&node->data.struct_decl.fns, visit, context);
+            visit_node_list(node->data.struct_decl.top_level_decl.directives, visit, context);
             break;
         case NodeTypeStructField:
-            set_field(&node->data.struct_field.type);
-            set_list_fields(node->data.struct_field.top_level_decl.directives);
+            visit_field(&node->data.struct_field.type, visit, context);
+            visit_node_list(node->data.struct_field.top_level_decl.directives, visit, context);
             break;
         case NodeTypeContainerInitExpr:
-            set_field(&node->data.container_init_expr.type);
-            set_list_fields(&node->data.container_init_expr.entries);
+            visit_field(&node->data.container_init_expr.type, visit, context);
+            visit_node_list(&node->data.container_init_expr.entries, visit, context);
             break;
         case NodeTypeStructValueField:
-            set_field(&node->data.struct_val_field.expr);
+            visit_field(&node->data.struct_val_field.expr, visit, context);
             break;
         case NodeTypeArrayType:
-            set_field(&node->data.array_type.size);
-            set_field(&node->data.array_type.child_type);
+            visit_field(&node->data.array_type.size, visit, context);
+            visit_field(&node->data.array_type.child_type, visit, context);
             break;
         case NodeTypeErrorType:
             // none
@@ -2821,3 +2833,29 @@ void normalize_parent_ptrs(AstNode *node) {
             break;
     }
 }
+
+static void normalize_parent_ptrs_visit(AstNode **node, void *context) {
+    (*node)->parent_field = node;
+}
+
+void normalize_parent_ptrs(AstNode *node) {
+    ast_visit_node_children(node, normalize_parent_ptrs_visit, nullptr);
+}
+
+static AstNode *clone_node(AstNode *old_node) {
+    AstNode *new_node = allocate_nonzero<AstNode>(1);
+    memcpy(new_node, old_node, sizeof(AstNode));
+    return new_node;
+}
+
+static void ast_clone_subtree_visit(AstNode **node, void *context) {
+    *node = clone_node(*node);
+    (*node)->parent_field = node;
+    ast_visit_node_children(*node, ast_clone_subtree_visit, nullptr);
+}
+
+AstNode *ast_clone_subtree(AstNode *old_node) {
+    AstNode *new_node = clone_node(old_node);
+    ast_visit_node_children(new_node, ast_clone_subtree_visit, nullptr);
+    return new_node;
+}
src/parser.hpp
@@ -20,10 +20,11 @@ void ast_token_error(Token *token, const char *format, ...);
 AstNode * ast_parse(Buf *buf, ZigList<Token> *tokens, ImportTableEntry *owner, ErrColor err_color,
         uint32_t *next_node_index);
 
-const char *node_type_str(NodeType node_type);
-
 void ast_print(AstNode *node, int indent);
 
 void normalize_parent_ptrs(AstNode *node);
 
+AstNode *ast_clone_subtree(AstNode *node);
+void ast_visit_node_children(AstNode *node, void (*visit)(AstNode **, void *context), void *context);
+
 #endif
test/self_hosted.zig
@@ -512,6 +512,16 @@ three)";
 
 
 
+#attribute("test")
+fn simple_generic_fn() {
+    assert(max(i32)(3, -1) == 3);
+}
+
+fn max(T: type)(a: T, b: T) -> T {
+    return if (a > b) a else b;
+}
+
+
 fn assert(b: bool) {
     if (!b) unreachable{}
 }