Commit 44d5d008d0

Andrew Kelley <superjoe30@gmail.com>
2016-01-04 11:31:57
partial import segregation
See #3
1 parent 333a322
src/analyze.cpp
@@ -13,15 +13,18 @@
 static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         TypeTableEntry *expected_type, AstNode *node);
 
-static void alloc_codegen_node(AstNode *node) {
-    assert(!node->codegen_node);
-    node->codegen_node = allocate<CodeGenNode>(1);
-}
-
 static AstNode *first_executing_node(AstNode *node) {
     switch (node->type) {
         case NodeTypeFnCallExpr:
             return first_executing_node(node->data.fn_call_expr.fn_ref_expr);
+        case NodeTypeBinOpExpr:
+            return first_executing_node(node->data.bin_op_expr.op1);
+        case NodeTypeArrayAccessExpr:
+            return first_executing_node(node->data.array_access_expr.array_ref_expr);
+        case NodeTypeFieldAccessExpr:
+            return first_executing_node(node->data.field_access_expr.struct_expr);
+        case NodeTypeCastExpr:
+            return first_executing_node(node->data.cast_expr.expr);
         case NodeTypeRoot:
         case NodeTypeRootExportDecl:
         case NodeTypeFnProto:
@@ -34,15 +37,12 @@ static AstNode *first_executing_node(AstNode *node) {
         case NodeTypeDirective:
         case NodeTypeReturnExpr:
         case NodeTypeVariableDeclaration:
-        case NodeTypeBinOpExpr:
-        case NodeTypeCastExpr:
         case NodeTypeNumberLiteral:
         case NodeTypeStringLiteral:
         case NodeTypeCharLiteral:
         case NodeTypeUnreachable:
         case NodeTypeSymbol:
         case NodeTypePrefixOpExpr:
-        case NodeTypeArrayAccessExpr:
         case NodeTypeUse:
         case NodeTypeVoid:
         case NodeTypeBoolLiteral:
@@ -53,7 +53,6 @@ static AstNode *first_executing_node(AstNode *node) {
         case NodeTypeBreak:
         case NodeTypeContinue:
         case NodeTypeAsmExpr:
-        case NodeTypeFieldAccessExpr:
         case NodeTypeStructDecl:
         case NodeTypeStructField:
         case NodeTypeStructValueExpr:
@@ -476,7 +475,6 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import,
                 AstNode *fn_decl = node->data.extern_block.fn_decls.at(fn_decl_i);
                 assert(fn_decl->type == NodeTypeFnDecl);
                 AstNode *fn_proto = fn_decl->data.fn_decl.fn_proto;
-                bool is_pub = (fn_proto->data.fn_proto.visib_mod == FnProtoVisibModPub);
 
                 FnTableEntry *fn_table_entry = allocate<FnTableEntry>(1);
                 fn_table_entry->proto_node = fn_proto;
@@ -490,9 +488,6 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import,
                 Buf *name = &fn_proto->data.fn_proto.name;
                 g->fn_protos.append(fn_table_entry);
                 import->fn_table.put(name, fn_table_entry);
-                if (is_pub) {
-                    g->fn_table.put(name, fn_table_entry);
-                }
 
                 alloc_codegen_node(fn_proto);
                 fn_proto->codegen_node->data.fn_proto_node.fn_table_entry = fn_table_entry;
@@ -514,7 +509,7 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import,
                     node->codegen_node->data.fn_def_node.skip = true;
                     skip = true;
                 } else if (is_pub) {
-                    auto entry = g->fn_table.maybe_get(proto_name);
+                    auto entry = import->fn_table.maybe_get(proto_name);
                     if (entry) {
                         add_node_error(g, node,
                                 buf_sprintf("redefinition of '%s'", buf_ptr(proto_name)));
@@ -540,8 +535,9 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import,
                     g->fn_defs.append(fn_table_entry);
 
                     import->fn_table.put(proto_name, fn_table_entry);
-                    if (is_pub) {
-                        g->fn_table.put(proto_name, fn_table_entry);
+
+                    if (g->bootstrap_import && import == g->root_import && buf_eql_str(proto_name, "main")) {
+                        g->bootstrap_import->fn_table.put(proto_name, fn_table_entry);
                     }
 
                     resolve_function_proto(g, proto_node, fn_table_entry, import);
@@ -1748,8 +1744,6 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
                 Buf *name = &fn_ref_expr->data.symbol;
 
                 auto entry = import->fn_table.maybe_get(name);
-                if (!entry)
-                    entry = g->fn_table.maybe_get(name);
 
                 if (!entry) {
                     add_node_error(g, fn_ref_expr,
@@ -2011,13 +2005,41 @@ static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import,
             // already looked at these in the preview pass
             break;
         case NodeTypeUse:
-            for (int i = 0; i < node->data.use.directives->length; i += 1) {
-                AstNode *directive_node = node->data.use.directives->at(i);
-                Buf *name = &directive_node->data.directive.name;
-                add_node_error(g, directive_node,
-                        buf_sprintf("invalid directive: '%s'", buf_ptr(name)));
+            {
+                for (int i = 0; i < node->data.use.directives->length; i += 1) {
+                    AstNode *directive_node = node->data.use.directives->at(i);
+                    Buf *name = &directive_node->data.directive.name;
+                    add_node_error(g, directive_node,
+                            buf_sprintf("invalid directive: '%s'", buf_ptr(name)));
+                }
+
+                ImportTableEntry *target_import = node->codegen_node->data.import_node.import;
+                assert(target_import);
+
+                // import all the public functions
+                {
+                    auto it = target_import->fn_table.entry_iterator();
+                    for (;;) {
+                        auto *entry = it.next();
+                        if (!entry)
+                            break;
+
+                        FnTableEntry *fn_entry = entry->value;
+                        bool is_pub = (fn_entry->proto_node->data.fn_proto.visib_mod != FnProtoVisibModPrivate);
+                        if (is_pub) {
+                            auto existing_entry = import->fn_table.maybe_get(entry->key);
+                            if (existing_entry) {
+                                add_node_error(g, node,
+                                    buf_sprintf("import of function '%s' overrides existing definition",
+                                        buf_ptr(&fn_entry->proto_node->data.fn_proto.name)));
+                            } else {
+                                import->fn_table.put(entry->key, entry->value);
+                            }
+                        }
+                    }
+                }
+                break;
             }
-            break;
         case NodeTypeStructDecl:
             // nothing to do
             break;
@@ -2118,6 +2140,7 @@ void semantic_analyze(CodeGen *g) {
             find_function_declarations_root(g, import, import->root);
         }
     }
+
     {
         auto it = g->import_table.entry_iterator();
         for (;;) {
@@ -2139,3 +2162,9 @@ void semantic_analyze(CodeGen *g) {
                 buf_sprintf("missing export declaration and export type not provided"));
     }
 }
+
+void alloc_codegen_node(AstNode *node) {
+    assert(!node->codegen_node);
+    node->codegen_node = allocate<CodeGenNode>(1);
+}
+
src/analyze.hpp
@@ -150,7 +150,6 @@ struct CodeGen {
     ZigList<Buf *> lib_search_paths;
 
     // reminder: hash tables must be initialized before use
-    HashMap<Buf *, FnTableEntry *, buf_hash, buf_eql_buf> fn_table;
     HashMap<Buf *, LLVMValueRef, buf_hash, buf_eql_buf> str_table;
     HashMap<Buf *, TypeTableEntry *, buf_hash, buf_eql_buf> type_table;
     HashMap<Buf *, bool, buf_hash, buf_eql_buf> link_table;
@@ -215,7 +214,9 @@ struct CodeGen {
     bool verbose;
     ErrColor err_color;
     ImportTableEntry *root_import;
+    ImportTableEntry *bootstrap_import;
     LLVMValueRef memcpy_fn_val;
+    bool error_during_imports;
 };
 
 struct VariableTableEntry {
@@ -328,6 +329,10 @@ struct ParamDeclNode {
     VariableTableEntry *variable;
 };
 
+struct ImportNode {
+    ImportTableEntry *import;
+};
+
 struct CodeGenNode {
     union {
         TypeNode type_node; // for NodeTypeType
@@ -345,6 +350,7 @@ struct CodeGenNode {
         StructValExprNode struct_val_expr_node; // for NodeTypeStructValueExpr
         IfVarNode if_var_node; // for NodeTypeStructValueExpr
         ParamDeclNode param_decl_node; // for NodeTypeParamDecl
+        ImportNode import_node; // for NodeTypeUse
     } data;
     ExprNode expr_node; // for all the expression nodes
 };
@@ -358,6 +364,7 @@ static inline Buf *hack_get_fn_call_name(CodeGen *g, AstNode *node) {
 
 void semantic_analyze(CodeGen *g);
 void add_node_error(CodeGen *g, AstNode *node, Buf *msg);
+void alloc_codegen_node(AstNode *node);
 TypeTableEntry *new_type_table_entry(TypeTableEntryId id);
 TypeTableEntry *get_pointer_to_type(CodeGen *g, TypeTableEntry *child_type, bool is_const);
 VariableTableEntry *find_variable(BlockContext *context, Buf *name);
src/codegen.cpp
@@ -19,7 +19,6 @@
 
 CodeGen *codegen_create(Buf *root_source_dir) {
     CodeGen *g = allocate<CodeGen>(1);
-    g->fn_table.init(32);
     g->str_table.init(32);
     g->type_table.init(32);
     g->link_table.init(32);
@@ -146,12 +145,7 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) {
 
     Buf *name = hack_get_fn_call_name(g, node->data.fn_call_expr.fn_ref_expr);
 
-    FnTableEntry *fn_table_entry;
-    auto entry = g->cur_fn->import_entry->fn_table.maybe_get(name);
-    if (entry)
-        fn_table_entry = entry->value;
-    else
-        fn_table_entry = g->fn_table.get(name);
+    FnTableEntry *fn_table_entry = g->cur_fn->import_entry->fn_table.get(name);
 
     assert(fn_table_entry->proto_node->type == NodeTypeFnProto);
     AstNodeFnProto *fn_proto_data = &fn_table_entry->proto_node->data.fn_proto;
@@ -2062,6 +2056,8 @@ static ImportTableEntry *codegen_add_code(CodeGen *g, Buf *abs_full_path,
             Buf *import_code = buf_alloc();
             bool found_it = false;
 
+            alloc_codegen_node(top_level_decl);
+
             for (int path_i = 0; path_i < g->lib_search_paths.length; path_i += 1) {
                 Buf *search_path = g->lib_search_paths.at(path_i);
                 os_path_join(search_path, import_target_path, &full_path);
@@ -2071,6 +2067,7 @@ static ImportTableEntry *codegen_add_code(CodeGen *g, Buf *abs_full_path,
                     if (err == ErrorFileNotFound) {
                         continue;
                     } else {
+                        g->error_during_imports = true;
                         add_node_error(g, top_level_decl,
                                 buf_sprintf("unable to open '%s': %s", buf_ptr(&full_path), err_str(err)));
                         goto done_looking_at_imports;
@@ -2080,22 +2077,26 @@ static ImportTableEntry *codegen_add_code(CodeGen *g, Buf *abs_full_path,
                 auto entry = g->import_table.maybe_get(abs_full_path);
                 if (entry) {
                     found_it = true;
+                    top_level_decl->codegen_node->data.import_node.import = entry->value;
                 } else {
                     if ((err = os_fetch_file_path(abs_full_path, import_code))) {
                         if (err == ErrorFileNotFound) {
                             continue;
                         } else {
+                            g->error_during_imports = true;
                             add_node_error(g, top_level_decl,
                                     buf_sprintf("unable to open '%s': %s", buf_ptr(&full_path), err_str(err)));
                             goto done_looking_at_imports;
                         }
                     }
-                    codegen_add_code(g, abs_full_path, search_path, &top_level_decl->data.use.path, import_code);
+                    top_level_decl->codegen_node->data.import_node.import = codegen_add_code(g,
+                            abs_full_path, search_path, &top_level_decl->data.use.path, import_code);
                     found_it = true;
                 }
                 break;
             }
             if (!found_it) {
+                g->error_during_imports = true;
                 add_node_error(g, top_level_decl,
                         buf_sprintf("unable to find '%s'", buf_ptr(import_target_path)));
             }
@@ -2147,14 +2148,16 @@ void codegen_add_root_code(CodeGen *g, Buf *src_dir, Buf *src_basename, Buf *sou
             zig_panic("unable to open '%s': %s", buf_ptr(&path_to_bootstrap_src), err_str(err));
         }
 
-        codegen_add_code(g, abs_full_path, bootstrap_dir, bootstrap_basename, import_code);
+        g->bootstrap_import = codegen_add_code(g, abs_full_path, bootstrap_dir, bootstrap_basename, import_code);
     }
 
     if (g->verbose) {
         fprintf(stderr, "\nSemantic Analysis:\n");
         fprintf(stderr, "--------------------\n");
     }
-    semantic_analyze(g);
+    if (!g->error_during_imports) {
+        semantic_analyze(g);
+    }
 
     if (g->errors.length == 0) {
         if (g->verbose) {
std/bootstrap.zig
@@ -1,5 +1,8 @@
 use "std.zig";
 
+// The compiler treats this file special by implicitly importing the function `main`
+// from the root source file.
+
 #attribute("naked")
 export fn _start() -> unreachable {
     const argc = asm("mov (%%rsp), %[argc]" : [argc] "=r" (-> isize));
test/run_tests.cpp
@@ -173,6 +173,44 @@ pub fn print_text() {
         )SOURCE");
     }
 
+    {
+        TestCase *tc = add_simple_case("import segregation", R"SOURCE(
+use "foo.zig";
+use "bar.zig";
+
+pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
+    foo_function();
+    bar_function();
+    return 0;
+}
+        )SOURCE", "OK\nOK\n");
+
+        add_source_file(tc, "foo.zig", R"SOURCE(
+use "std.zig";
+pub fn foo_function() {
+    print_str("OK\n");
+}
+        )SOURCE");
+
+        add_source_file(tc, "bar.zig", R"SOURCE(
+use "other.zig";
+use "std.zig";
+
+pub fn bar_function() {
+    if (foo_function()) {
+        print_str("OK\n");
+    }
+}
+        )SOURCE");
+
+        add_source_file(tc, "other.zig", R"SOURCE(
+pub fn foo_function() -> bool {
+    // this one conflicts with the one from foo
+    return true;
+}
+        )SOURCE");
+    }
+
     add_simple_case("if statements", R"SOURCE(
         use "std.zig";