Commit 58e375d0a1

Andrew Kelley <superjoe30@gmail.com>
2015-12-01 10:08:58
support multiple files
1 parent 29f24e3
example/multiple_files/foo.zig
@@ -6,6 +6,6 @@ fn private_function() {
     puts("it works!");
 }
 
-fn print_text() {
+pub fn print_text() {
     private_function();
 }
example/multiple_files/libc.zig
@@ -1,5 +1,5 @@
 #link("c")
 extern {
-    fn puts(s: *mut u8) -> i32;
-    fn exit(code: i32) -> unreachable;
+    pub fn puts(s: *mut u8) -> i32;
+    pub fn exit(code: i32) -> unreachable;
 }
example/multiple_files/main.zig
@@ -3,7 +3,7 @@ export executable "test";
 use "libc.zig";
 use "foo.zig";
 
-fn _start() -> unreachable {
+export fn _start() -> unreachable {
     private_function();
 }
 
src/analyze.cpp
@@ -137,6 +137,7 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import,
                 AstNode *fn_decl = node->data.extern_block.fn_decls.at(fn_decl_i);
                 assert(fn_decl->type == NodeTypeFnDecl);
                 AstNode *fn_proto = fn_decl->data.fn_decl.fn_proto;
+                bool is_pub = (fn_proto->data.fn_proto.visib_mod == FnProtoVisibModPub);
                 resolve_function_proto(g, fn_proto);
                 Buf *name = &fn_proto->data.fn_proto.name;
 
@@ -145,7 +146,12 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import,
                 fn_table_entry->is_extern = true;
                 fn_table_entry->calling_convention = LLVMCCallConv;
                 fn_table_entry->import_entry = import;
-                g->fn_table.put(name, fn_table_entry);
+
+                g->fn_protos.append(fn_table_entry);
+                import->fn_table.put(name, fn_table_entry);
+                if (is_pub) {
+                    g->fn_table.put(name, fn_table_entry);
+                }
             }
             break;
         case NodeTypeFnDef:
@@ -153,27 +159,44 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import,
                 AstNode *proto_node = node->data.fn_def.fn_proto;
                 assert(proto_node->type == NodeTypeFnProto);
                 Buf *proto_name = &proto_node->data.fn_proto.name;
-                auto entry = g->fn_table.maybe_get(proto_name);
+                auto entry = import->fn_table.maybe_get(proto_name);
+                bool skip = false;
+                bool is_internal = (proto_node->data.fn_proto.visib_mod != FnProtoVisibModExport);
+                bool is_pub = (proto_node->data.fn_proto.visib_mod == FnProtoVisibModPub);
                 if (entry) {
                     add_node_error(g, node,
                             buf_sprintf("redefinition of '%s'", buf_ptr(proto_name)));
                     assert(!node->codegen_node);
                     node->codegen_node = allocate<CodeGenNode>(1);
                     node->codegen_node->data.fn_def_node.skip = true;
-                } else {
+                    skip = true;
+                } else if (is_pub) {
+                    auto entry = g->fn_table.maybe_get(proto_name);
+                    if (entry) {
+                        add_node_error(g, node,
+                                buf_sprintf("redefinition of '%s'", buf_ptr(proto_name)));
+                        assert(!node->codegen_node);
+                        node->codegen_node = allocate<CodeGenNode>(1);
+                        node->codegen_node->data.fn_def_node.skip = true;
+                        skip = true;
+                    }
+                }
+                if (!skip) {
                     FnTableEntry *fn_table_entry = allocate<FnTableEntry>(1);
                     fn_table_entry->import_entry = import;
                     fn_table_entry->proto_node = proto_node;
                     fn_table_entry->fn_def_node = node;
-                    fn_table_entry->internal_linkage = proto_node->data.fn_proto.visib_mod != FnProtoVisibModExport;
-                    if (fn_table_entry->internal_linkage) {
-                        fn_table_entry->calling_convention = LLVMFastCallConv;
-                    } else {
-                        fn_table_entry->calling_convention = LLVMCCallConv;
-                    }
-                    g->fn_table.put(proto_name, fn_table_entry);
+                    fn_table_entry->internal_linkage = is_internal;
+                    fn_table_entry->calling_convention = is_internal ? LLVMFastCallConv : LLVMCCallConv;
+
+                    g->fn_protos.append(fn_table_entry);
                     g->fn_defs.append(fn_table_entry);
 
+                    import->fn_table.put(proto_name, fn_table_entry);
+                    if (is_pub) {
+                        g->fn_table.put(proto_name, fn_table_entry);
+                    }
+
                     resolve_function_proto(g, proto_node);
                 }
             }
@@ -297,28 +320,31 @@ static void check_fn_def_control_flow(CodeGen *g, AstNode *node) {
     }
 }
 
-static void analyze_expression(CodeGen *g, AstNode *node) {
+static void analyze_expression(CodeGen *g, ImportTableEntry *import, AstNode *node) {
     switch (node->type) {
         case NodeTypeBlock:
             for (int i = 0; i < node->data.block.statements.length; i += 1) {
                 AstNode *child = node->data.block.statements.at(i);
-                analyze_expression(g, child);
+                analyze_expression(g, import, child);
             }
             break;
         case NodeTypeReturnExpr:
             if (node->data.return_expr.expr) {
-                analyze_expression(g, node->data.return_expr.expr);
+                analyze_expression(g, import, node->data.return_expr.expr);
             }
             break;
         case NodeTypeBinOpExpr:
-            analyze_expression(g, node->data.bin_op_expr.op1);
-            analyze_expression(g, node->data.bin_op_expr.op2);
+            analyze_expression(g, import, node->data.bin_op_expr.op1);
+            analyze_expression(g, import, node->data.bin_op_expr.op2);
             break;
         case NodeTypeFnCallExpr:
             {
                 Buf *name = hack_get_fn_call_name(g, node->data.fn_call_expr.fn_ref_expr);
 
-                auto entry = g->fn_table.maybe_get(name);
+                auto entry = import->fn_table.maybe_get(name);
+                if (!entry)
+                    entry = g->fn_table.maybe_get(name);
+
                 if (!entry) {
                     add_node_error(g, node,
                             buf_sprintf("undefined function: '%s'", buf_ptr(name)));
@@ -336,7 +362,7 @@ static void analyze_expression(CodeGen *g, AstNode *node) {
 
                 for (int i = 0; i < node->data.fn_call_expr.params.length; i += 1) {
                     AstNode *child = node->data.fn_call_expr.params.at(i);
-                    analyze_expression(g, child);
+                    analyze_expression(g, import, child);
                 }
                 break;
             }
@@ -366,7 +392,7 @@ static void analyze_expression(CodeGen *g, AstNode *node) {
     }
 }
 
-static void analyze_top_level_declaration(CodeGen *g, AstNode *node) {
+static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import, AstNode *node) {
     switch (node->type) {
         case NodeTypeFnDef:
             {
@@ -387,7 +413,7 @@ static void analyze_top_level_declaration(CodeGen *g, AstNode *node) {
                 }
 
                 check_fn_def_control_flow(g, node);
-                analyze_expression(g, node->data.fn_def.body);
+                analyze_expression(g, import, node->data.fn_def.body);
             }
             break;
 
@@ -423,33 +449,50 @@ static void analyze_top_level_declaration(CodeGen *g, AstNode *node) {
     }
 }
 
-static void analyze_root(CodeGen *g, ImportTableEntry *import, AstNode *node) {
+static void find_function_declarations_root(CodeGen *g, ImportTableEntry *import, AstNode *node) {
     assert(node->type == NodeTypeRoot);
 
-    // find function declarations
     for (int i = 0; i < node->data.root.top_level_decls.length; i += 1) {
         AstNode *child = node->data.root.top_level_decls.at(i);
         preview_function_declarations(g, import, child);
     }
 
+}
+
+static void analyze_top_level_decls_root(CodeGen *g, ImportTableEntry *import, AstNode *node) {
+    assert(node->type == NodeTypeRoot);
+
     for (int i = 0; i < node->data.root.top_level_decls.length; i += 1) {
         AstNode *child = node->data.root.top_level_decls.at(i);
-        analyze_top_level_declaration(g, child);
+        analyze_top_level_declaration(g, import, child);
     }
-
 }
 
 void semantic_analyze(CodeGen *g) {
-    auto it = g->import_table.entry_iterator();
-    for (;;) {
-        auto *entry = it.next();
-        if (!entry)
-            break;
+    {
+        auto it = g->import_table.entry_iterator();
+        for (;;) {
+            auto *entry = it.next();
+            if (!entry)
+                break;
 
-        ImportTableEntry *import = entry->value;
-        analyze_root(g, import, import->root);
+            ImportTableEntry *import = entry->value;
+            find_function_declarations_root(g, import, import->root);
+        }
+    }
+    {
+        auto it = g->import_table.entry_iterator();
+        for (;;) {
+            auto *entry = it.next();
+            if (!entry)
+                break;
+
+            ImportTableEntry *import = entry->value;
+            analyze_top_level_decls_root(g, import, import->root);
+        }
     }
 
+
     if (!g->root_out_name) {
         add_node_error(g, g->root_import->root,
                 buf_sprintf("missing export declaration and output name not provided"));
src/codegen.cpp
@@ -125,7 +125,13 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) {
 
     Buf *name = hack_get_fn_call_name(g, node->data.fn_call_expr.fn_ref_expr);
 
-    FnTableEntry *fn_table_entry = g->fn_table.get(name);
+    FnTableEntry *fn_table_entry;
+    auto entry = g->cur_fn->import_entry->fn_table.maybe_get(name);
+    if (entry)
+        fn_table_entry = entry->value;
+    else
+        fn_table_entry = g->fn_table.get(name);
+
     assert(fn_table_entry->proto_node->type == NodeTypeFnProto);
     int expected_param_count = fn_table_entry->proto_node->data.fn_proto.params.length;
     int actual_param_count = node->data.fn_call_expr.params.length;
@@ -478,13 +484,8 @@ static void do_code_gen(CodeGen *g) {
 
 
     // Generate function prototypes
-    auto it = g->fn_table.entry_iterator();
-    for (;;) {
-        auto *entry = it.next();
-        if (!entry)
-            break;
-
-        FnTableEntry *fn_table_entry = entry->value;
+    for (int i = 0; i < g->fn_protos.length; i += 1) {
+        FnTableEntry *fn_table_entry = g->fn_protos.at(i);
 
         AstNode *proto_node = fn_table_entry->proto_node;
         assert(proto_node->type == NodeTypeFnProto);
@@ -547,6 +548,7 @@ static void do_code_gen(CodeGen *g) {
         assert(codegen_node);
 
         FnDefNode *codegen_fn_def = &codegen_node->data.fn_def_node;
+        assert(codegen_fn_def);
         codegen_fn_def->params = allocate<LLVMValueRef>(LLVMCountParams(fn));
         LLVMGetParams(fn, codegen_fn_def->params);
 
@@ -733,9 +735,9 @@ static ImportTableEntry *codegen_add_code(CodeGen *g, Buf *source_path, Buf *sou
         if (!entry) {
             Buf full_path = BUF_INIT;
             os_path_join(g->root_source_dir, &top_level_decl->data.use.path, &full_path);
-            Buf import_code = BUF_INIT;
-            os_fetch_file_path(&full_path, &import_code);
-            codegen_add_code(g, &top_level_decl->data.use.path, &import_code);
+            Buf *import_code = buf_alloc();
+            os_fetch_file_path(&full_path, import_code);
+            codegen_add_code(g, &top_level_decl->data.use.path, import_code);
         }
     }
 
src/semantic_info.hpp
@@ -80,7 +80,14 @@ struct CodeGen {
     Buf *root_source_dir;
     Buf *root_out_name;
     ZigList<LLVMZigDIScope *> block_scopes;
+
+    // The function definitions this module includes. There must be a corresponding
+    // fn_protos entry.
     ZigList<FnTableEntry *> fn_defs;
+    // The function prototypes this module includes. In the case of external declarations,
+    // there will not be a corresponding fn_defs entry.
+    ZigList<FnTableEntry *> fn_protos;
+
     OutType out_type;
     FnTableEntry *cur_fn;
     bool c_stdint_used;
test/run_tests.cpp
@@ -14,13 +14,13 @@
 
 struct TestSourceFile {
     const char *relative_path;
-    const char *text;
+    const char *source_code;
 };
 
 struct TestCase {
     const char *case_name;
     const char *output;
-    const char *source;
+    ZigList<TestSourceFile> source_files;
     ZigList<const char *> compile_errors;
     ZigList<const char *> compiler_args;
     ZigList<const char *> program_args;
@@ -31,11 +31,20 @@ static const char *tmp_source_path = ".tmp_source.zig";
 static const char *tmp_exe_path = "./.tmp_exe";
 static const char *zig_exe = "./zig";
 
-static void add_simple_case(const char *case_name, const char *source, const char *output) {
+static void add_source_file(TestCase *test_case, const char *path, const char *source) {
+    test_case->source_files.add_one();
+    test_case->source_files.last().relative_path = path;
+    test_case->source_files.last().source_code = source;
+}
+
+static TestCase *add_simple_case(const char *case_name, const char *source, const char *output) {
     TestCase *test_case = allocate<TestCase>(1);
     test_case->case_name = case_name;
     test_case->output = output;
-    test_case->source = source;
+
+    test_case->source_files.resize(1);
+    test_case->source_files.at(0).relative_path = tmp_source_path;
+    test_case->source_files.at(0).source_code = source;
 
     test_case->compiler_args.append("build");
     test_case->compiler_args.append(tmp_source_path);
@@ -52,15 +61,19 @@ static void add_simple_case(const char *case_name, const char *source, const cha
     test_case->compiler_args.append("on");
 
     test_cases.append(test_case);
+
+    return test_case;
 }
 
-static void add_compile_fail_case(const char *case_name, const char *source, int count, ...) {
+static TestCase *add_compile_fail_case(const char *case_name, const char *source, int count, ...) {
     va_list ap;
     va_start(ap, count);
 
     TestCase *test_case = allocate<TestCase>(1);
     test_case->case_name = case_name;
-    test_case->source = source;
+    test_case->source_files.resize(1);
+    test_case->source_files.at(0).relative_path = tmp_source_path;
+    test_case->source_files.at(0).source_code = source;
 
     for (int i = 0; i < count; i += 1) {
         const char *arg = va_arg(ap, const char *);
@@ -78,6 +91,8 @@ static void add_compile_fail_case(const char *case_name, const char *source, int
     test_cases.append(test_case);
 
     va_end(ap);
+
+    return test_case;
 }
 
 static void add_compiling_test_cases(void) {
@@ -135,6 +150,45 @@ static void add_compiling_test_cases(void) {
             exit(0);
         }
     )SOURCE", "OK\n");
+
+    {
+        TestCase *tc = add_simple_case("multiple files with private function", R"SOURCE(
+            use "libc.zig";
+            use "foo.zig";
+
+            export fn _start() -> unreachable {
+                private_function();
+            }
+
+            fn private_function() -> unreachable {
+                print_text();
+                exit(0);
+            }
+        )SOURCE", "OK\n");
+
+        add_source_file(tc, "libc.zig", R"SOURCE(
+            #link("c")
+            extern {
+                pub fn puts(s: *mut u8) -> i32;
+                pub fn exit(code: i32) -> unreachable;
+            }
+        )SOURCE");
+
+        add_source_file(tc, "foo.zig", R"SOURCE(
+            use "libc.zig";
+
+            // purposefully conflicting function with main source file
+            // but it's private so it should be OK
+            fn private_function() {
+                puts("OK");
+            }
+
+            pub fn print_text() {
+                private_function();
+            }
+        )SOURCE");
+    }
+
 }
 
 static void add_compile_failure_test_cases(void) {
@@ -207,7 +261,12 @@ static void print_compiler_invokation(TestCase *test_case, Buf *zig_stderr) {
 }
 
 static void run_test(TestCase *test_case) {
-    os_write_file(buf_create_from_str(tmp_source_path), buf_create_from_str(test_case->source));
+    for (int i = 0; i < test_case->source_files.length; i += 1) {
+        TestSourceFile *test_source = &test_case->source_files.at(i);
+        os_write_file(
+                buf_create_from_str(test_source->relative_path),
+                buf_create_from_str(test_source->source_code));
+    }
 
     Buf zig_stderr = BUF_INIT;
     Buf zig_stdout = BUF_INIT;
@@ -265,6 +324,11 @@ static void run_test(TestCase *test_case) {
         printf("=======================================\n");
         exit(1);
     }
+
+    for (int i = 0; i < test_case->source_files.length; i += 1) {
+        TestSourceFile *test_source = &test_case->source_files.at(i);
+        remove(test_source->relative_path);
+    }
 }
 
 static void run_all_tests(void) {
README.md
@@ -43,7 +43,6 @@ make
 ## Roadmap
 
  * variable declarations and assignment expressions
- * Multiple files
  * Type checking
  * inline assembly and syscalls
  * running code at compile time