Commit c42c91ee7c

Andrew Kelley <superjoe30@gmail.com>
2017-05-26 20:39:18
fix segfault with array of generic functions
closes #377
1 parent fcdd808
src/analyze.cpp
@@ -1033,10 +1033,10 @@ static TypeTableEntry *analyze_fn_type(CodeGen *g, AstNode *proto_node, Scope *c
         AstNode *param_node = fn_proto->params.at(fn_type_id.next_param_index);
         assert(param_node->type == NodeTypeParamDecl);
 
-        bool param_is_inline = param_node->data.param_decl.is_inline;
+        bool param_is_comptime = param_node->data.param_decl.is_inline;
         bool param_is_var_args = param_node->data.param_decl.is_var_args;
 
-        if (param_is_inline) {
+        if (param_is_comptime) {
             if (fn_type_id.is_extern) {
                 add_node_error(g, param_node,
                         buf_sprintf("comptime parameter not allowed in extern function"));
@@ -2507,7 +2507,10 @@ bool types_match_const_cast_only(TypeTableEntry *expected_type, TypeTableEntry *
         if (expected_type->data.fn.fn_type_id.is_var_args != actual_type->data.fn.fn_type_id.is_var_args) {
             return false;
         }
-        if (!expected_type->data.fn.fn_type_id.is_var_args && 
+        if (expected_type->data.fn.is_generic != actual_type->data.fn.is_generic) {
+            return false;
+        }
+        if (!expected_type->data.fn.is_generic && 
             actual_type->data.fn.fn_type_id.return_type->id != TypeTableEntryIdUnreachable &&
             !types_match_const_cast_only(
                 expected_type->data.fn.fn_type_id.return_type,
@@ -2518,12 +2521,12 @@ bool types_match_const_cast_only(TypeTableEntry *expected_type, TypeTableEntry *
         if (expected_type->data.fn.fn_type_id.param_count != actual_type->data.fn.fn_type_id.param_count) {
             return false;
         }
-        for (size_t i = 0; i < expected_type->data.fn.fn_type_id.param_count; i += 1) {
-            if (i == expected_type->data.fn.fn_type_id.param_count - 1 &&
-                expected_type->data.fn.fn_type_id.is_var_args)
-            {
-                continue;
-            }
+        if (expected_type->data.fn.fn_type_id.next_param_index != actual_type->data.fn.fn_type_id.next_param_index) {
+            return false;
+        }
+        assert(expected_type->data.fn.is_generic ||
+                expected_type->data.fn.fn_type_id.next_param_index  == expected_type->data.fn.fn_type_id.param_count);
+        for (size_t i = 0; i < expected_type->data.fn.fn_type_id.next_param_index; i += 1) {
             // note it's reversed for parameters
             FnTypeParamInfo *actual_param_info = &actual_type->data.fn.fn_type_id.param_info[i];
             FnTypeParamInfo *expected_param_info = &expected_type->data.fn.fn_type_id.param_info[i];
src/ir.cpp
@@ -13089,12 +13089,20 @@ static TypeTableEntry *ir_analyze_instruction_fn_proto(IrAnalyze *ira, IrInstruc
             }
         }
         IrInstruction *param_type_value = instruction->param_types[fn_type_id.next_param_index]->other;
+        if (type_is_invalid(param_type_value->value.type)) 
+            return ira->codegen->builtin_types.entry_invalid;
 
         FnTypeParamInfo *param_info = &fn_type_id.param_info[fn_type_id.next_param_index];
         param_info->is_noalias = param_node->data.param_decl.is_noalias;
         param_info->type = ir_resolve_type(ira, param_type_value);
         if (type_is_invalid(param_info->type))
             return ira->codegen->builtin_types.entry_invalid;
+
+        if (param_info->type->id == TypeTableEntryIdVar) {
+            ConstExprValue *out_val = ir_build_const_from(ira, &instruction->base);
+            out_val->data.x_type = get_generic_fn_type(ira->codegen, &fn_type_id);
+            return ira->codegen->builtin_types.entry_type;
+        }
     }
 
     IrInstruction *return_type_value = instruction->return_type->other;
test/cases/generics.zig
@@ -1,6 +1,6 @@
 const assert = @import("std").debug.assert;
 
-test "simpleGenericFn" {
+test "simple generic fn" {
     assert(max(i32, 3, -1) == 3);
     assert(max(f32, 0.123, 0.456) == 0.456);
     assert(add(2, 3) == 5);
@@ -15,7 +15,7 @@ fn add(comptime a: i32, b: i32) -> i32 {
 }
 
 const the_max = max(u32, 1234, 5678);
-test "compileTimeGenericEval" {
+test "compile time generic eval" {
     assert(the_max == 5678);
 }
 
@@ -31,21 +31,22 @@ fn sameButWithFloats(a: f64, b: f64) -> f64 {
     max(f64, a, b)
 }
 
-test "fnWithInlineArgs" {
+test "fn with comptime args" {
     assert(gimmeTheBigOne(1234, 5678) == 5678);
     assert(shouldCallSameInstance(34, 12) == 34);
     assert(sameButWithFloats(0.43, 0.49) == 0.49);
 }
 
 
-test "varParams" {
+test "var params" {
     assert(max_i32(12, 34) == 34);
     assert(max_f64(1.2, 3.4) == 3.4);
 }
 
-// TODO `_`
-const _1 = assert(max_i32(12, 34) == 34);
-const _2 = assert(max_f64(1.2, 3.4) == 3.4);
+comptime {
+    assert(max_i32(12, 34) == 34);
+    assert(max_f64(1.2, 3.4) == 3.4);
+}
 
 fn max_var(a: var, b: var) -> @typeOf(a + b) {
     if (a > b) a else b
@@ -72,7 +73,7 @@ pub fn SmallList(comptime T: type, comptime STATIC_SIZE: usize) -> type {
     }
 }
 
-test "functionWithReturnTypeType" {
+test "function with return type type" {
     var list: List(i32) = undefined;
     var list2: List(i32) = undefined;
     list.length = 10;
@@ -82,7 +83,7 @@ test "functionWithReturnTypeType" {
 }
 
 
-test "genericStruct" {
+test "generic struct" {
     var a1 = GenNode(i32) {.value = 13, .next = null,};
     var b1 = GenNode(bool) {.value = true, .next = null,};
     assert(a1.value == 13);
@@ -97,7 +98,7 @@ fn GenNode(comptime T: type) -> type {
     }
 }
 
-test "constDeclsInStruct" {
+test "const decls in struct" {
     assert(GenericDataThing(3).count_plus_one == 4);
 }
 fn GenericDataThing(comptime count: isize) -> type {
@@ -107,7 +108,7 @@ fn GenericDataThing(comptime count: isize) -> type {
 }
 
 
-test "useGenericParamInGenericParam" {
+test "use generic param in generic param" {
     assert(aGenericFn(i32, 3, 4) == 7);
 }
 fn aGenericFn(comptime T: type, comptime a: T, b: T) -> T {
@@ -115,7 +116,7 @@ fn aGenericFn(comptime T: type, comptime a: T, b: T) -> T {
 }
 
 
-test "genericFnWithImplicitCast" {
+test "generic fn with implicit cast" {
     assert(getFirstByte(u8, []u8 {13}) == 13);
     assert(getFirstByte(u16, []u16 {0, 13}) == 0);
 }
@@ -123,3 +124,14 @@ fn getByte(ptr: ?&const u8) -> u8 {*??ptr}
 fn getFirstByte(comptime T: type, mem: []const T) -> u8 {
     getByte(@ptrCast(&const u8, &mem[0]))
 }
+
+
+const foos = []fn(var) -> bool { foo1, foo2 };
+
+fn foo1(arg: var) -> bool { arg }
+fn foo2(arg: var) -> bool { !arg }
+
+test "array of generic fns" {
+    assert(foos[0](true));
+    assert(!foos[1](true));
+}
test/cases/var_args.zig
@@ -54,3 +54,14 @@ fn extraFn(extra: u32, args: ...) -> usize {
     }
     return args.len;
 }
+
+
+const foos = []fn(...) -> bool { foo1, foo2 };
+
+fn foo1(args: ...) -> bool { true }
+fn foo2(args: ...) -> bool { false }
+
+test "array of var args functions" {
+    assert(foos[0]());
+    assert(!foos[1]());
+}
test/compile_errors.zig
@@ -1904,4 +1904,16 @@ pub fn addCases(cases: &tests.CompileErrorContext) {
         \\}
     ,
         ".tmp_source.zig:7:9: error: calling a generic function requires compile-time known function value");
+
+    cases.add("calling a generic function only known at runtime",
+        \\var foos = []fn(var) { foo1, foo2 };
+        \\
+        \\fn foo1(arg: var) {}
+        \\fn foo2(arg: var) {}
+        \\
+        \\pub fn main() -> %void {
+        \\    foos[0](true);
+        \\}
+    ,
+        ".tmp_source.zig:7:9: error: calling a generic function requires compile-time known function value");
 }