Commit a73453a268

Andrew Kelley <superjoe30@gmail.com>
2016-01-27 00:00:39
add c_import top level decl
see #88
1 parent 5afe473
src/all_types.hpp
@@ -136,7 +136,8 @@ enum NodeType {
     NodeTypeArrayAccessExpr,
     NodeTypeSliceExpr,
     NodeTypeFieldAccessExpr,
-    NodeTypeUse,
+    NodeTypeImport,
+    NodeTypeCImport,
     NodeTypeBoolLiteral,
     NodeTypeNullLiteral,
     NodeTypeUndefinedLiteral,
@@ -410,7 +411,7 @@ struct AstNodePrefixOpExpr {
     Expr resolved_expr;
 };
 
-struct AstNodeUse {
+struct AstNodeImport {
     Buf path;
     ZigList<AstNode *> *directives;
     VisibMod visib_mod;
@@ -419,6 +420,15 @@ struct AstNodeUse {
     ImportTableEntry *import;
 };
 
+struct AstNodeCImport {
+    ZigList<AstNode *> *directives;
+    VisibMod visib_mod;
+    AstNode *block;
+
+    // populated by semantic analyzer
+    TopLevelDecl top_level_decl;
+};
+
 struct AstNodeIfBoolExpr {
     AstNode *condition;
     AstNode *then_block;
@@ -699,7 +709,8 @@ struct AstNode {
         AstNodeFnCallExpr fn_call_expr;
         AstNodeArrayAccessExpr array_access_expr;
         AstNodeSliceExpr slice_expr;
-        AstNodeUse use;
+        AstNodeImport import;
+        AstNodeCImport c_import;
         AstNodeIfBoolExpr if_bool_expr;
         AstNodeIfVarExpr if_var_expr;
         AstNodeWhileExpr while_expr;
@@ -920,6 +931,9 @@ enum BuiltinFnId {
     BuiltinFnIdAddWithOverflow,
     BuiltinFnIdSubWithOverflow,
     BuiltinFnIdMulWithOverflow,
+    BuiltinFnIdCInclude,
+    BuiltinFnIdCDefine,
+    BuiltinFnIdCUndef,
 };
 
 struct BuiltinFnEntry {
@@ -1051,6 +1065,7 @@ struct BlockContext {
     ZigList<VariableTableEntry *> variable_list;
     AstNode *parent_loop_node;
     LLVMZigDIScope *di_scope;
+    Buf *c_import_buf;
 };
 
 #endif
src/analyze.cpp
@@ -21,6 +21,8 @@ static TypeTableEntry *analyze_cast_expr(CodeGen *g, ImportTableEntry *import, B
         AstNode *node);
 static TypeTableEntry *analyze_error_literal_expr(CodeGen *g, ImportTableEntry *import,
         BlockContext *context, AstNode *node, Buf *err_name);
+static TypeTableEntry *analyze_block_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
+        TypeTableEntry *expected_type, AstNode *node);
 
 static AstNode *first_executing_node(AstNode *node) {
     switch (node->type) {
@@ -54,7 +56,8 @@ static AstNode *first_executing_node(AstNode *node) {
         case NodeTypeCharLiteral:
         case NodeTypeSymbol:
         case NodeTypePrefixOpExpr:
-        case NodeTypeUse:
+        case NodeTypeImport:
+        case NodeTypeCImport:
         case NodeTypeBoolLiteral:
         case NodeTypeNullLiteral:
         case NodeTypeUndefinedLiteral:
@@ -1018,6 +1021,25 @@ static void resolve_error_value_decl(CodeGen *g, ImportTableEntry *import, AstNo
     }
 }
 
+static void resolve_c_import_decl(CodeGen *g, ImportTableEntry *import, AstNode *node) {
+    assert(node->type == NodeTypeCImport);
+
+    AstNode *block_node = node->data.c_import.block;
+
+    BlockContext *child_context = new_block_context(node, import->block_context);
+    child_context->c_import_buf = buf_alloc();
+
+    TypeTableEntry *resolved_type = analyze_block_expr(g, import, child_context,
+            g->builtin_types.entry_void, block_node);
+
+    if (resolved_type->id == TypeTableEntryIdInvalid) {
+        return;
+    }
+
+    fprintf(stderr, "c import buf:\n%s\n", buf_ptr(child_context->c_import_buf));
+    zig_panic("TODO");
+}
+
 static void resolve_top_level_decl(CodeGen *g, ImportTableEntry *import, AstNode *node) {
     switch (node->type) {
         case NodeTypeFnProto:
@@ -1053,9 +1075,12 @@ static void resolve_top_level_decl(CodeGen *g, ImportTableEntry *import, AstNode
         case NodeTypeErrorValueDecl:
             resolve_error_value_decl(g, import, node);
             break;
-        case NodeTypeUse:
+        case NodeTypeImport:
             // nothing to do here
             break;
+        case NodeTypeCImport:
+            resolve_c_import_decl(g, import, node);
+            break;
         case NodeTypeFnDef:
         case NodeTypeDirective:
         case NodeTypeParamDecl:
@@ -1501,6 +1526,7 @@ BlockContext *new_block_context(AstNode *node, BlockContext *parent) {
 
     if (parent) {
         context->parent_loop_node = parent->parent_loop_node;
+        context->c_import_buf = parent->c_import_buf;
     }
 
     if (node && node->type == NodeTypeFnDef) {
@@ -1997,7 +2023,7 @@ static TypeTableEntry *resolve_expr_const_val_as_float_num_lit(CodeGen *g, AstNo
     }
 }
 
-static TypeTableEntry *resolve_expr_const_val_as_bignum_op(CodeGen *g, AstNode *node, 
+static TypeTableEntry *resolve_expr_const_val_as_bignum_op(CodeGen *g, AstNode *node,
         bool (*bignum_fn)(BigNum *, BigNum *, BigNum *), AstNode *op1, AstNode *op2,
         TypeTableEntry *resolved_type)
 {
@@ -2521,7 +2547,7 @@ static VariableTableEntry *add_local_var(CodeGen *g, AstNode *source_node, Block
     if (name) {
         buf_init_from_buf(&variable_entry->name, name);
         VariableTableEntry *existing_var;
-        
+
         if (context->fn_entry) {
             existing_var = find_local_variable(context, name);
         } else {
@@ -3396,6 +3422,47 @@ static TypeTableEntry *analyze_builtin_fn_call_expr(CodeGen *g, ImportTableEntry
                     return resolve_expr_const_val_as_type(g, node, type_entry);
                 }
             }
+        case BuiltinFnIdCInclude:
+            {
+                if (!context->c_import_buf) {
+                    add_node_error(g, node, buf_sprintf("@c_include valid only in c_import blocks"));
+                    return g->builtin_types.entry_invalid;
+                }
+
+                AstNode **str_node = node->data.fn_call_expr.params.at(0)->parent_field;
+                TypeTableEntry *str_type = get_unknown_size_array_type(g, g->builtin_types.entry_u8, true);
+                TypeTableEntry *resolved_type = analyze_expression(g, import, context, str_type, *str_node);
+
+                if (resolved_type->id == TypeTableEntryIdInvalid) {
+                    return resolved_type;
+                }
+
+                ConstExprValue *const_str_val = &get_resolved_expr(*str_node)->const_val;
+
+                if (!const_str_val->ok) {
+                    add_node_error(g, *str_node, buf_sprintf("@c_include requires constant expression"));
+                    return g->builtin_types.entry_void;
+                }
+
+                buf_appendf(context->c_import_buf, "#include \"");
+                ConstExprValue *ptr_field = const_str_val->data.x_struct.fields[0];
+                uint64_t len = ptr_field->data.x_ptr.len;
+                for (uint64_t i = 0; i < len; i += 1) {
+                    ConstExprValue *char_val = ptr_field->data.x_ptr.ptr[i];
+                    uint64_t big_c = char_val->data.x_bignum.data.x_uint;
+                    assert(big_c <= UINT8_MAX);
+                    uint8_t c = big_c;
+                    buf_appendf(context->c_import_buf, "%c", c);
+                }
+                buf_appendf(context->c_import_buf, "\"\n");
+
+                return g->builtin_types.entry_void;
+            }
+        case BuiltinFnIdCDefine:
+            zig_panic("TODO");
+        case BuiltinFnIdCUndef:
+            zig_panic("TODO");
+
     }
     zig_unreachable();
 }
@@ -4038,7 +4105,8 @@ static TypeTableEntry *analyze_expression(CodeGen *g, ImportTableEntry *import,
         case NodeTypeRoot:
         case NodeTypeRootExportDecl:
         case NodeTypeFnDef:
-        case NodeTypeUse:
+        case NodeTypeImport:
+        case NodeTypeCImport:
         case NodeTypeLabel:
         case NodeTypeStructDecl:
         case NodeTypeStructField:
@@ -4140,7 +4208,8 @@ static void analyze_top_level_decl(CodeGen *g, ImportTableEntry *import, AstNode
                 break;
             }
         case NodeTypeRootExportDecl:
-        case NodeTypeUse:
+        case NodeTypeImport:
+        case NodeTypeCImport:
         case NodeTypeVariableDeclaration:
         case NodeTypeErrorValueDecl:
         case NodeTypeFnProto:
@@ -4340,7 +4409,8 @@ static void collect_expr_decl_deps(CodeGen *g, ImportTableEntry *import, AstNode
         case NodeTypeFnDecl:
         case NodeTypeParamDecl:
         case NodeTypeDirective:
-        case NodeTypeUse:
+        case NodeTypeImport:
+        case NodeTypeCImport:
         case NodeTypeLabel:
         case NodeTypeStructDecl:
         case NodeTypeStructField:
@@ -4486,9 +4556,24 @@ static void detect_top_level_decl_deps(CodeGen *g, ImportTableEntry *import, Ast
         case NodeTypeRootExportDecl:
             resolve_top_level_decl(g, import, node);
             break;
-        case NodeTypeUse:
+        case NodeTypeImport:
             // already taken care of
             break;
+        case NodeTypeCImport:
+            {
+                TopLevelDecl *decl_node = &node->data.c_import.top_level_decl;
+                decl_node->deps.init(1);
+                collect_expr_decl_deps(g, import, node->data.c_import.block, decl_node);
+
+                decl_node->name = buf_sprintf("c_import_%" PRIu32, node->create_index);
+                decl_node->import = import;
+                if (decl_node->deps.size() > 0) {
+                    g->unresolved_top_level_decls.put(decl_node->name, node);
+                } else {
+                    resolve_top_level_decl(g, import, node);
+                }
+                break;
+            }
         case NodeTypeErrorValueDecl:
             // error value declarations do not depend on other top level decls
             resolve_top_level_decl(g, import, node);
@@ -4619,15 +4704,15 @@ void semantic_analyze(CodeGen *g) {
 
             for (int i = 0; i < import->root->data.root.top_level_decls.length; i += 1) {
                 AstNode *child = import->root->data.root.top_level_decls.at(i);
-                if (child->type == NodeTypeUse) {
-                    for (int i = 0; i < child->data.use.directives->length; i += 1) {
-                        AstNode *directive_node = child->data.use.directives->at(i);
+                if (child->type == NodeTypeImport) {
+                    for (int i = 0; i < child->data.import.directives->length; i += 1) {
+                        AstNode *directive_node = child->data.import.directives->at(i);
                         Buf *name = &directive_node->data.directive.name;
                         add_node_error(g, directive_node,
                                 buf_sprintf("invalid directive: '%s'", buf_ptr(name)));
                     }
 
-                    ImportTableEntry *target_import = child->data.use.import;
+                    ImportTableEntry *target_import = child->data.import.import;
                     assert(target_import);
 
                     target_import->importers.append({import, child});
@@ -4760,7 +4845,8 @@ Expr *get_resolved_expr(AstNode *node) {
         case NodeTypeFnDecl:
         case NodeTypeParamDecl:
         case NodeTypeDirective:
-        case NodeTypeUse:
+        case NodeTypeImport:
+        case NodeTypeCImport:
         case NodeTypeStructDecl:
         case NodeTypeStructField:
         case NodeTypeStructValueField:
@@ -4780,6 +4866,8 @@ TopLevelDecl *get_resolved_top_level_decl(AstNode *node) {
             return &node->data.struct_decl.top_level_decl;
         case NodeTypeErrorValueDecl:
             return &node->data.error_value_decl.top_level_decl;
+        case NodeTypeCImport:
+            return &node->data.c_import.top_level_decl;
         case NodeTypeNumberLiteral:
         case NodeTypeReturnExpr:
         case NodeTypeBinOpExpr:
@@ -4808,7 +4896,7 @@ TopLevelDecl *get_resolved_top_level_decl(AstNode *node) {
         case NodeTypeStringLiteral:
         case NodeTypeCharLiteral:
         case NodeTypeSymbol:
-        case NodeTypeUse:
+        case NodeTypeImport:
         case NodeTypeBoolLiteral:
         case NodeTypeNullLiteral:
         case NodeTypeUndefinedLiteral:
src/codegen.cpp
@@ -177,6 +177,9 @@ static LLVMValueRef gen_builtin_fn_call_expr(CodeGen *g, AstNode *node) {
     switch (builtin_fn->id) {
         case BuiltinFnIdInvalid:
         case BuiltinFnIdTypeof:
+        case BuiltinFnIdCInclude:
+        case BuiltinFnIdCDefine:
+        case BuiltinFnIdCUndef:
             zig_unreachable();
         case BuiltinFnIdAddWithOverflow:
         case BuiltinFnIdSubWithOverflow:
@@ -2250,7 +2253,8 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
         case NodeTypeFnDecl:
         case NodeTypeParamDecl:
         case NodeTypeDirective:
-        case NodeTypeUse:
+        case NodeTypeImport:
+        case NodeTypeCImport:
         case NodeTypeStructDecl:
         case NodeTypeStructField:
         case NodeTypeStructValueField:
@@ -2930,6 +2934,9 @@ static void define_builtin_fns(CodeGen *g) {
     create_builtin_fn_with_arg_count(g, BuiltinFnIdAddWithOverflow, "add_with_overflow", 4);
     create_builtin_fn_with_arg_count(g, BuiltinFnIdSubWithOverflow, "sub_with_overflow", 4);
     create_builtin_fn_with_arg_count(g, BuiltinFnIdMulWithOverflow, "mul_with_overflow", 4);
+    create_builtin_fn_with_arg_count(g, BuiltinFnIdCInclude, "c_include", 1);
+    create_builtin_fn_with_arg_count(g, BuiltinFnIdCDefine, "c_define", 2);
+    create_builtin_fn_with_arg_count(g, BuiltinFnIdCUndef, "c_undef", 1);
 }
 
 
@@ -3138,8 +3145,8 @@ static ImportTableEntry *codegen_add_code(CodeGen *g, Buf *abs_full_path,
                     }
                 }
             }
-        } else if (top_level_decl->type == NodeTypeUse) {
-            Buf *import_target_path = &top_level_decl->data.use.path;
+        } else if (top_level_decl->type == NodeTypeImport) {
+            Buf *import_target_path = &top_level_decl->data.import.path;
             Buf full_path = BUF_INIT;
             Buf *import_code = buf_alloc();
             bool found_it = false;
@@ -3163,7 +3170,7 @@ static ImportTableEntry *codegen_add_code(CodeGen *g, Buf *abs_full_path,
                 auto entry = g->import_table.maybe_get(abs_full_path);
                 if (entry) {
                     found_it = true;
-                    top_level_decl->data.use.import = entry->value;
+                    top_level_decl->data.import.import = entry->value;
                 } else {
                     if ((err = os_fetch_file_path(abs_full_path, import_code))) {
                         if (err == ErrorFileNotFound) {
@@ -3175,8 +3182,8 @@ static ImportTableEntry *codegen_add_code(CodeGen *g, Buf *abs_full_path,
                             goto done_looking_at_imports;
                         }
                     }
-                    top_level_decl->data.use.import = codegen_add_code(g,
-                            abs_full_path, search_path, &top_level_decl->data.use.path, import_code);
+                    top_level_decl->data.import.import = codegen_add_code(g,
+                            abs_full_path, search_path, &top_level_decl->data.import.path, import_code);
                     found_it = true;
                 }
                 break;
src/parser.cpp
@@ -123,8 +123,10 @@ const char *node_type_str(NodeType node_type) {
             return "Symbol";
         case NodeTypePrefixOpExpr:
             return "PrefixOpExpr";
-        case NodeTypeUse:
-            return "Use";
+        case NodeTypeImport:
+            return "Import";
+        case NodeTypeCImport:
+            return "CImport";
         case NodeTypeBoolLiteral:
             return "BoolLiteral";
         case NodeTypeNullLiteral:
@@ -329,8 +331,12 @@ void ast_print(AstNode *node, int indent) {
         case NodeTypeSymbol:
             fprintf(stderr, "Symbol %s\n", buf_ptr(&node->data.symbol_expr.symbol));
             break;
-        case NodeTypeUse:
-            fprintf(stderr, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.use.path));
+        case NodeTypeImport:
+            fprintf(stderr, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.import.path));
+            break;
+        case NodeTypeCImport:
+            fprintf(stderr, "%s\n", node_type_str(node->type));
+            ast_print(node->data.c_import.block, indent + 2);
             break;
         case NodeTypeBoolLiteral:
             fprintf(stderr, "%s '%s'\n", node_type_str(node->type),
@@ -2768,19 +2774,35 @@ static AstNode *ast_parse_import(ParseContext *pc, int *token_index,
         return nullptr;
     *token_index += 1;
 
-    Token *use_name = &pc->tokens->at(*token_index);
-    *token_index += 1;
-    ast_expect_token(pc, use_name, TokenIdStringLiteral);
+    Token *import_name = ast_eat_token(pc, token_index, TokenIdStringLiteral);
 
-    Token *semicolon = &pc->tokens->at(*token_index);
+    ast_eat_token(pc, token_index, TokenIdSemicolon);
+
+    AstNode *node = ast_create_node(pc, NodeTypeImport, import_kw);
+    node->data.import.visib_mod = visib_mod;
+    node->data.import.directives = directives;
+
+    parse_string_literal(pc, import_name, &node->data.import.path, nullptr, nullptr);
+    normalize_parent_ptrs(node);
+    return node;
+}
+
+/*
+CImportDecl : "c_import" Block
+*/
+static AstNode *ast_parse_c_import(ParseContext *pc, int *token_index,
+        ZigList<AstNode*> *directives, VisibMod visib_mod)
+{
+    Token *c_import_kw = &pc->tokens->at(*token_index);
+    if (c_import_kw->id != TokenIdKeywordCImport)
+        return nullptr;
     *token_index += 1;
-    ast_expect_token(pc, semicolon, TokenIdSemicolon);
 
-    AstNode *node = ast_create_node(pc, NodeTypeUse, import_kw);
-    node->data.use.visib_mod = visib_mod;
-    node->data.use.directives = directives;
+    AstNode *node = ast_create_node(pc, NodeTypeCImport, c_import_kw);
+    node->data.c_import.visib_mod = visib_mod;
+    node->data.c_import.directives = directives;
+    node->data.c_import.block = ast_parse_block(pc, token_index, true);
 
-    parse_string_literal(pc, use_name, &node->data.use.path, nullptr, nullptr);
     normalize_parent_ptrs(node);
     return node;
 }
@@ -2953,6 +2975,12 @@ static void ast_parse_top_level_decls(ParseContext *pc, int *token_index, ZigLis
             continue;
         }
 
+        AstNode *c_import_node = ast_parse_c_import(pc, token_index, directives, visib_mod);
+        if (c_import_node) {
+            top_level_decls->append(c_import_node);
+            continue;
+        }
+
         AstNode *struct_node = ast_parse_struct_decl(pc, token_index, directives, visib_mod);
         if (struct_node) {
             top_level_decls->append(struct_node);
@@ -3103,8 +3131,12 @@ void normalize_parent_ptrs(AstNode *node) {
         case NodeTypeFieldAccessExpr:
             set_field(&node->data.field_access_expr.struct_expr);
             break;
-        case NodeTypeUse:
-            set_list_fields(node->data.use.directives);
+        case NodeTypeImport:
+            set_list_fields(node->data.import.directives);
+            break;
+        case NodeTypeCImport:
+            set_list_fields(node->data.c_import.directives);
+            set_field(&node->data.c_import.block);
             break;
         case NodeTypeBoolLiteral:
             // none
src/tokenizer.cpp
@@ -211,6 +211,8 @@ static void end_token(Tokenize *t) {
         t->cur_tok->id = TokenIdKeywordPub;
     } else if (mem_eql_str(token_mem, token_len, "export")) {
         t->cur_tok->id = TokenIdKeywordExport;
+    } else if (mem_eql_str(token_mem, token_len, "c_import")) {
+        t->cur_tok->id = TokenIdKeywordCImport;
     } else if (mem_eql_str(token_mem, token_len, "import")) {
         t->cur_tok->id = TokenIdKeywordImport;
     } else if (mem_eql_str(token_mem, token_len, "true")) {
@@ -1041,6 +1043,7 @@ const char * token_name(TokenId id) {
         case TokenIdKeywordPub: return "pub";
         case TokenIdKeywordExport: return "export";
         case TokenIdKeywordImport: return "import";
+        case TokenIdKeywordCImport: return "c_import";
         case TokenIdKeywordTrue: return "true";
         case TokenIdKeywordFalse: return "false";
         case TokenIdKeywordIf: return "if";
src/tokenizer.hpp
@@ -21,6 +21,7 @@ enum TokenId {
     TokenIdKeywordPub,
     TokenIdKeywordExport,
     TokenIdKeywordImport,
+    TokenIdKeywordCImport,
     TokenIdKeywordTrue,
     TokenIdKeywordFalse,
     TokenIdKeywordIf,