Commit 4bbc074dd7

Andrew Kelley <superjoe30@gmail.com>
2015-11-24 10:43:45
hello world IR code looks good
1 parent 7d22a89
src/codegen.cpp
@@ -5,29 +5,33 @@
 
 #include <llvm-c/Core.h>
 
+struct FnTableEntry {
+    LLVMValueRef fn_value;
+    AstNode *proto_node;
+};
+
 struct CodeGen {
+    LLVMModuleRef mod;
     AstNode *root;
-    HashMap<Buf *, AstNode *, buf_hash, buf_eql_buf> fn_decls;
+    HashMap<Buf *, AstNode *, buf_hash, buf_eql_buf> fn_defs;
     ZigList<ErrorMsg> errors;
     LLVMBuilderRef builder;
-    HashMap<Buf *, LLVMValueRef, buf_hash, buf_eql_buf> external_fns;
-};
-
-struct ExpressionNode {
-    AstNode *type_node;
+    HashMap<Buf *, FnTableEntry *, buf_hash, buf_eql_buf> fn_table;
+    HashMap<Buf *, LLVMValueRef, buf_hash, buf_eql_buf> str_table;
 };
 
 struct CodeGenNode {
     union {
         LLVMTypeRef type_ref; // for NodeTypeType
-        ExpressionNode expr; // for NodeTypeExpression
     } data;
 };
 
 CodeGen *create_codegen(AstNode *root) {
     CodeGen *g = allocate<CodeGen>(1);
     g->root = root;
-    g->fn_decls.init(32);
+    g->fn_defs.init(32);
+    g->fn_table.init(32);
+    g->str_table.init(32);
     return g;
 }
 
@@ -41,31 +45,81 @@ static void add_node_error(CodeGen *g, AstNode *node, Buf *msg) {
     last_msg->msg = msg;
 }
 
+static LLVMTypeRef to_llvm_type(AstNode *type_node) {
+    assert(type_node->type == NodeTypeType);
+    assert(type_node->codegen_node);
+
+    return type_node->codegen_node->data.type_ref;
+}
+
 static void analyze_node(CodeGen *g, AstNode *node) {
     switch (node->type) {
         case NodeTypeRoot:
-            for (int i = 0; i < node->data.root.fn_decls.length; i += 1) {
-                AstNode *child = node->data.root.fn_decls.at(i);
+            for (int i = 0; i < node->data.root.top_level_decls.length; i += 1) {
+                AstNode *child = node->data.root.top_level_decls.at(i);
                 analyze_node(g, child);
             }
             break;
-        case NodeTypeFnDecl:
+        case NodeTypeExternBlock:
+            for (int fn_decl_i = 0; fn_decl_i < node->data.extern_block.fn_decls.length; fn_decl_i += 1) {
+                AstNode *fn_decl = node->data.extern_block.fn_decls.at(fn_decl_i);
+                analyze_node(g, fn_decl);
+
+                AstNode *fn_proto = fn_decl->data.fn_decl.fn_proto;
+                Buf *name = &fn_proto->data.fn_proto.name;
+                ZigList<AstNode *> *params = &fn_proto->data.fn_proto.params;
+
+                LLVMTypeRef *fn_param_values = allocate<LLVMTypeRef>(params->length);
+                for (int param_i = 0; param_i < params->length; param_i += 1) {
+                    AstNode *param_node = params->at(param_i);
+                    assert(param_node->type == NodeTypeParamDecl);
+                    AstNode *param_type = param_node->data.param_decl.type;
+                    fn_param_values[param_i] = to_llvm_type(param_type);
+                }
+                LLVMTypeRef return_type = to_llvm_type(fn_proto->data.fn_proto.return_type);
+
+                LLVMTypeRef fn_type = LLVMFunctionType(return_type, fn_param_values, params->length, 0);
+                LLVMValueRef fn_val = LLVMAddFunction(g->mod, buf_ptr(name), fn_type);
+                LLVMSetLinkage(fn_val, LLVMExternalLinkage);
+                LLVMSetFunctionCallConv(fn_val, LLVMCCallConv);
+
+                FnTableEntry *fn_table_entry = allocate<FnTableEntry>(1);
+                fn_table_entry->fn_value = fn_val;
+                fn_table_entry->proto_node = fn_proto;
+                g->fn_table.put(name, fn_table_entry);
+            }
+            break;
+        case NodeTypeFnDef:
             {
-                auto entry = g->fn_decls.maybe_get(&node->data.fn_decl.name);
+                AstNode *proto_node = node->data.fn_def.fn_proto;
+                assert(proto_node->type = NodeTypeFnProto);
+                Buf *proto_name = &proto_node->data.fn_proto.name;
+                auto entry = g->fn_defs.maybe_get(proto_name);
                 if (entry) {
                     add_node_error(g, node,
-                            buf_sprintf("redefinition of '%s'", buf_ptr(&node->data.fn_decl.name)));
+                            buf_sprintf("redefinition of '%s'", buf_ptr(proto_name)));
                 } else {
-                    g->fn_decls.put(&node->data.fn_decl.name, node);
-                    for (int i = 0; i < node->data.fn_decl.params.length; i += 1) {
-                        AstNode *child = node->data.fn_decl.params.at(i);
-                        analyze_node(g, child);
-                    }
-                    analyze_node(g, node->data.fn_decl.return_type);
-                    analyze_node(g, node->data.fn_decl.body);
+                    g->fn_defs.put(proto_name, node);
+                    analyze_node(g, proto_node);
                 }
                 break;
             }
+        case NodeTypeFnDecl:
+            {
+                AstNode *proto_node = node->data.fn_decl.fn_proto;
+                assert(proto_node->type == NodeTypeFnProto);
+                analyze_node(g, proto_node);
+                break;
+            }
+        case NodeTypeFnProto:
+            {
+                for (int i = 0; i < node->data.fn_proto.params.length; i += 1) {
+                    AstNode *child = node->data.fn_proto.params.at(i);
+                    analyze_node(g, child);
+                }
+                analyze_node(g, node->data.fn_proto.return_type);
+                break;
+            }
         case NodeTypeParamDecl:
             analyze_node(g, node->data.param_decl.type);
             break;
@@ -131,47 +185,81 @@ static void analyze_node(CodeGen *g, AstNode *node) {
 }
 
 
-/* TODO external fn
-    LLVMTypeRef puts_param_types[] = {LLVMPointerType(LLVMInt8Type(), 0)};
-    LLVMTypeRef puts_type = LLVMFunctionType(LLVMInt32Type(), puts_param_types, 1, 0);
-    LLVMValueRef puts_fn = LLVMAddFunction(mod, "puts", puts_type);
-    LLVMSetLinkage(puts_fn, LLVMExternalLinkage);
-    */
-
 void semantic_analyze(CodeGen *g) {
+    g->mod = LLVMModuleCreateWithName("ZigModule");
+
     // Pass 1.
     analyze_node(g, g->root);
 }
 
-static LLVMTypeRef to_llvm_type(AstNode *type_node) {
-    assert(type_node->type == NodeTypeType);
-    assert(type_node->codegen_node);
-
-    return type_node->codegen_node->data.type_ref;
-}
+static LLVMValueRef gen_expr(CodeGen *g, AstNode *expr_node);
 
 static LLVMValueRef gen_fn_call(CodeGen *g, AstNode *fn_call_node) {
     assert(fn_call_node->type == NodeTypeFnCall);
 
-    zig_panic("TODO support external fn declarations");
-    //LLVMTypeRef fn_type =  LLVMFunctionType(LLVMVoidType(), );
+    Buf *name = &fn_call_node->data.fn_call.name;
 
-    // resolve function name
-    //LLVMValueRef result = LLVMBuildCall(g->builder, 
+    auto entry = g->fn_table.maybe_get(name);
+    if (!entry) {
+        add_node_error(g, fn_call_node,
+                buf_sprintf("undefined function: '%s'", buf_ptr(name)));
+        return LLVMConstNull(LLVMInt32Type());
+    }
+    FnTableEntry *fn_table_entry = entry->value;
+    assert(fn_table_entry->proto_node->type == NodeTypeFnProto);
+    int expected_param_count = fn_table_entry->proto_node->data.fn_proto.params.length;
+    int actual_param_count = fn_call_node->data.fn_call.params.length;
+    if (expected_param_count != actual_param_count) {
+        add_node_error(g, fn_call_node,
+                buf_sprintf("wrong number of arguments. Expected %d, got %d.",
+                    expected_param_count, actual_param_count));
+        return LLVMConstNull(LLVMInt32Type());
+    }
 
+    LLVMValueRef *param_values = allocate<LLVMValueRef>(actual_param_count);
+    for (int i = 0; i < actual_param_count; i += 1) {
+        AstNode *expr_node = fn_call_node->data.fn_call.params.at(i);
+        param_values[i] = gen_expr(g, expr_node);
+    }
 
-    //return value;
+    LLVMValueRef result = LLVMBuildCall(g->builder, fn_table_entry->fn_value,
+            param_values, actual_param_count, "");
+
+    return result;
+}
+
+static LLVMValueRef find_or_create_string(CodeGen *g, Buf *str) {
+    auto entry = g->str_table.maybe_get(str);
+    if (entry) {
+        return entry->value;
+    }
+    LLVMValueRef text = LLVMConstString(buf_ptr(str), buf_len(str), false);
+    LLVMValueRef global_value = LLVMAddGlobal(g->mod, LLVMTypeOf(text), "");
+    LLVMSetLinkage(global_value, LLVMInternalLinkage);
+    LLVMSetInitializer(global_value, text);
+    LLVMSetGlobalConstant(global_value, true);
+    g->str_table.put(str, global_value);
+
+    return global_value;
 }
 
 static LLVMValueRef gen_expr(CodeGen *g, AstNode *expr_node) {
     assert(expr_node->type == NodeTypeExpression);
     switch (expr_node->data.expression.type) {
         case AstNodeExpressionTypeNumber:
-            zig_panic("TODO number expr");
-            break;
+            {
+                Buf *number_str = &expr_node->data.expression.data.number;
+                LLVMTypeRef number_type = LLVMInt32Type();
+                LLVMValueRef number_val = LLVMConstIntOfStringAndSize(number_type,
+                        buf_ptr(number_str), buf_len(number_str), 10);
+                return number_val;
+            }
         case AstNodeExpressionTypeString:
-            zig_panic("TODO string expr");
-            break;
+            {
+                Buf *str = &expr_node->data.expression.data.string;
+                fprintf(stderr, "str = '%s'\n", buf_ptr(str));
+                return find_or_create_string(g, str);
+            }
         case AstNodeExpressionTypeFnCall:
             return gen_fn_call(g, expr_node->data.expression.data.fn_call);
     }
@@ -203,32 +291,37 @@ static void gen_block(CodeGen *g, AstNode *block_node) {
 }
 
 void code_gen(CodeGen *g) {
-    LLVMModuleRef mod = LLVMModuleCreateWithName("ZigModule");
     g->builder = LLVMCreateBuilder();
 
+    auto it = g->fn_defs.entry_iterator();
+    for (;;) {
+        auto *entry = it.next();
+        if (!entry)
+            break;
 
-    for (int fn_decl_i = 0; fn_decl_i < g->root->data.root.fn_decls.length; fn_decl_i += 1) {
-        AstNode *fn_decl_node = g->root->data.root.fn_decls.at(fn_decl_i);
-        AstNodeFnDecl *fn_decl = &fn_decl_node->data.fn_decl;
+        AstNode *fn_def_node = entry->value;
+        AstNodeFnDef *fn_def = &fn_def_node->data.fn_def;
+        assert(fn_def->fn_proto->type == NodeTypeFnProto);
+        AstNodeFnProto *fn_proto = &fn_def->fn_proto->data.fn_proto;
 
-        LLVMTypeRef ret_type = to_llvm_type(fn_decl->return_type);
-        LLVMTypeRef *param_types = allocate<LLVMTypeRef>(fn_decl->params.length);
-        for (int param_decl_i = 0; param_decl_i < fn_decl->params.length; param_decl_i += 1) {
-            AstNode *param_node = fn_decl->params.at(param_decl_i);
+        LLVMTypeRef ret_type = to_llvm_type(fn_proto->return_type);
+        LLVMTypeRef *param_types = allocate<LLVMTypeRef>(fn_proto->params.length);
+        for (int param_decl_i = 0; param_decl_i < fn_proto->params.length; param_decl_i += 1) {
+            AstNode *param_node = fn_proto->params.at(param_decl_i);
             assert(param_node->type == NodeTypeParamDecl);
             AstNode *type_node = param_node->data.param_decl.type;
             param_types[param_decl_i] = to_llvm_type(type_node);
         }
-        LLVMTypeRef function_type = LLVMFunctionType(ret_type, param_types, fn_decl->params.length, 0);
-        LLVMValueRef fn = LLVMAddFunction(mod, buf_ptr(&fn_decl->name), function_type);
+        LLVMTypeRef function_type = LLVMFunctionType(ret_type, param_types, fn_proto->params.length, 0);
+        LLVMValueRef fn = LLVMAddFunction(g->mod, buf_ptr(&fn_proto->name), function_type);
 
-        LLVMBasicBlockRef entry = LLVMAppendBasicBlock(fn, "entry");
-        LLVMPositionBuilderAtEnd(g->builder, entry);
+        LLVMBasicBlockRef entry_block = LLVMAppendBasicBlock(fn, "entry");
+        LLVMPositionBuilderAtEnd(g->builder, entry_block);
 
-        gen_block(g, fn_decl->body);
+        gen_block(g, fn_def->body);
     }
 
-    LLVMDumpModule(mod);
+    LLVMDumpModule(g->mod);
 }
 
 ZigList<ErrorMsg> *codegen_error_messages(CodeGen *g) {
src/parser.cpp
@@ -20,8 +20,12 @@ const char *node_type_str(NodeType node_type) {
     switch (node_type) {
         case NodeTypeRoot:
             return "Root";
+        case NodeTypeFnDef:
+            return "FnDef";
         case NodeTypeFnDecl:
             return "FnDecl";
+        case NodeTypeFnProto:
+            return "FnProto";
         case NodeTypeParamDecl:
             return "ParamDecl";
         case NodeTypeType:
@@ -34,6 +38,8 @@ const char *node_type_str(NodeType node_type) {
             return "Expression";
         case NodeTypeFnCall:
             return "FnCall";
+        case NodeTypeExternBlock:
+            return "ExternBlock";
     }
     zig_unreachable();
 }
@@ -46,24 +52,30 @@ void ast_print(AstNode *node, int indent) {
     switch (node->type) {
         case NodeTypeRoot:
             fprintf(stderr, "%s\n", node_type_str(node->type));
-            for (int i = 0; i < node->data.root.fn_decls.length; i += 1) {
-                AstNode *child = node->data.root.fn_decls.at(i);
+            for (int i = 0; i < node->data.root.top_level_decls.length; i += 1) {
+                AstNode *child = node->data.root.top_level_decls.at(i);
                 ast_print(child, indent + 2);
             }
             break;
-        case NodeTypeFnDecl:
+        case NodeTypeFnDef:
+            {
+                fprintf(stderr, "%s\n", node_type_str(node->type));
+                AstNode *child = node->data.fn_def.fn_proto;
+                ast_print(child, indent + 2);
+                ast_print(node->data.fn_def.body, indent + 2);
+                break;
+            }
+        case NodeTypeFnProto:
             {
-                Buf *name_buf = &node->data.fn_decl.name;
+                Buf *name_buf = &node->data.fn_proto.name;
                 fprintf(stderr, "%s '%s'\n", node_type_str(node->type), buf_ptr(name_buf));
 
-                for (int i = 0; i < node->data.fn_decl.params.length; i += 1) {
-                    AstNode *child = node->data.fn_decl.params.at(i);
+                for (int i = 0; i < node->data.fn_proto.params.length; i += 1) {
+                    AstNode *child = node->data.fn_proto.params.at(i);
                     ast_print(child, indent + 2);
                 }
 
-                ast_print(node->data.fn_decl.return_type, indent + 2);
-
-                ast_print(node->data.fn_decl.body, indent + 2);
+                ast_print(node->data.fn_proto.return_type, indent + 2);
 
                 break;
             }
@@ -115,6 +127,19 @@ void ast_print(AstNode *node, int indent) {
                     break;
             }
             break;
+        case NodeTypeExternBlock:
+            {
+                fprintf(stderr, "%s\n", node_type_str(node->type));
+                for (int i = 0; i < node->data.extern_block.fn_decls.length; i += 1) {
+                    AstNode *child = node->data.extern_block.fn_decls.at(i);
+                    ast_print(child, indent + 2);
+                }
+                break;
+            }
+        case NodeTypeFnDecl:
+            fprintf(stderr, "%s\n", node_type_str(node->type));
+            ast_print(node->data.fn_decl.fn_proto, indent + 2);
+            break;
         default:
             fprintf(stderr, "%s\n", node_type_str(node->type));
             break;
@@ -135,10 +160,52 @@ static AstNode *ast_create_node(NodeType type, Token *first_token) {
     return node;
 }
 
+static AstNode *ast_create_node_with_node(NodeType type, AstNode *other_node) {
+    AstNode *node = allocate<AstNode>(1);
+    node->type = type;
+    node->line = other_node->line;
+    node->column = other_node->column;
+    return node;
+}
+
 static void ast_buf_from_token(ParseContext *pc, Token *token, Buf *buf) {
     buf_init_from_mem(buf, buf_ptr(pc->buf) + token->start_pos, token->end_pos - token->start_pos);
 }
 
+static void parse_string_literal(ParseContext *pc, Token *token, Buf *buf) {
+    // skip the double quotes at beginning and end
+    // convert escape sequences
+    bool escape = false;
+    for (int i = token->start_pos; i < token->end_pos - 1; i += 1) {
+        uint8_t c = *((uint8_t*)buf_ptr(pc->buf) + i);
+        if (escape) {
+            switch (c) {
+                case '\\':
+                    buf_append_char(buf, '\\');
+                    break;
+                case 'r':
+                    buf_append_char(buf, '\r');
+                    break;
+                case 'n':
+                    buf_append_char(buf, '\n');
+                    break;
+                case 't':
+                    buf_append_char(buf, '\t');
+                    break;
+                case '"':
+                    buf_append_char(buf, '"');
+                    break;
+            }
+            escape = false;
+        } else if (c == '\\') {
+            escape = true;
+        } else {
+            buf_append_char(buf, c);
+        }
+    }
+    assert(!escape);
+}
+
 static void ast_invalid_token_error(ParseContext *pc, Token *token) {
     Buf token_value = {0};
     ast_buf_from_token(pc, token, &token_value);
@@ -304,7 +371,7 @@ static AstNode *ast_parse_expression(ParseContext *pc, int token_index, int *new
         token_index += 1;
     } else if (token->id == TokenIdStringLiteral) {
         node->data.expression.type = AstNodeExpressionTypeString;
-        ast_buf_from_token(pc, token, &node->data.expression.data.string);
+        parse_string_literal(pc, token, &node->data.expression.data.string);
         token_index += 1;
     } else {
         ast_invalid_token_error(pc, token);
@@ -381,50 +448,111 @@ static AstNode *ast_parse_block(ParseContext *pc, int token_index, int *new_toke
 }
 
 /*
-FnDecl : token(Fn) token(Symbol) ParamDeclList option(token(Arrow) Type) Block;
+FnProto : token(Fn) token(Symbol) ParamDeclList option(token(Arrow) Type)
 */
-static AstNode *ast_parse_fn_decl(ParseContext *pc, int token_index, int *new_token_index) {
+static AstNode *ast_parse_fn_proto(ParseContext *pc, int token_index, int *new_token_index) {
     Token *fn_token = &pc->tokens->at(token_index);
     token_index += 1;
     ast_expect_token(pc, fn_token, TokenIdKeywordFn);
 
-    AstNode *node = ast_create_node(NodeTypeFnDecl, fn_token);
+    AstNode *node = ast_create_node(NodeTypeFnProto, fn_token);
 
 
     Token *fn_name = &pc->tokens->at(token_index);
     token_index += 1;
     ast_expect_token(pc, fn_name, TokenIdSymbol);
 
-    ast_buf_from_token(pc, fn_name, &node->data.fn_decl.name);
+    ast_buf_from_token(pc, fn_name, &node->data.fn_proto.name);
 
 
-    ast_parse_param_decl_list(pc, token_index, &token_index, &node->data.fn_decl.params);
+    ast_parse_param_decl_list(pc, token_index, &token_index, &node->data.fn_proto.params);
 
     Token *arrow = &pc->tokens->at(token_index);
     token_index += 1;
     if (arrow->id == TokenIdArrow) {
-        node->data.fn_decl.return_type = ast_parse_type(pc, token_index, &token_index);
+        node->data.fn_proto.return_type = ast_parse_type(pc, token_index, &token_index);
     } else if (arrow->id == TokenIdLBrace) {
-        node->data.fn_decl.return_type = nullptr;
+        node->data.fn_proto.return_type = nullptr;
     } else {
         ast_invalid_token_error(pc, arrow);
     }
 
-    node->data.fn_decl.body = ast_parse_block(pc, token_index, &token_index);
+    *new_token_index = token_index;
+    return node;
+}
+
+/*
+FnDef : FnProto Block
+*/
+static AstNode *ast_parse_fn_def(ParseContext *pc, int token_index, int *new_token_index) {
+    AstNode *fn_proto = ast_parse_fn_proto(pc, token_index, &token_index);
+    AstNode *node = ast_create_node_with_node(NodeTypeFnDef, fn_proto);
+
+    node->data.fn_def.fn_proto = fn_proto;
+    node->data.fn_def.body = ast_parse_block(pc, token_index, &token_index);
+
+    *new_token_index = token_index;
+    return node;
+}
+
+/*
+FnDecl : FnProto token(Semicolon)
+*/
+static AstNode *ast_parse_fn_decl(ParseContext *pc, int token_index, int *new_token_index) {
+    AstNode *fn_proto = ast_parse_fn_proto(pc, token_index, &token_index);
+    AstNode *node = ast_create_node_with_node(NodeTypeFnDecl, fn_proto);
+
+    node->data.fn_decl.fn_proto = fn_proto;
+
+    Token *semicolon = &pc->tokens->at(token_index);
+    token_index += 1;
+    ast_expect_token(pc, semicolon, TokenIdSemicolon);
 
     *new_token_index = token_index;
     return node;
 }
 
+/*
+ExternBlock : token(Extern) token(LBrace) many(FnProtoDecl) token(RBrace)
+*/
+static AstNode *ast_parse_extern_block(ParseContext *pc, int token_index, int *new_token_index) {
+    Token *extern_kw = &pc->tokens->at(token_index);
+    token_index += 1;
+    ast_expect_token(pc, extern_kw, TokenIdKeywordExtern);
+
+    AstNode *node = ast_create_node(NodeTypeExternBlock, extern_kw);
+
+    Token *l_brace = &pc->tokens->at(token_index);
+    token_index += 1;
+    ast_expect_token(pc, l_brace, TokenIdLBrace);
+
+    for (;;) {
+        Token *token = &pc->tokens->at(token_index);
+        if (token->id == TokenIdRBrace) {
+            token_index += 1;
+            *new_token_index = token_index;
+            return node;
+        } else {
+            AstNode *child = ast_parse_fn_decl(pc, token_index, &token_index);
+            node->data.extern_block.fn_decls.append(child);
+        }
+    }
+
+
+    zig_unreachable();
+}
 
-static void ast_parse_fn_decl_list(ParseContext *pc, int token_index, ZigList<AstNode *> *fn_decls,
-        int *new_token_index)
+static void ast_parse_top_level_decls(ParseContext *pc, int token_index, int *new_token_index,
+        ZigList<AstNode *> *top_level_decls)
 {
     for (;;) {
         Token *token = &pc->tokens->at(token_index);
         if (token->id == TokenIdKeywordFn) {
-            AstNode *fn_decl_node = ast_parse_fn_decl(pc, token_index, &token_index);
-            fn_decls->append(fn_decl_node);
+            AstNode *fn_decl_node = ast_parse_fn_def(pc, token_index, &token_index);
+            top_level_decls->append(fn_decl_node);
+        } else if (token->id == TokenIdKeywordExtern) {
+            AstNode *extern_node = ast_parse_extern_block(pc, token_index, &token_index);
+            top_level_decls->append(extern_node);
         } else {
             *new_token_index = token_index;
             return;
@@ -440,7 +568,7 @@ AstNode *ast_parse(Buf *buf, ZigList<Token> *tokens) {
     pc.tokens = tokens;
 
     int new_token_index;
-    ast_parse_fn_decl_list(&pc, 0, &pc.root->data.root.fn_decls, &new_token_index);
+    ast_parse_top_level_decls(&pc, 0, &new_token_index, &pc.root->data.root.top_level_decls);
 
     if (new_token_index != tokens->length - 1) {
         ast_invalid_token_error(&pc, &tokens->at(new_token_index));
src/parser.hpp
@@ -10,6 +10,8 @@ struct CodeGenNode;
 
 enum NodeType {
     NodeTypeRoot,
+    NodeTypeFnProto,
+    NodeTypeFnDef,
     NodeTypeFnDecl,
     NodeTypeParamDecl,
     NodeTypeType,
@@ -17,19 +19,28 @@ enum NodeType {
     NodeTypeStatement,
     NodeTypeExpression,
     NodeTypeFnCall,
+    NodeTypeExternBlock,
 };
 
 struct AstNodeRoot {
-    ZigList<AstNode *> fn_decls;
+    ZigList<AstNode *> top_level_decls;
 };
 
-struct AstNodeFnDecl {
+struct AstNodeFnProto {
     Buf name;
     ZigList<AstNode *> params;
     AstNode *return_type;
+};
+
+struct AstNodeFnDef {
+    AstNode *fn_proto;
     AstNode *body;
 };
 
+struct AstNodeFnDecl {
+    AstNode *fn_proto;
+};
+
 struct AstNodeParamDecl {
     Buf name;
     AstNode *type;
@@ -92,6 +103,10 @@ struct AstNodeFnCall {
     ZigList<AstNode *> params;
 };
 
+struct AstNodeExternBlock {
+    ZigList<AstNode *> fn_decls;
+};
+
 struct AstNode {
     enum NodeType type;
     AstNode *parent;
@@ -100,13 +115,16 @@ struct AstNode {
     CodeGenNode *codegen_node;
     union {
         AstNodeRoot root;
+        AstNodeFnDef fn_def;
         AstNodeFnDecl fn_decl;
+        AstNodeFnProto fn_proto;
         AstNodeType type;
         AstNodeParamDecl param_decl;
         AstNodeBlock block;
         AstNodeStatement statement;
         AstNodeExpression expression;
         AstNodeFnCall fn_call;
+        AstNodeExternBlock extern_block;
     } data;
 };
 
src/tokenizer.cpp
@@ -150,6 +150,8 @@ static void end_token(Tokenize *t) {
         t->cur_tok->id = TokenIdKeywordMut;
     } else if (mem_eql_str(token_mem, token_len, "const")) {
         t->cur_tok->id = TokenIdKeywordConst;
+    } else if (mem_eql_str(token_mem, token_len, "extern")) {
+        t->cur_tok->id = TokenIdKeywordExtern;
     }
 
     t->cur_tok = nullptr;
@@ -307,6 +309,7 @@ static const char * token_name(Token *token) {
         case TokenIdKeywordConst: return "Const";
         case TokenIdKeywordMut: return "Mut";
         case TokenIdKeywordReturn: return "Return";
+        case TokenIdKeywordExtern: return "Extern";
         case TokenIdLParen: return "LParen";
         case TokenIdRParen: return "RParen";
         case TokenIdComma: return "Comma";
src/tokenizer.hpp
@@ -17,6 +17,7 @@ enum TokenId {
     TokenIdKeywordReturn,
     TokenIdKeywordMut,
     TokenIdKeywordConst,
+    TokenIdKeywordExtern,
     TokenIdLParen,
     TokenIdRParen,
     TokenIdComma,
test/hello.zig
@@ -1,3 +1,7 @@
+extern {
+    fn puts(s: *mut u8) -> i32;
+}
+
 fn main(argc: i32, argv: *mut *mut u8) -> i32 {
     puts("Hello, world!\n");
     return 0;
README.md
@@ -72,27 +72,35 @@ zig    | C equivalent | Description
 ### Grammar
 
 ```
-Root : many(FnDecl) token(EOF);
+Root : many(TopLevelDecl) token(EOF)
 
-FnDecl : token(Fn) token(Symbol) ParamDeclList option(token(Arrow) Type) Block;
+TopLevelDecl : FnDef | ExternBlock
 
-ParamDeclList : token(LParen) list(ParamDecl, token(Comma)) token(RParen);
+ExternBlock : token(Extern) token(LBrace) many(FnProtoDecl) token(RBrace)
 
-ParamDecl : token(Symbol) token(Colon) Type;
+FnProto : token(Fn) token(Symbol) ParamDeclList option(token(Arrow) Type)
 
-Type : token(Symbol) | PointerType;
+FnDecl : FnProto token(Semicolon)
 
-PointerType : token(Star) token(Const) Type  | token(Star) token(Mut) Type;
+FnDef : FnProto Block
 
-Block : token(LBrace) many(Statement) token(RBrace);
+ParamDeclList : token(LParen) list(ParamDecl, token(Comma)) token(RParen)
 
-Statement : ExpressionStatement  | ReturnStatement ;
+ParamDecl : token(Symbol) token(Colon) Type
 
-ExpressionStatement : Expression token(Semicolon) ;
+Type : token(Symbol) | PointerType
 
-ReturnStatement : token(Return) Expression token(Semicolon) ;
+PointerType : token(Star) token(Const) Type  | token(Star) token(Mut) Type
 
-Expression : token(Number)  | token(String)  | FnCall ;
+Block : token(LBrace) many(Statement) token(RBrace)
 
-FnCall : token(Symbol) token(LParen) list(Expression, token(Comma)) token(RParen) ;
+Statement : ExpressionStatement  | ReturnStatement
+
+ExpressionStatement : Expression token(Semicolon)
+
+ReturnStatement : token(Return) Expression token(Semicolon)
+
+Expression : token(Number) | token(String) | FnCall
+
+FnCall : token(Symbol) token(LParen) list(Expression, token(Comma)) token(RParen)
 ```