Commit 3a326d5005

Andrew Kelley <superjoe30@gmail.com>
2016-01-18 16:50:10
pave the road for function pointers
See #14
1 parent 4c50606
src/all_types.hpp
@@ -793,11 +793,6 @@ struct LabelTableEntry {
     bool entered_from_fallthrough;
 };
 
-enum FnAttrId {
-    FnAttrIdNaked,
-    FnAttrIdAlwaysInline,
-};
-
 struct FnTableEntry {
     LLVMValueRef fn_value;
     AstNode *proto_node;
@@ -806,7 +801,8 @@ struct FnTableEntry {
     bool internal_linkage;
     unsigned calling_convention;
     ImportTableEntry *import_entry;
-    ZigList<FnAttrId> fn_attr_list;
+    bool is_naked;
+    bool is_inline;
     // Required to be a pre-order traversal of the AST. (parents must come before children)
     ZigList<BlockContext *> all_block_contexts;
     TypeTableEntry *member_of_struct;
src/analyze.cpp
@@ -346,18 +346,19 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t
         ImportTableEntry *import)
 {
     assert(node->type == NodeTypeFnProto);
+    AstNodeFnProto *fn_proto = &node->data.fn_proto;
 
-    for (int i = 0; i < node->data.fn_proto.directives->length; i += 1) {
-        AstNode *directive_node = node->data.fn_proto.directives->at(i);
+    for (int i = 0; i < fn_proto->directives->length; i += 1) {
+        AstNode *directive_node = fn_proto->directives->at(i);
         Buf *name = &directive_node->data.directive.name;
 
         if (buf_eql_str(name, "attribute")) {
             Buf *attr_name = &directive_node->data.directive.param;
             if (fn_table_entry->fn_def_node) {
                 if (buf_eql_str(attr_name, "naked")) {
-                    fn_table_entry->fn_attr_list.append(FnAttrIdNaked);
+                    fn_table_entry->is_naked = true;
                 } else if (buf_eql_str(attr_name, "inline")) {
-                    fn_table_entry->fn_attr_list.append(FnAttrIdAlwaysInline);
+                    fn_table_entry->is_inline = true;
                 } else {
                     add_node_error(g, directive_node,
                             buf_sprintf("invalid function attribute: '%s'", buf_ptr(name)));
@@ -382,38 +383,100 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t
 
     buf_resize(&fn_type->name, 0);
     buf_appendf(&fn_type->name, "fn(");
+    int gen_param_count = 0;
+    LLVMTypeRef *gen_param_types = allocate<LLVMTypeRef>(param_count);
+    LLVMZigDIType **param_di_types = allocate<LLVMZigDIType*>(1 + param_count);
     for (int i = 0; i < param_count; i += 1) {
         AstNode *child = node->data.fn_proto.params.at(i);
         assert(child->type == NodeTypeParamDecl);
         TypeTableEntry *type_entry = analyze_type_expr(g, import, import->block_context,
                 child->data.param_decl.type);
-        fn_table_entry->type_entry->data.fn.param_types[i] = type_entry;
-
-        buf_appendf(&fn_type->name, "%s", buf_ptr(&type_entry->name));
-
-        if (i + 1 < param_count) {
-            buf_appendf(&fn_type->name, ", ");
-        }
+        fn_type->data.fn.param_types[i] = type_entry;
 
         if (type_entry->id == TypeTableEntryIdUnreachable) {
             add_node_error(g, child->data.param_decl.type,
                 buf_sprintf("parameter of type 'unreachable' not allowed"));
-        } else if (type_entry->id == TypeTableEntryIdVoid) {
-            if (node->data.fn_proto.visib_mod == VisibModExport) {
-                add_node_error(g, child->data.param_decl.type,
-                    buf_sprintf("parameter of type 'void' not allowed on exported functions"));
-            }
+            fn_proto->skip = true;
+        } else if (type_entry->id == TypeTableEntryIdInvalid) {
+            fn_proto->skip = true;
+        }
+
+        if (!fn_proto->skip && type_entry->size_in_bits > 0) {
+            const char *comma = (gen_param_count == 0) ? "" : ", ";
+            buf_appendf(&fn_type->name, "%s%s", comma, buf_ptr(&type_entry->name));
+
+            TypeTableEntry *gen_type = handle_is_ptr(type_entry) ?
+                get_pointer_to_type(g, type_entry, true) : type_entry;
+            gen_param_types[gen_param_count] = gen_type->type_ref;
+
+            gen_param_count += 1;
+
+            // after the gen_param_count += 1 because 0 is the return type
+            param_di_types[gen_param_count] = gen_type->di_type;
         }
     }
 
     TypeTableEntry *return_type = analyze_type_expr(g, import, import->block_context,
             node->data.fn_proto.return_type);
-    fn_table_entry->type_entry->data.fn.return_type = return_type;
+    fn_type->data.fn.return_type = return_type;
+    if (return_type->id == TypeTableEntryIdInvalid) {
+        fn_proto->skip = true;
+    }
 
     buf_appendf(&fn_type->name, ")");
     if (return_type->id != TypeTableEntryIdVoid) {
         buf_appendf(&fn_type->name, " %s", buf_ptr(&return_type->name));
     }
+
+    if (fn_proto->skip) {
+        return;
+    }
+
+    fn_type->type_ref = LLVMFunctionType(return_type->type_ref, gen_param_types, gen_param_count,
+            fn_proto->is_var_args);
+
+    fn_table_entry->fn_value = LLVMAddFunction(g->module, buf_ptr(&fn_table_entry->symbol_name),
+            fn_table_entry->type_entry->type_ref);
+
+    if (fn_table_entry->is_inline) {
+        LLVMAddFunctionAttr(fn_table_entry->fn_value, LLVMAlwaysInlineAttribute);
+    }
+    if (fn_table_entry->is_naked) {
+        LLVMAddFunctionAttr(fn_table_entry->fn_value, LLVMNakedAttribute);
+    }
+
+    LLVMSetLinkage(fn_table_entry->fn_value, fn_table_entry->internal_linkage ?
+            LLVMInternalLinkage : LLVMExternalLinkage);
+
+    if (return_type->id == TypeTableEntryIdUnreachable) {
+        LLVMAddFunctionAttr(fn_table_entry->fn_value, LLVMNoReturnAttribute);
+    }
+    LLVMSetFunctionCallConv(fn_table_entry->fn_value, fn_table_entry->calling_convention);
+    if (!fn_table_entry->is_extern) {
+        LLVMAddFunctionAttr(fn_table_entry->fn_value, LLVMNoUnwindAttribute);
+    }
+
+    param_di_types[0] = return_type->di_type;
+    LLVMZigDISubroutineType *di_sub_type = LLVMZigCreateSubroutineType(g->dbuilder, import->di_file,
+            param_di_types, gen_param_count + 1, 0);
+
+    // Add debug info.
+    unsigned line_number = node->line + 1;
+    unsigned scope_line = line_number;
+    bool is_definition = fn_table_entry->fn_def_node != nullptr;
+    unsigned flags = 0;
+    bool is_optimized = g->build_type == CodeGenBuildTypeRelease;
+    LLVMZigDISubprogram *subprogram = LLVMZigCreateFunction(g->dbuilder,
+        import->block_context->di_scope, buf_ptr(&fn_table_entry->symbol_name), "",
+        import->di_file, line_number,
+        di_sub_type, fn_table_entry->internal_linkage, 
+        is_definition, scope_line, flags, is_optimized, fn_table_entry->fn_value);
+    fn_type->di_type = LLVMZigSubroutineToType(di_sub_type);
+    if (fn_table_entry->fn_def_node) {
+        BlockContext *context = new_block_context(fn_table_entry->fn_def_node, import->block_context);
+        fn_table_entry->fn_def_node->data.fn_def.block_context = context;
+        context->di_scope = LLVMZigSubprogramToScope(subprogram);
+    }
 }
 
 static void preview_function_labels(CodeGen *g, AstNode *node, FnTableEntry *fn_table_entry) {
@@ -724,14 +787,6 @@ static void preview_fn_proto(CodeGen *g, ImportTableEntry *import,
                 buf_sprintf("redefinition of '%s'", buf_ptr(proto_name)));
         proto_node->data.fn_proto.skip = true;
         skip = true;
-    } else if (is_pub) {
-        // TODO is this else if branch a mistake?
-        auto entry = fn_table->maybe_get(proto_name);
-        if (entry) {
-            add_node_error(g, proto_node, buf_sprintf("redefinition of '%s'", buf_ptr(proto_name)));
-            proto_node->data.fn_proto.skip = true;
-            skip = true;
-        }
     }
     if (!extern_node && proto_node->data.fn_proto.is_var_args) {
         add_node_error(g, proto_node,
@@ -775,10 +830,8 @@ static void preview_fn_proto(CodeGen *g, ImportTableEntry *import,
         g->bootstrap_import->fn_table.put(proto_name, fn_table_entry);
     }
 
-    resolve_function_proto(g, proto_node, fn_table_entry, import);
-
-
     proto_node->data.fn_proto.fn_table_entry = fn_table_entry;
+    resolve_function_proto(g, proto_node, fn_table_entry, import);
 
     if (fn_def_node) {
         preview_function_labels(g, fn_def_node->data.fn_def.body, fn_table_entry);
@@ -3146,8 +3199,7 @@ static void analyze_top_level_fn_def(CodeGen *g, ImportTableEntry *import, AstNo
         return;
     }
 
-    BlockContext *context = new_block_context(node, import->block_context);
-    node->data.fn_def.block_context = context;
+    BlockContext *context = node->data.fn_def.block_context;
 
     AstNodeFnProto *fn_proto = &fn_proto_node->data.fn_proto;
     bool is_exported = (fn_proto->visib_mod == VisibModExport);
@@ -3923,3 +3975,11 @@ TypeTableEntry **get_int_type_ptr(CodeGen *g, bool is_signed, int size_in_bits)
 TypeTableEntry *get_int_type(CodeGen *g, bool is_signed, int size_in_bits) {
     return *get_int_type_ptr(g, is_signed, size_in_bits);
 }
+
+bool handle_is_ptr(TypeTableEntry *type_entry) {
+    return type_entry->id == TypeTableEntryIdStruct ||
+            (type_entry->id == TypeTableEntryIdEnum && type_entry->data.enumeration.gen_field_count != 0) ||
+            type_entry->id == TypeTableEntryIdMaybe ||
+            type_entry->id == TypeTableEntryIdArray;
+}
+
src/analyze.hpp
@@ -23,5 +23,5 @@ TopLevelDecl *get_resolved_top_level_decl(AstNode *node);
 bool is_node_void_expr(AstNode *node);
 TypeTableEntry **get_int_type_ptr(CodeGen *g, bool is_signed, int size_in_bits);
 TypeTableEntry *get_int_type(CodeGen *g, bool is_signed, int size_in_bits);
-
+bool handle_is_ptr(TypeTableEntry *type_entry);
 #endif
src/codegen.cpp
@@ -82,29 +82,13 @@ static TypeTableEntry *get_type_for_type_node(AstNode *node) {
     return const_val->data.x_type;
 }
 
-static TypeTableEntry *fn_proto_type_from_type_node(CodeGen *g, AstNode *type_node) {
-    TypeTableEntry *type_entry = get_type_for_type_node(type_node);
-
-    if (type_entry->id == TypeTableEntryIdStruct || type_entry->id == TypeTableEntryIdArray) {
-        return get_pointer_to_type(g, type_entry, true);
-    } else {
-        return type_entry;
-    }
-}
-
-static LLVMZigDIType *to_llvm_debug_type(CodeGen *g, AstNode *type_node) {
-    TypeTableEntry *type_entry = get_type_for_type_node(type_node);
-    return type_entry->di_type;
-}
-
-
 static bool type_is_unreachable(CodeGen *g, AstNode *type_node) {
     return get_type_for_type_node(type_node)->id == TypeTableEntryIdUnreachable;
 }
 
 static bool is_param_decl_type_void(CodeGen *g, AstNode *param_decl_node) {
     assert(param_decl_node->type == NodeTypeParamDecl);
-    return get_type_for_type_node(param_decl_node->data.param_decl.type)->id == TypeTableEntryIdVoid;
+    return get_type_for_type_node(param_decl_node->data.param_decl.type)->size_in_bits == 0;
 }
 
 static int count_non_void_params(CodeGen *g, ZigList<AstNode *> *params) {
@@ -150,11 +134,14 @@ static TypeTableEntry *get_expr_type(AstNode *node) {
     return expr->type_entry;
 }
 
-static bool handle_is_ptr(TypeTableEntry *type_entry) {
-    return type_entry->id == TypeTableEntryIdStruct ||
-            (type_entry->id == TypeTableEntryIdEnum && type_entry->data.enumeration.gen_field_count != 0) ||
-            type_entry->id == TypeTableEntryIdMaybe ||
-            type_entry->id == TypeTableEntryIdArray;
+static TypeTableEntry *fn_proto_type_from_type_node(CodeGen *g, AstNode *type_node) {
+    TypeTableEntry *type_entry = get_type_for_type_node(type_node);
+
+    if (handle_is_ptr(type_entry)) {
+        return get_pointer_to_type(g, type_entry, true);
+    } else {
+        return type_entry;
+    }
 }
 
 static LLVMValueRef gen_number_literal_raw(CodeGen *g, AstNode *source_node,
@@ -2121,31 +2108,6 @@ static void build_label_blocks(CodeGen *g, AstNode *block_node) {
 
 }
 
-static LLVMZigDISubroutineType *create_di_function_type(CodeGen *g, AstNodeFnProto *fn_proto,
-        LLVMZigDIFile *di_file)
-{
-    LLVMZigDIType **types = allocate<LLVMZigDIType*>(1 + fn_proto->params.length);
-    types[0] = to_llvm_debug_type(g, fn_proto->return_type);
-    int types_len = fn_proto->params.length + 1;
-    for (int i = 0; i < fn_proto->params.length; i += 1) {
-        AstNode *param_node = fn_proto->params.at(i);
-        assert(param_node->type == NodeTypeParamDecl);
-        LLVMZigDIType *param_type = to_llvm_debug_type(g, param_node->data.param_decl.type);
-        types[i + 1] = param_type;
-    }
-    return LLVMZigCreateSubroutineType(g->dbuilder, di_file, types, types_len, 0);
-}
-
-static LLVMAttribute to_llvm_fn_attr(FnAttrId attr_id) {
-    switch (attr_id) {
-        case FnAttrIdNaked:
-            return LLVMNakedAttribute;
-        case FnAttrIdAlwaysInline:
-            return LLVMAlwaysInlineAttribute;
-    }
-    zig_unreachable();
-}
-
 static void do_code_gen(CodeGen *g) {
     assert(!g->errors.length);
 
@@ -2172,45 +2134,12 @@ static void do_code_gen(CodeGen *g) {
     // Generate function prototypes
     for (int fn_proto_i = 0; fn_proto_i < g->fn_protos.length; fn_proto_i += 1) {
         FnTableEntry *fn_table_entry = g->fn_protos.at(fn_proto_i);
-
         AstNode *proto_node = fn_table_entry->proto_node;
         assert(proto_node->type == NodeTypeFnProto);
         AstNodeFnProto *fn_proto = &proto_node->data.fn_proto;
 
-        LLVMTypeRef ret_type = get_type_for_type_node(fn_proto->return_type)->type_ref;
-        int param_count = count_non_void_params(g, &fn_proto->params);
-        LLVMTypeRef *param_types = allocate<LLVMTypeRef>(param_count);
-        int gen_param_index = 0;
-        for (int param_decl_i = 0; param_decl_i < fn_proto->params.length; param_decl_i += 1) {
-            AstNode *param_node = fn_proto->params.at(param_decl_i);
-            assert(param_node->type == NodeTypeParamDecl);
-            if (is_param_decl_type_void(g, param_node))
-                continue;
-            AstNode *type_node = param_node->data.param_decl.type;
-            param_types[gen_param_index] = fn_proto_type_from_type_node(g, type_node)->type_ref;
-            gen_param_index += 1;
-        }
-        LLVMTypeRef function_type = LLVMFunctionType(ret_type, param_types, param_count, fn_proto->is_var_args);
-
-        LLVMValueRef fn = LLVMAddFunction(g->module, buf_ptr(&fn_table_entry->symbol_name), function_type);
-
-        for (int attr_i = 0; attr_i < fn_table_entry->fn_attr_list.length; attr_i += 1) {
-            FnAttrId attr_id = fn_table_entry->fn_attr_list.at(attr_i);
-            LLVMAddFunctionAttr(fn, to_llvm_fn_attr(attr_id));
-        }
-
-        LLVMSetLinkage(fn, fn_table_entry->internal_linkage ? LLVMInternalLinkage : LLVMExternalLinkage);
-
-        if (type_is_unreachable(g, fn_proto->return_type)) {
-            LLVMAddFunctionAttr(fn, LLVMNoReturnAttribute);
-        }
-        LLVMSetFunctionCallConv(fn, fn_table_entry->calling_convention);
-        if (!fn_table_entry->is_extern) {
-            LLVMAddFunctionAttr(fn, LLVMNoUnwindAttribute);
-        }
-
         // set parameter attributes
-        gen_param_index = 0;
+        int gen_param_index = 0;
         for (int param_decl_i = 0; param_decl_i < fn_proto->params.length; param_decl_i += 1) {
             AstNode *param_node = fn_proto->params.at(param_decl_i);
             assert(param_node->type == NodeTypeParamDecl);
@@ -2218,7 +2147,7 @@ static void do_code_gen(CodeGen *g) {
                 continue;
             AstNode *type_node = param_node->data.param_decl.type;
             TypeTableEntry *param_type = fn_proto_type_from_type_node(g, type_node);
-            LLVMValueRef argument_val = LLVMGetParam(fn, gen_param_index);
+            LLVMValueRef argument_val = LLVMGetParam(fn_table_entry->fn_value, gen_param_index);
             bool param_is_noalias = param_node->data.param_decl.is_noalias;
             if (param_type->id == TypeTableEntryIdPointer && param_is_noalias) {
                 LLVMAddAttribute(argument_val, LLVMNoAliasAttribute);
@@ -2230,7 +2159,6 @@ static void do_code_gen(CodeGen *g) {
             gen_param_index += 1;
         }
 
-        fn_table_entry->fn_value = fn;
     }
 
     // Generate function definitions.
@@ -2245,22 +2173,9 @@ static void do_code_gen(CodeGen *g) {
         assert(proto_node->type == NodeTypeFnProto);
         AstNodeFnProto *fn_proto = &proto_node->data.fn_proto;
 
-        // Add debug info.
-        unsigned line_number = fn_def_node->line + 1;
-        unsigned scope_line = line_number;
-        bool is_definition = true;
-        unsigned flags = 0;
-        bool is_optimized = g->build_type == CodeGenBuildTypeRelease;
-        LLVMZigDISubprogram *subprogram = LLVMZigCreateFunction(g->dbuilder,
-            import->block_context->di_scope, buf_ptr(&fn_table_entry->symbol_name), "",
-            import->di_file, line_number,
-            create_di_function_type(g, fn_proto, import->di_file), fn_table_entry->internal_linkage, 
-            is_definition, scope_line, flags, is_optimized, fn);
-
         LLVMBasicBlockRef entry_block = LLVMAppendBasicBlock(fn, "entry");
         LLVMPositionBuilderAtEnd(g->builder, entry_block);
 
-        fn_def_node->data.fn_def.block_context->di_scope = LLVMZigSubprogramToScope(subprogram);
 
         AstNode *body_node = fn_def_node->data.fn_def.body;
         build_label_blocks(g, body_node);
src/zig_llvm.cpp
@@ -388,6 +388,11 @@ LLVMZigDIScope *LLVMZigSubprogramToScope(LLVMZigDISubprogram *subprogram) {
     return reinterpret_cast<LLVMZigDIScope*>(scope);
 }
 
+LLVMZigDIType *LLVMZigSubroutineToType(LLVMZigDISubroutineType *subrtype) {
+    DIType *di_type = reinterpret_cast<DISubroutineType*>(subrtype);
+    return reinterpret_cast<LLVMZigDIType*>(di_type);
+}
+
 LLVMZigDIScope *LLVMZigTypeToScope(LLVMZigDIType *type) {
     DIScope *scope = reinterpret_cast<DIType*>(type);
     return reinterpret_cast<LLVMZigDIScope*>(scope);
src/zig_llvm.hpp
@@ -101,6 +101,7 @@ LLVMZigDIScope *LLVMZigCompileUnitToScope(LLVMZigDICompileUnit *compile_unit);
 LLVMZigDIScope *LLVMZigFileToScope(LLVMZigDIFile *difile);
 LLVMZigDIScope *LLVMZigSubprogramToScope(LLVMZigDISubprogram *subprogram);
 LLVMZigDIScope *LLVMZigTypeToScope(LLVMZigDIType *type);
+LLVMZigDIType *LLVMZigSubroutineToType(LLVMZigDISubroutineType *subrtype);
 
 LLVMZigDILocalVariable *LLVMZigCreateLocalVariable(LLVMZigDIBuilder *dbuilder, unsigned tag,
         LLVMZigDIScope *scope, const char *name, LLVMZigDIFile *file, unsigned line_no,
test/run_tests.cpp
@@ -1296,10 +1296,6 @@ fn f() => {
 fn f(a : unreachable) => {}
     )SOURCE", 1, ".tmp_source.zig:2:10: error: parameter of type 'unreachable' not allowed");
 
-    add_compile_fail_case("exporting a void parameter", R"SOURCE(
-export fn f(a : void) => {}
-    )SOURCE", 1, ".tmp_source.zig:2:17: error: parameter of type 'void' not allowed on exported functions");
-
     add_compile_fail_case("unused label", R"SOURCE(
 fn f() => {
 a_label: