Commit 2521afef69

Andrew Kelley <superjoe30@gmail.com>
2016-02-04 05:34:09
add ability to call function pointer field
also introduce the self hosted tests closes #108
1 parent 5c310f4
src/all_types.hpp
@@ -357,6 +357,7 @@ struct AstNodeFnCallExpr {
     Expr resolved_expr;
     FnTableEntry *fn_entry;
     CastOp cast_op;
+    TypeTableEntry *enum_type;
     // if cast_op is CastOpArrayToString, this will be a pointer to
     // the string struct on the stack
     LLVMValueRef tmp_ptr;
@@ -390,6 +391,9 @@ struct AstNodeFieldAccessExpr {
     TypeEnumField *type_enum_field;
     Expr resolved_expr;
     StructValExprCodeGen resolved_struct_val_expr; // for enum values
+    bool is_fn_call;
+    TypeTableEntry *bare_struct_type;
+    bool is_member_fn;
 };
 
 struct AstNodeDirective {
src/analyze.cpp
@@ -28,6 +28,7 @@ static TypeTableEntry *analyze_block_expr(CodeGen *g, ImportTableEntry *import,
         TypeTableEntry *expected_type, AstNode *node);
 static TypeTableEntry *resolve_expr_const_val_as_void(CodeGen *g, AstNode *node);
 static TypeTableEntry *resolve_expr_const_val_as_fn(CodeGen *g, AstNode *node, FnTableEntry *fn);
+static TypeTableEntry *resolve_expr_const_val_as_type(CodeGen *g, AstNode *node, TypeTableEntry *type);
 static void detect_top_level_decl_deps(CodeGen *g, ImportTableEntry *import, AstNode *node);
 static void analyze_top_level_decls_root(CodeGen *g, ImportTableEntry *import, AstNode *node);
 
@@ -2220,15 +2221,28 @@ static TypeTableEntry *analyze_field_access_expr(CodeGen *g, ImportTableEntry *i
     TypeTableEntry *struct_type = analyze_expression(g, import, context, nullptr, struct_expr_node);
     Buf *field_name = &node->data.field_access_expr.field_name;
 
+    bool wrapped_in_fn_call = node->data.field_access_expr.is_fn_call;
+
     if (struct_type->id == TypeTableEntryIdStruct || (struct_type->id == TypeTableEntryIdPointer &&
          struct_type->data.pointer.child_type->id == TypeTableEntryIdStruct))
     {
         TypeTableEntry *bare_struct_type = (struct_type->id == TypeTableEntryIdStruct) ?
             struct_type : struct_type->data.pointer.child_type;
 
+        node->data.field_access_expr.bare_struct_type = bare_struct_type;
         node->data.field_access_expr.type_struct_field = find_struct_type_field(bare_struct_type, field_name);
         if (node->data.field_access_expr.type_struct_field) {
             return node->data.field_access_expr.type_struct_field->type_entry;
+        } else if (wrapped_in_fn_call) {
+            auto table_entry = bare_struct_type->data.structure.fn_table.maybe_get(field_name);
+            if (table_entry) {
+                node->data.field_access_expr.is_member_fn = true;
+                return resolve_expr_const_val_as_fn(g, node, table_entry->value);
+            } else {
+                add_node_error(g, node, buf_sprintf("no member named '%s' in '%s'",
+                    buf_ptr(field_name), buf_ptr(&bare_struct_type->name)));
+                return g->builtin_types.entry_invalid;
+            }
         } else {
             add_node_error(g, node,
                 buf_sprintf("no member named '%s' in '%s'", buf_ptr(field_name), buf_ptr(&struct_type->name)));
@@ -2251,6 +2265,8 @@ static TypeTableEntry *analyze_field_access_expr(CodeGen *g, ImportTableEntry *i
 
         if (child_type->id == TypeTableEntryIdInvalid) {
             return g->builtin_types.entry_invalid;
+        } else if (wrapped_in_fn_call) {
+            return resolve_expr_const_val_as_type(g, node, child_type);
         } else if (child_type->id == TypeTableEntryIdEnum) {
             return analyze_enum_value_expr(g, import, context, node, nullptr, child_type, field_name);
         } else if (child_type->id == TypeTableEntryIdStruct) {
@@ -4128,72 +4144,7 @@ static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import
     }
 
     if (fn_ref_expr->type == NodeTypeFieldAccessExpr) {
-        fn_ref_expr->block_context = context;
-        AstNode *first_param_expr = fn_ref_expr->data.field_access_expr.struct_expr;
-        TypeTableEntry *struct_type = analyze_expression(g, import, context, nullptr, first_param_expr);
-        Buf *name = &fn_ref_expr->data.field_access_expr.field_name;
-        if (struct_type->id == TypeTableEntryIdStruct ||
-            (struct_type->id == TypeTableEntryIdPointer &&
-            struct_type->data.pointer.child_type->id == TypeTableEntryIdStruct))
-        {
-            TypeTableEntry *bare_struct_type = (struct_type->id == TypeTableEntryIdStruct) ?
-                struct_type : struct_type->data.pointer.child_type;
-
-            auto table_entry = bare_struct_type->data.structure.fn_table.maybe_get(name);
-            if (table_entry) {
-                return analyze_fn_call_raw(g, import, context, expected_type, node,
-                        table_entry->value, bare_struct_type);
-            } else {
-                add_node_error(g, fn_ref_expr,
-                        buf_sprintf("no function named '%s' in '%s'",
-                            buf_ptr(name), buf_ptr(&bare_struct_type->name)));
-                return g->builtin_types.entry_invalid;
-            }
-        } else if (struct_type->id == TypeTableEntryIdInvalid) {
-            return struct_type;
-        } else if (struct_type->id == TypeTableEntryIdMetaType) {
-            TypeTableEntry *child_type = resolve_type(g, first_param_expr);
-
-            if (child_type->id == TypeTableEntryIdInvalid) {
-                return g->builtin_types.entry_invalid;
-            } else if (child_type->id == TypeTableEntryIdEnum) {
-                Buf *field_name = &fn_ref_expr->data.field_access_expr.field_name;
-                int param_count = node->data.fn_call_expr.params.length;
-                if (param_count > 1) {
-                    add_node_error(g, first_executing_node(node->data.fn_call_expr.params.at(1)),
-                            buf_sprintf("enum values accept only one parameter"));
-                    return child_type;
-                } else {
-                    AstNode *value_node;
-                    if (param_count == 1) {
-                        value_node = node->data.fn_call_expr.params.at(0);
-                    } else {
-                        value_node = nullptr;
-                    }
-
-                    return analyze_enum_value_expr(g, import, context, fn_ref_expr, value_node,
-                            child_type, field_name);
-                }
-            } else if (child_type->id == TypeTableEntryIdStruct) {
-                Buf *field_name = &fn_ref_expr->data.field_access_expr.field_name;
-                auto entry = child_type->data.structure.fn_table.maybe_get(field_name);
-                if (entry) {
-                    return analyze_fn_call_raw(g, import, context, expected_type, node,
-                            entry->value, nullptr);
-                } else {
-                    add_node_error(g, node,
-                        buf_sprintf("struct '%s' has no function called '%s'",
-                            buf_ptr(&child_type->name), buf_ptr(field_name)));
-                    return g->builtin_types.entry_invalid;
-                }
-            } else {
-                add_node_error(g, first_param_expr, buf_sprintf("member reference base type not struct or enum"));
-                return g->builtin_types.entry_invalid;
-            }
-        } else {
-            add_node_error(g, first_param_expr, buf_sprintf("member reference base type not struct or enum"));
-            return g->builtin_types.entry_invalid;
-        }
+        fn_ref_expr->data.field_access_expr.is_fn_call = true;
     }
 
     TypeTableEntry *invoke_type_entry = analyze_expression(g, import, context, nullptr, fn_ref_expr);
@@ -4207,9 +4158,62 @@ static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import
 
     if (const_val->ok) {
         if (invoke_type_entry->id == TypeTableEntryIdMetaType) {
-            return analyze_cast_expr(g, import, context, node);
+            if (fn_ref_expr->type == NodeTypeFieldAccessExpr) {
+                TypeTableEntry *child_type = resolve_type(g, fn_ref_expr);
+
+                if (child_type->id == TypeTableEntryIdInvalid) {
+                    return g->builtin_types.entry_invalid;
+                } else if (child_type->id == TypeTableEntryIdEnum) {
+                    Buf *field_name = &fn_ref_expr->data.field_access_expr.field_name;
+                    int param_count = node->data.fn_call_expr.params.length;
+                    if (param_count > 1) {
+                        add_node_error(g, first_executing_node(node->data.fn_call_expr.params.at(1)),
+                                buf_sprintf("enum values accept only one parameter"));
+                        return child_type;
+                    } else {
+                        AstNode *value_node;
+                        if (param_count == 1) {
+                            value_node = node->data.fn_call_expr.params.at(0);
+                        } else {
+                            value_node = nullptr;
+                        }
+
+                        node->data.fn_call_expr.enum_type = child_type;
+
+                        return analyze_enum_value_expr(g, import, context, fn_ref_expr, value_node,
+                                child_type, field_name);
+                    }
+                } else if (child_type->id == TypeTableEntryIdStruct) {
+                    Buf *field_name = &fn_ref_expr->data.field_access_expr.field_name;
+                    auto entry = child_type->data.structure.fn_table.maybe_get(field_name);
+                    if (entry) {
+                        return analyze_fn_call_raw(g, import, context, expected_type, node,
+                                entry->value, nullptr);
+                    } else {
+                        add_node_error(g, node,
+                            buf_sprintf("struct '%s' has no function called '%s'",
+                                buf_ptr(&child_type->name), buf_ptr(field_name)));
+                        return g->builtin_types.entry_invalid;
+                    }
+                } else {
+                    add_node_error(g, fn_ref_expr, buf_sprintf("member reference base type not struct or enum"));
+                    return g->builtin_types.entry_invalid;
+                }
+            } else {
+                return analyze_cast_expr(g, import, context, node);
+            }
         } else if (invoke_type_entry->id == TypeTableEntryIdFn) {
-            return analyze_fn_call_raw(g, import, context, expected_type, node, const_val->data.x_fn, nullptr);
+            TypeTableEntry *bare_struct_type;
+            if (fn_ref_expr->type == NodeTypeFieldAccessExpr &&
+                fn_ref_expr->data.field_access_expr.is_member_fn)
+            {
+                bare_struct_type = fn_ref_expr->data.field_access_expr.bare_struct_type;
+            } else {
+                bare_struct_type = nullptr;
+            }
+
+            return analyze_fn_call_raw(g, import, context, expected_type, node,
+                    const_val->data.x_fn, bare_struct_type);
         } else {
             add_node_error(g, fn_ref_expr,
                 buf_sprintf("type '%s' not a function", buf_ptr(&invoke_type_entry->name)));
src/codegen.cpp
@@ -544,41 +544,28 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) {
         return gen_cast_expr(g, node);
     }
 
-    FnTableEntry *fn_table_entry = node->data.fn_call_expr.fn_entry;
     AstNode *fn_ref_expr = node->data.fn_call_expr.fn_ref_expr;
+    if (node->data.fn_call_expr.enum_type) {
+        int param_count = node->data.fn_call_expr.params.length;
+        AstNode *arg1_node;
+        if (param_count == 1) {
+            arg1_node = node->data.fn_call_expr.params.at(0);
+        } else {
+            assert(param_count == 0);
+            arg1_node = nullptr;
+        }
+        return gen_enum_value_expr(g, fn_ref_expr, node->data.fn_call_expr.enum_type, arg1_node);
+    }
+
+    FnTableEntry *fn_table_entry = node->data.fn_call_expr.fn_entry;
     TypeTableEntry *struct_type = nullptr;
     AstNode *first_param_expr = nullptr;
-    if (fn_ref_expr->type == NodeTypeFieldAccessExpr) {
+
+    if (fn_ref_expr->type == NodeTypeFieldAccessExpr &&
+        fn_ref_expr->data.field_access_expr.is_member_fn)
+    {
         first_param_expr = fn_ref_expr->data.field_access_expr.struct_expr;
         struct_type = get_expr_type(first_param_expr);
-        if (struct_type->id == TypeTableEntryIdStruct) {
-            fn_table_entry = node->data.fn_call_expr.fn_entry;
-        } else if (struct_type->id == TypeTableEntryIdPointer) {
-            assert(struct_type->data.pointer.child_type->id == TypeTableEntryIdStruct);
-            fn_table_entry = node->data.fn_call_expr.fn_entry;
-        } else if (struct_type->id == TypeTableEntryIdMetaType) {
-            TypeTableEntry *child_type = get_type_for_type_node(first_param_expr);
-
-            if (child_type->id == TypeTableEntryIdEnum) {
-                int param_count = node->data.fn_call_expr.params.length;
-                AstNode *arg1_node;
-                if (param_count == 1) {
-                    arg1_node = node->data.fn_call_expr.params.at(0);
-                } else {
-                    assert(param_count == 0);
-                    arg1_node = nullptr;
-                }
-                return gen_enum_value_expr(g, fn_ref_expr, child_type, arg1_node);
-            } else if (child_type->id == TypeTableEntryIdStruct) {
-                struct_type = nullptr;
-                first_param_expr = nullptr;
-                fn_table_entry = node->data.fn_call_expr.fn_entry;
-            } else {
-                zig_unreachable();
-            }
-        } else {
-            zig_unreachable();
-        }
     }
 
     TypeTableEntry *fn_type;
std/test_runner.zig
@@ -17,9 +17,7 @@ pub fn main(args: [][]u8) -> %void {
         %%stderr.print_str(test_fn.name);
         %%stderr.print_str("...");
 
-        // TODO support calling function pointers as fields directly
-        const fn_ptr = test_fn.func;
-        fn_ptr();
+        test_fn.func();
 
 
         %%stderr.print_str("OK\n");
test/run_tests.cpp
@@ -25,6 +25,7 @@ struct TestCase {
     ZigList<const char *> compiler_args;
     ZigList<const char *> program_args;
     bool is_parseh;
+    bool is_self_hosted;
 };
 
 static ZigList<TestCase*> test_cases = {0};
@@ -157,21 +158,6 @@ fn this_is_a_function() -> unreachable {
 }
     )SOURCE", "OK\n");
 
-    add_simple_case("comments", R"SOURCE(
-import "std.zig";
-
-/**
-    * multi line doc comment
-    */
-fn another_function() {}
-
-/// this is a documentation comment
-/// doc comment line 2
-pub fn main(args: [][]u8) -> %void {
-    %%stdout.printf(/* mid-line comment /* nested */ */ "OK\n");
-}
-    )SOURCE", "OK\n");
-
     {
         TestCase *tc = add_simple_case("multiple files with private function", R"SOURCE(
 import "std.zig";
@@ -2205,6 +2191,30 @@ extern void (*fn_ptr)(void);
 })SOURCE");
 }
 
+static void run_self_hosted_test(void) {
+    Buf zig_stderr = BUF_INIT;
+    Buf zig_stdout = BUF_INIT;
+    int return_code;
+    ZigList<const char *> args = {0};
+    args.append("test");
+    args.append("../test/self_hosted.zig");
+    os_exec_process(zig_exe, args, &return_code, &zig_stderr, &zig_stdout);
+
+    if (return_code) {
+        printf("\nSelf-hosted tests failed:\n");
+        printf("./zig test ../test/self_hosted.zig\n");
+        printf("%s\n", buf_ptr(&zig_stderr));
+        exit(1);
+    }
+}
+
+static void add_self_hosted_tests(void) {
+    TestCase *test_case = allocate<TestCase>(1);
+    test_case->case_name = "self hosted tests";
+    test_case->is_self_hosted = true;
+    test_cases.append(test_case);
+}
+
 static void print_compiler_invocation(TestCase *test_case) {
     printf("%s", zig_exe);
     for (int i = 0; i < test_case->compiler_args.length; i += 1) {
@@ -2214,6 +2224,10 @@ static void print_compiler_invocation(TestCase *test_case) {
 }
 
 static void run_test(TestCase *test_case) {
+    if (test_case->is_self_hosted) {
+        return run_self_hosted_test();
+    }
+
     for (int i = 0; i < test_case->source_files.length; i += 1) {
         TestSourceFile *test_source = &test_case->source_files.at(i);
         os_write_file(
@@ -2359,6 +2373,7 @@ int main(int argc, char **argv) {
     add_compiling_test_cases();
     add_compile_failure_test_cases();
     add_parseh_test_cases();
+    add_self_hosted_tests();
     run_all_tests(reverse);
     cleanup();
 }
test/self_hosted.zig
@@ -0,0 +1,37 @@
+#attribute("test")
+fn empty_function() {}
+
+
+
+
+/**
+    * multi line doc comment
+    */
+/// this is a documentation comment
+/// doc comment line 2
+#attribute("test")
+fn comments() {
+    comments_f1(/* mid-line comment /* nested */ */ "OK\n");
+}
+
+fn comments_f1(s: []u8) {}
+
+
+
+
+#attribute("test")
+fn fn_call_of_struct_field() {
+    if (call_struct_field(Foo {.ptr = a_func,}) != 13) {
+        unreachable{};
+    }
+}
+
+struct Foo {
+    ptr: fn() -> i32,
+}
+
+fn a_func() -> i32 { 13 }
+
+fn call_struct_field(foo: Foo) -> i32 {
+    return foo.ptr();
+}