Commit 893e152dab

Andrew Kelley <superjoe30@gmail.com>
2015-11-26 02:17:35
no errors during codegen
also, fix function calling and allow forward declarations
1 parent 6f460de
Changed files (1)
src/codegen.cpp
@@ -21,6 +21,9 @@
 struct FnTableEntry {
     LLVMValueRef fn_value;
     AstNode *proto_node;
+    AstNode *fn_def_node;
+    bool is_extern;
+    bool internal_linkage;
 };
 
 enum TypeId {
@@ -48,7 +51,6 @@ struct TypeTableEntry {
 struct CodeGen {
     LLVMModuleRef mod;
     AstNode *root;
-    HashMap<Buf *, AstNode *, buf_hash, buf_eql_buf> fn_defs;
     ZigList<ErrorMsg> errors;
     LLVMBuilderRef builder;
     llvm::DIBuilder *dbuilder;
@@ -68,6 +70,7 @@ struct CodeGen {
     Buf in_dir;
     ZigList<llvm::DIScope *> block_scopes;
     llvm::DIFile *di_file;
+    ZigList<FnTableEntry *> fn_defs;
 };
 
 struct TypeNode {
@@ -83,7 +86,6 @@ struct CodeGenNode {
 CodeGen *create_codegen(AstNode *root, Buf *in_full_path) {
     CodeGen *g = allocate<CodeGen>(1);
     g->root = root;
-    g->fn_defs.init(32);
     g->fn_table.init(32);
     g->str_table.init(32);
     g->type_table.init(32);
@@ -137,11 +139,12 @@ static llvm::DIType *to_llvm_debug_type(AstNode *type_node) {
 
 static bool type_is_unreachable(AstNode *type_node) {
     assert(type_node->type == NodeTypeType);
-    return type_node->data.type.type == AstNodeTypeTypePrimitive &&
-            buf_eql_str(&type_node->data.type.primitive_name, "unreachable");
+    assert(type_node->codegen_node);
+    assert(type_node->codegen_node->data.type_node.entry);
+    return type_node->codegen_node->data.type_node.entry->id == TypeIdUnreachable;
 }
 
-static void analyze_node(CodeGen *g, AstNode *node);
+static void find_declarations(CodeGen *g, AstNode *node);
 
 static void resolve_type_and_recurse(CodeGen *g, AstNode *node) {
     assert(!node->codegen_node);
@@ -163,7 +166,7 @@ static void resolve_type_and_recurse(CodeGen *g, AstNode *node) {
             }
         case AstNodeTypeTypePointer:
             {
-                analyze_node(g, node->data.type.child_type);
+                find_declarations(g, node->data.type.child_type);
                 TypeNode *child_type_node = &node->data.type.child_type->codegen_node->data.type_node;
                 if (child_type_node->entry->id == TypeIdUnreachable) {
                     add_node_error(g, node,
@@ -192,14 +195,8 @@ static void resolve_type_and_recurse(CodeGen *g, AstNode *node) {
     }
 }
 
-static void analyze_node(CodeGen *g, AstNode *node) {
+static void find_declarations(CodeGen *g, AstNode *node) {
     switch (node->type) {
-        case NodeTypeRoot:
-            for (int i = 0; i < node->data.root.top_level_decls.length; i += 1) {
-                AstNode *child = node->data.root.top_level_decls.at(i);
-                analyze_node(g, child);
-            }
-            break;
         case NodeTypeExternBlock:
             for (int i = 0; i < node->data.extern_block.directives->length; i += 1) {
                 AstNode *directive_node = node->data.extern_block.directives->at(i);
@@ -215,34 +212,14 @@ static void analyze_node(CodeGen *g, AstNode *node) {
 
             for (int fn_decl_i = 0; fn_decl_i < node->data.extern_block.fn_decls.length; fn_decl_i += 1) {
                 AstNode *fn_decl = node->data.extern_block.fn_decls.at(fn_decl_i);
-                analyze_node(g, fn_decl);
-
+                assert(fn_decl->type == NodeTypeFnDecl);
                 AstNode *fn_proto = fn_decl->data.fn_decl.fn_proto;
+                find_declarations(g, fn_proto);
                 Buf *name = &fn_proto->data.fn_proto.name;
-                ZigList<AstNode *> *params = &fn_proto->data.fn_proto.params;
-
-                LLVMTypeRef *fn_param_values = allocate<LLVMTypeRef>(params->length);
-                for (int param_i = 0; param_i < params->length; param_i += 1) {
-                    AstNode *param_node = params->at(param_i);
-                    assert(param_node->type == NodeTypeParamDecl);
-                    AstNode *param_type = param_node->data.param_decl.type;
-                    fn_param_values[param_i] = to_llvm_type(param_type);
-                }
-                AstNode *return_type_node = fn_proto->data.fn_proto.return_type;
-                LLVMTypeRef return_type = to_llvm_type(return_type_node);
-
-                LLVMTypeRef fn_type = LLVMFunctionType(return_type, fn_param_values, params->length, 0);
-                LLVMValueRef fn_val = LLVMAddFunction(g->mod, buf_ptr(name), fn_type);
-                LLVMSetLinkage(fn_val, LLVMExternalLinkage);
-                LLVMSetFunctionCallConv(fn_val, LLVMCCallConv);
-
-                if (type_is_unreachable(return_type_node)) {
-                    LLVMAddFunctionAttr(fn_val, LLVMNoReturnAttribute);
-                }
 
                 FnTableEntry *fn_table_entry = allocate<FnTableEntry>(1);
-                fn_table_entry->fn_value = fn_val;
                 fn_table_entry->proto_node = fn_proto;
+                fn_table_entry->is_extern = true;
                 g->fn_table.put(name, fn_table_entry);
             }
             break;
@@ -251,14 +228,74 @@ static void analyze_node(CodeGen *g, AstNode *node) {
                 AstNode *proto_node = node->data.fn_def.fn_proto;
                 assert(proto_node->type == NodeTypeFnProto);
                 Buf *proto_name = &proto_node->data.fn_proto.name;
-                auto entry = g->fn_defs.maybe_get(proto_name);
+                auto entry = g->fn_table.maybe_get(proto_name);
                 if (entry) {
                     add_node_error(g, node,
                             buf_sprintf("redefinition of '%s'", buf_ptr(proto_name)));
                 } else {
-                    g->fn_defs.put(proto_name, node);
-                    analyze_node(g, proto_node);
+                    FnTableEntry *fn_table_entry = allocate<FnTableEntry>(1);
+                    fn_table_entry->proto_node = proto_node;
+                    fn_table_entry->fn_def_node = node;
+                    g->fn_table.put(proto_name, fn_table_entry);
+                    g->fn_defs.append(fn_table_entry);
+
+                    find_declarations(g, proto_node);
+                }
+                break;
+            }
+        case NodeTypeFnProto:
+            {
+                for (int i = 0; i < node->data.fn_proto.params.length; i += 1) {
+                    AstNode *child = node->data.fn_proto.params.at(i);
+                    find_declarations(g, child);
                 }
+                find_declarations(g, node->data.fn_proto.return_type);
+                break;
+            }
+            break;
+        case NodeTypeParamDecl:
+            find_declarations(g, node->data.param_decl.type);
+            break;
+        case NodeTypeType:
+            resolve_type_and_recurse(g, node);
+            break;
+        case NodeTypeDirective:
+            // we handled directives in the parent function
+            break;
+        case NodeTypeFnDecl:
+        case NodeTypeStatementReturn:
+        case NodeTypeRoot:
+        case NodeTypeBlock:
+        case NodeTypeExpression:
+        case NodeTypeFnCall:
+            zig_unreachable();
+    }
+}
+
+static void analyze_node(CodeGen *g, AstNode *node) {
+    switch (node->type) {
+        case NodeTypeRoot:
+            // Iterate once over the top level declarations to build the function table
+            for (int i = 0; i < node->data.root.top_level_decls.length; i += 1) {
+                AstNode *child = node->data.root.top_level_decls.at(i);
+                find_declarations(g, child);
+            }
+            for (int i = 0; i < node->data.root.top_level_decls.length; i += 1) {
+                AstNode *child = node->data.root.top_level_decls.at(i);
+                analyze_node(g, child);
+            }
+            break;
+        case NodeTypeExternBlock:
+            for (int fn_decl_i = 0; fn_decl_i < node->data.extern_block.fn_decls.length; fn_decl_i += 1) {
+                AstNode *fn_decl = node->data.extern_block.fn_decls.at(fn_decl_i);
+                analyze_node(g, fn_decl);
+            }
+            break;
+        case NodeTypeFnDef:
+            {
+                AstNode *proto_node = node->data.fn_def.fn_proto;
+                assert(proto_node->type == NodeTypeFnProto);
+                analyze_node(g, proto_node);
                 break;
             }
         case NodeTypeFnDecl:
@@ -282,10 +319,8 @@ static void analyze_node(CodeGen *g, AstNode *node) {
             break;
 
         case NodeTypeType:
-            {
-                resolve_type_and_recurse(g, node);
-                break;
-            }
+            // ignore; we handled types with find_declarations
+            break;
         case NodeTypeBlock:
             for (int i = 0; i < node->data.block.statements.length; i += 1) {
                 AstNode *child = node->data.block.statements.at(i);
@@ -309,12 +344,33 @@ static void analyze_node(CodeGen *g, AstNode *node) {
             }
             break;
         case NodeTypeFnCall:
-            for (int i = 0; i < node->data.fn_call.params.length; i += 1) {
-                AstNode *child = node->data.fn_call.params.at(i);
-                analyze_node(g, child);
+            {
+                Buf *name = &node->data.fn_call.name;
+
+                auto entry = g->fn_table.maybe_get(name);
+                if (!entry) {
+                    add_node_error(g, node,
+                            buf_sprintf("undefined function: '%s'", buf_ptr(name)));
+                } else {
+                    FnTableEntry *fn_table_entry = entry->value;
+                    assert(fn_table_entry->proto_node->type == NodeTypeFnProto);
+                    int expected_param_count = fn_table_entry->proto_node->data.fn_proto.params.length;
+                    int actual_param_count = node->data.fn_call.params.length;
+                    if (expected_param_count != actual_param_count) {
+                        add_node_error(g, node,
+                                buf_sprintf("wrong number of arguments. Expected %d, got %d.",
+                                    expected_param_count, actual_param_count));
+                    }
+                }
+
+                for (int i = 0; i < node->data.fn_call.params.length; i += 1) {
+                    AstNode *child = node->data.fn_call.params.at(i);
+                    analyze_node(g, child);
+                }
+                break;
             }
-            break;
         case NodeTypeDirective:
+            // we looked at directives in the parent node
             break;
     }
 }
@@ -399,7 +455,6 @@ void semantic_analyze(CodeGen *g) {
 
     add_types(g);
 
-    // Pass 1.
     analyze_node(g, g->root);
 }
 
@@ -415,23 +470,11 @@ static LLVMValueRef gen_fn_call(CodeGen *g, AstNode *fn_call_node) {
     assert(fn_call_node->type == NodeTypeFnCall);
 
     Buf *name = &fn_call_node->data.fn_call.name;
-
-    auto entry = g->fn_table.maybe_get(name);
-    if (!entry) {
-        add_node_error(g, fn_call_node,
-                buf_sprintf("undefined function: '%s'", buf_ptr(name)));
-        return LLVMConstNull(LLVMInt32Type());
-    }
-    FnTableEntry *fn_table_entry = entry->value;
+    FnTableEntry *fn_table_entry = g->fn_table.get(name);
     assert(fn_table_entry->proto_node->type == NodeTypeFnProto);
     int expected_param_count = fn_table_entry->proto_node->data.fn_proto.params.length;
     int actual_param_count = fn_call_node->data.fn_call.params.length;
-    if (expected_param_count != actual_param_count) {
-        add_node_error(g, fn_call_node,
-                buf_sprintf("wrong number of arguments. Expected %d, got %d.",
-                    expected_param_count, actual_param_count));
-        return LLVMConstNull(LLVMInt32Type());
-    }
+    assert(expected_param_count == actual_param_count);
 
     LLVMValueRef *param_values = allocate<LLVMValueRef>(actual_param_count);
     for (int i = 0; i < actual_param_count; i += 1) {
@@ -557,6 +600,8 @@ static llvm::DISubroutineType *create_di_function_type(CodeGen *g, AstNodeFnProt
 }
 
 void code_gen(CodeGen *g) {
+    assert(!g->errors.length);
+
     Buf *producer = buf_sprintf("zig %s", ZIG_VERSION_STRING);
     bool is_optimized = g->build_type == CodeGenBuildTypeRelease;
     const char *flags = "";
@@ -570,16 +615,19 @@ void code_gen(CodeGen *g) {
 
     g->di_file = g->dbuilder->createFile(g->compile_unit->getFilename(), g->compile_unit->getDirectory());
 
-    auto it = g->fn_defs.entry_iterator();
+
+    // Generate function prototypes
+    auto it = g->fn_table.entry_iterator();
     for (;;) {
         auto *entry = it.next();
         if (!entry)
             break;
 
-        AstNode *fn_def_node = entry->value;
-        AstNodeFnDef *fn_def = &fn_def_node->data.fn_def;
-        assert(fn_def->fn_proto->type == NodeTypeFnProto);
-        AstNodeFnProto *fn_proto = &fn_def->fn_proto->data.fn_proto;
+        FnTableEntry *fn_table_entry = entry->value;
+
+        AstNode *proto_node = fn_table_entry->proto_node;
+        assert(proto_node->type == NodeTypeFnProto);
+        AstNodeFnProto *fn_proto = &proto_node->data.fn_proto;
 
         LLVMTypeRef ret_type = to_llvm_type(fn_proto->return_type);
         LLVMTypeRef *param_types = allocate<LLVMTypeRef>(fn_proto->params.length);
@@ -592,13 +640,29 @@ void code_gen(CodeGen *g) {
         LLVMTypeRef function_type = LLVMFunctionType(ret_type, param_types, fn_proto->params.length, 0);
         LLVMValueRef fn = LLVMAddFunction(g->mod, buf_ptr(&fn_proto->name), function_type);
 
-        bool internal_linkage = false;
-        LLVMSetLinkage(fn, internal_linkage ? LLVMPrivateLinkage : LLVMExternalLinkage);
+        LLVMSetLinkage(fn, fn_table_entry->internal_linkage ? LLVMPrivateLinkage : LLVMExternalLinkage);
 
         if (type_is_unreachable(fn_proto->return_type)) {
             LLVMAddFunctionAttr(fn, LLVMNoReturnAttribute);
         }
-        LLVMAddFunctionAttr(fn, LLVMNoUnwindAttribute);
+        if (fn_table_entry->is_extern) {
+            LLVMSetFunctionCallConv(fn, LLVMCCallConv);
+        } else {
+            LLVMAddFunctionAttr(fn, LLVMNoUnwindAttribute);
+        }
+
+        fn_table_entry->fn_value = fn;
+    }
+
+    // Generate function definitions.
+    for (int i = 0; i < g->fn_defs.length; i += 1) {
+        FnTableEntry *fn_table_entry = g->fn_defs.at(i);
+        AstNode *fn_def_node = fn_table_entry->fn_def_node;
+        LLVMValueRef fn = fn_table_entry->fn_value;
+
+        AstNode *proto_node = fn_table_entry->proto_node;
+        assert(proto_node->type == NodeTypeFnProto);
+        AstNodeFnProto *fn_proto = &proto_node->data.fn_proto;
 
         // Add debug info.
         llvm::DIScope *fn_scope = g->di_file;
@@ -609,17 +673,19 @@ void code_gen(CodeGen *g) {
         llvm::Function *unwrapped_function = reinterpret_cast<llvm::Function*>(llvm::unwrap(fn));
         llvm::DISubprogram *subprogram = g->dbuilder->createFunction(
             fn_scope, buf_ptr(&fn_proto->name), "", g->di_file, line_number,
-            create_di_function_type(g, fn_proto, g->di_file), internal_linkage, 
+            create_di_function_type(g, fn_proto, g->di_file), fn_table_entry->internal_linkage, 
             is_definition, scope_line, flags, is_optimized, unwrapped_function);
 
         g->block_scopes.append(subprogram);
 
-
         LLVMBasicBlockRef entry_block = LLVMAppendBasicBlock(fn, "entry");
         LLVMPositionBuilderAtEnd(g->builder, entry_block);
 
-        gen_block(g, fn_def->body);
+        gen_block(g, fn_def_node->data.fn_def.body);
+
+        g->block_scopes.pop();
     }
+    assert(!g->errors.length);
 
     g->dbuilder->finalize();