Commit 370de7386c

Andrew Kelley <superjoe30@gmail.com>
2015-12-02 08:53:57
fix parameter access and thus shared library example
1 parent 08a2311
Changed files (5)
example/shared_library/mathtest.zig
@@ -2,5 +2,5 @@
 export library "mathtest";
 
 export fn add(a: i32, b: i32) -> i32 {
-    return a + b;
+    a + b
 }
src/analyze.cpp
@@ -74,7 +74,7 @@ TypeTableEntry *get_pointer_to_type(CodeGen *g, TypeTableEntry *child_type, bool
     }
 }
 
-static void resolve_type(CodeGen *g, AstNode *node) {
+static TypeTableEntry *resolve_type(CodeGen *g, AstNode *node) {
     assert(!node->codegen_node);
     node->codegen_node = allocate<CodeGenNode>(1);
     TypeNode *type_node = &node->codegen_node->data.type_node;
@@ -90,7 +90,7 @@ static void resolve_type(CodeGen *g, AstNode *node) {
                             buf_sprintf("invalid type name: '%s'", buf_ptr(name)));
                     type_node->entry = g->builtin_types.entry_invalid;
                 }
-                break;
+                return type_node->entry;
             }
         case AstNodeTypeTypePointer:
             {
@@ -101,12 +101,12 @@ static void resolve_type(CodeGen *g, AstNode *node) {
                             buf_create_from_str("pointer to unreachable not allowed"));
                 }
                 type_node->entry = get_pointer_to_type(g, child_type, node->data.type.is_const);
-                break;
+                return type_node->entry;
             }
     }
 }
 
-static void resolve_function_proto(CodeGen *g, AstNode *node) {
+static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_table_entry) {
     assert(node->type == NodeTypeFnProto);
 
     for (int i = 0; i < node->data.fn_proto.directives->length; i += 1) {
@@ -120,9 +120,11 @@ static void resolve_function_proto(CodeGen *g, AstNode *node) {
         AstNode *child = node->data.fn_proto.params.at(i);
         assert(child->type == NodeTypeParamDecl);
 
-        // parameter names are not important here.
-
-        resolve_type(g, child->data.param_decl.type);
+        Buf *param_name = &child->data.param_decl.name;
+        SymbolTableEntry *symbol_entry = allocate<SymbolTableEntry>(1);
+        symbol_entry->type_entry = resolve_type(g, child->data.param_decl.type);
+        symbol_entry->param_index = i;
+        fn_table_entry->symbol_table.put(param_name, symbol_entry);
     }
 
     resolve_type(g, node->data.fn_proto.return_type);
@@ -148,20 +150,26 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import,
                 assert(fn_decl->type == NodeTypeFnDecl);
                 AstNode *fn_proto = fn_decl->data.fn_decl.fn_proto;
                 bool is_pub = (fn_proto->data.fn_proto.visib_mod == FnProtoVisibModPub);
-                resolve_function_proto(g, fn_proto);
-                Buf *name = &fn_proto->data.fn_proto.name;
 
                 FnTableEntry *fn_table_entry = allocate<FnTableEntry>(1);
                 fn_table_entry->proto_node = fn_proto;
                 fn_table_entry->is_extern = true;
                 fn_table_entry->calling_convention = LLVMCCallConv;
                 fn_table_entry->import_entry = import;
+                fn_table_entry->symbol_table.init(8);
+
+                resolve_function_proto(g, fn_proto, fn_table_entry);
 
+                Buf *name = &fn_proto->data.fn_proto.name;
                 g->fn_protos.append(fn_table_entry);
                 import->fn_table.put(name, fn_table_entry);
                 if (is_pub) {
                     g->fn_table.put(name, fn_table_entry);
                 }
+
+                assert(!fn_proto->codegen_node);
+                fn_proto->codegen_node = allocate<CodeGenNode>(1);
+                fn_proto->codegen_node->data.fn_proto_node.fn_table_entry = fn_table_entry;
             }
             break;
         case NodeTypeFnDef:
@@ -198,6 +206,7 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import,
                     fn_table_entry->fn_def_node = node;
                     fn_table_entry->internal_linkage = is_internal;
                     fn_table_entry->calling_convention = is_internal ? LLVMFastCallConv : LLVMCCallConv;
+                    fn_table_entry->symbol_table.init(8);
 
                     g->fn_protos.append(fn_table_entry);
                     g->fn_defs.append(fn_table_entry);
@@ -207,7 +216,11 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import,
                         g->fn_table.put(proto_name, fn_table_entry);
                     }
 
-                    resolve_function_proto(g, proto_node);
+                    resolve_function_proto(g, proto_node, fn_table_entry);
+
+                    assert(!proto_node->codegen_node);
+                    proto_node->codegen_node = allocate<CodeGenNode>(1);
+                    proto_node->codegen_node->data.fn_proto_node.fn_table_entry = fn_table_entry;
                 }
             }
             break;
@@ -289,6 +302,16 @@ static TypeTableEntry * get_return_type(BlockContext *context) {
     return return_type_node->codegen_node->data.type_node.entry;
 }
 
+static FnTableEntry *get_context_fn_entry(BlockContext *context) {
+    AstNode *fn_def_node = context->root->node;
+    assert(fn_def_node->type == NodeTypeFnDef);
+    AstNode *fn_proto_node = fn_def_node->data.fn_def.fn_proto;
+    assert(fn_proto_node->type == NodeTypeFnProto);
+    assert(fn_proto_node->codegen_node);
+    assert(fn_proto_node->codegen_node->data.fn_proto_node.fn_table_entry);
+    return fn_proto_node->codegen_node->data.fn_proto_node.fn_table_entry;
+}
+
 static void check_type_compatibility(CodeGen *g, AstNode *node, TypeTableEntry *expected_type, TypeTableEntry *actual_type) {
     if (expected_type == nullptr)
         return; // anything will do
@@ -482,9 +505,20 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
             break;
 
         case NodeTypeSymbol:
-            // look up symbol in symbol table
-            zig_panic("TODO analyze_expression symbol");
-
+            {
+                Buf *symbol_name = &node->data.symbol;
+                FnTableEntry *fn_table_entry = get_context_fn_entry(context);
+                auto table_entry = fn_table_entry->symbol_table.maybe_get(symbol_name);
+                if (table_entry) {
+                    SymbolTableEntry *symbol_entry = table_entry->value;
+                    return_type = symbol_entry->type_entry;
+                } else {
+                    add_node_error(g, node,
+                            buf_sprintf("use of undeclared identifier '%s'", buf_ptr(symbol_name)));
+                    return_type = g->builtin_types.entry_invalid;
+                }
+                break;
+            }
         case NodeTypeCastExpr:
             zig_panic("TODO analyze_expression cast expr");
             break;
src/codegen.cpp
@@ -105,19 +105,13 @@ static LLVMValueRef find_or_create_string(CodeGen *g, Buf *str) {
 
 static LLVMValueRef get_variable_value(CodeGen *g, Buf *name) {
     assert(g->cur_fn->proto_node->type == NodeTypeFnProto);
-    int param_count = g->cur_fn->proto_node->data.fn_proto.params.length;
-    for (int i = 0; i < param_count; i += 1) {
-        AstNode *param_decl_node = g->cur_fn->proto_node->data.fn_proto.params.at(i);
-        assert(param_decl_node->type == NodeTypeParamDecl);
-        Buf *param_name = &param_decl_node->data.param_decl.name;
-        if (buf_eql_buf(name, param_name)) {
-            CodeGenNode *codegen_node = g->cur_fn->fn_def_node->codegen_node;
-            assert(codegen_node);
-            FnDefNode *codegen_fn_def = &codegen_node->data.fn_def_node;
-            return codegen_fn_def->params[i];
-        }
-    }
-    zig_unreachable();
+
+    SymbolTableEntry *symbol_entry = g->cur_fn->symbol_table.get(name);
+
+    CodeGenNode *codegen_node = g->cur_fn->fn_def_node->codegen_node;
+    assert(codegen_node);
+    FnDefNode *codegen_fn_def = &codegen_node->data.fn_def_node;
+    return codegen_fn_def->params[symbol_entry->param_index];
 }
 
 static TypeTableEntry *get_expr_type(AstNode *node) {
src/semantic_info.hpp
@@ -38,6 +38,11 @@ struct ImportTableEntry {
     HashMap<Buf *, FnTableEntry *, buf_hash, buf_eql_buf> fn_table;
 };
 
+struct SymbolTableEntry {
+    TypeTableEntry *type_entry;
+    int param_index; // only valid in the case of parameters
+};
+
 struct FnTableEntry {
     LLVMValueRef fn_value;
     AstNode *proto_node;
@@ -46,6 +51,9 @@ struct FnTableEntry {
     bool internal_linkage;
     unsigned calling_convention;
     ImportTableEntry *import_entry;
+
+    // reminder: hash tables must be initialized before use
+    HashMap<Buf *, SymbolTableEntry *, buf_hash, buf_eql_buf> symbol_table;
 };
 
 struct CodeGen {
@@ -106,6 +114,10 @@ struct TypeNode {
     TypeTableEntry *entry;
 };
 
+struct FnProtoNode {
+    FnTableEntry *fn_table_entry;
+};
+
 struct FnDefNode {
     TypeTableEntry *implicit_return_type;
     bool skip;
@@ -121,6 +133,7 @@ struct CodeGenNode {
         TypeNode type_node; // for NodeTypeType
         FnDefNode fn_def_node; // for NodeTypeFnDef
         ExprNode expr_node; // for all the expression nodes
+        FnProtoNode fn_proto_node; // for NodeTypeFnProto
     } data;
 };
 
test/run_tests.cpp
@@ -213,6 +213,25 @@ static void add_compiling_test_cases(void) {
             exit(0);
         }
     )SOURCE", "1 is true\n!0 is true\n");
+
+    add_simple_case("params", R"SOURCE(
+        #link("c")
+        extern {
+            fn puts(s: *const u8) -> i32;
+            fn exit(code: i32) -> unreachable;
+        }
+
+        fn add(a: i32, b: i32) -> i32 {
+            a + b
+        }
+
+        export fn _start() -> unreachable {
+            if add(22, 11) == 33 {
+                puts("pass");
+            }
+            exit(0);
+        }
+    )SOURCE", "pass\n");
 }
 
 static void add_compile_failure_test_cases(void) {