Commit c42c91ee7c
Changed files (5)
src/analyze.cpp
@@ -1033,10 +1033,10 @@ static TypeTableEntry *analyze_fn_type(CodeGen *g, AstNode *proto_node, Scope *c
AstNode *param_node = fn_proto->params.at(fn_type_id.next_param_index);
assert(param_node->type == NodeTypeParamDecl);
- bool param_is_inline = param_node->data.param_decl.is_inline;
+ bool param_is_comptime = param_node->data.param_decl.is_inline;
bool param_is_var_args = param_node->data.param_decl.is_var_args;
- if (param_is_inline) {
+ if (param_is_comptime) {
if (fn_type_id.is_extern) {
add_node_error(g, param_node,
buf_sprintf("comptime parameter not allowed in extern function"));
@@ -2507,7 +2507,10 @@ bool types_match_const_cast_only(TypeTableEntry *expected_type, TypeTableEntry *
if (expected_type->data.fn.fn_type_id.is_var_args != actual_type->data.fn.fn_type_id.is_var_args) {
return false;
}
- if (!expected_type->data.fn.fn_type_id.is_var_args &&
+ if (expected_type->data.fn.is_generic != actual_type->data.fn.is_generic) {
+ return false;
+ }
+ if (!expected_type->data.fn.is_generic &&
actual_type->data.fn.fn_type_id.return_type->id != TypeTableEntryIdUnreachable &&
!types_match_const_cast_only(
expected_type->data.fn.fn_type_id.return_type,
@@ -2518,12 +2521,12 @@ bool types_match_const_cast_only(TypeTableEntry *expected_type, TypeTableEntry *
if (expected_type->data.fn.fn_type_id.param_count != actual_type->data.fn.fn_type_id.param_count) {
return false;
}
- for (size_t i = 0; i < expected_type->data.fn.fn_type_id.param_count; i += 1) {
- if (i == expected_type->data.fn.fn_type_id.param_count - 1 &&
- expected_type->data.fn.fn_type_id.is_var_args)
- {
- continue;
- }
+ if (expected_type->data.fn.fn_type_id.next_param_index != actual_type->data.fn.fn_type_id.next_param_index) {
+ return false;
+ }
+ assert(expected_type->data.fn.is_generic ||
+ expected_type->data.fn.fn_type_id.next_param_index == expected_type->data.fn.fn_type_id.param_count);
+ for (size_t i = 0; i < expected_type->data.fn.fn_type_id.next_param_index; i += 1) {
// note it's reversed for parameters
FnTypeParamInfo *actual_param_info = &actual_type->data.fn.fn_type_id.param_info[i];
FnTypeParamInfo *expected_param_info = &expected_type->data.fn.fn_type_id.param_info[i];
src/ir.cpp
@@ -13089,12 +13089,20 @@ static TypeTableEntry *ir_analyze_instruction_fn_proto(IrAnalyze *ira, IrInstruc
}
}
IrInstruction *param_type_value = instruction->param_types[fn_type_id.next_param_index]->other;
+ if (type_is_invalid(param_type_value->value.type))
+ return ira->codegen->builtin_types.entry_invalid;
FnTypeParamInfo *param_info = &fn_type_id.param_info[fn_type_id.next_param_index];
param_info->is_noalias = param_node->data.param_decl.is_noalias;
param_info->type = ir_resolve_type(ira, param_type_value);
if (type_is_invalid(param_info->type))
return ira->codegen->builtin_types.entry_invalid;
+
+ if (param_info->type->id == TypeTableEntryIdVar) {
+ ConstExprValue *out_val = ir_build_const_from(ira, &instruction->base);
+ out_val->data.x_type = get_generic_fn_type(ira->codegen, &fn_type_id);
+ return ira->codegen->builtin_types.entry_type;
+ }
}
IrInstruction *return_type_value = instruction->return_type->other;
test/cases/generics.zig
@@ -1,6 +1,6 @@
const assert = @import("std").debug.assert;
-test "simpleGenericFn" {
+test "simple generic fn" {
assert(max(i32, 3, -1) == 3);
assert(max(f32, 0.123, 0.456) == 0.456);
assert(add(2, 3) == 5);
@@ -15,7 +15,7 @@ fn add(comptime a: i32, b: i32) -> i32 {
}
const the_max = max(u32, 1234, 5678);
-test "compileTimeGenericEval" {
+test "compile time generic eval" {
assert(the_max == 5678);
}
@@ -31,21 +31,22 @@ fn sameButWithFloats(a: f64, b: f64) -> f64 {
max(f64, a, b)
}
-test "fnWithInlineArgs" {
+test "fn with comptime args" {
assert(gimmeTheBigOne(1234, 5678) == 5678);
assert(shouldCallSameInstance(34, 12) == 34);
assert(sameButWithFloats(0.43, 0.49) == 0.49);
}
-test "varParams" {
+test "var params" {
assert(max_i32(12, 34) == 34);
assert(max_f64(1.2, 3.4) == 3.4);
}
-// TODO `_`
-const _1 = assert(max_i32(12, 34) == 34);
-const _2 = assert(max_f64(1.2, 3.4) == 3.4);
+comptime {
+ assert(max_i32(12, 34) == 34);
+ assert(max_f64(1.2, 3.4) == 3.4);
+}
fn max_var(a: var, b: var) -> @typeOf(a + b) {
if (a > b) a else b
@@ -72,7 +73,7 @@ pub fn SmallList(comptime T: type, comptime STATIC_SIZE: usize) -> type {
}
}
-test "functionWithReturnTypeType" {
+test "function with return type type" {
var list: List(i32) = undefined;
var list2: List(i32) = undefined;
list.length = 10;
@@ -82,7 +83,7 @@ test "functionWithReturnTypeType" {
}
-test "genericStruct" {
+test "generic struct" {
var a1 = GenNode(i32) {.value = 13, .next = null,};
var b1 = GenNode(bool) {.value = true, .next = null,};
assert(a1.value == 13);
@@ -97,7 +98,7 @@ fn GenNode(comptime T: type) -> type {
}
}
-test "constDeclsInStruct" {
+test "const decls in struct" {
assert(GenericDataThing(3).count_plus_one == 4);
}
fn GenericDataThing(comptime count: isize) -> type {
@@ -107,7 +108,7 @@ fn GenericDataThing(comptime count: isize) -> type {
}
-test "useGenericParamInGenericParam" {
+test "use generic param in generic param" {
assert(aGenericFn(i32, 3, 4) == 7);
}
fn aGenericFn(comptime T: type, comptime a: T, b: T) -> T {
@@ -115,7 +116,7 @@ fn aGenericFn(comptime T: type, comptime a: T, b: T) -> T {
}
-test "genericFnWithImplicitCast" {
+test "generic fn with implicit cast" {
assert(getFirstByte(u8, []u8 {13}) == 13);
assert(getFirstByte(u16, []u16 {0, 13}) == 0);
}
@@ -123,3 +124,14 @@ fn getByte(ptr: ?&const u8) -> u8 {*??ptr}
fn getFirstByte(comptime T: type, mem: []const T) -> u8 {
getByte(@ptrCast(&const u8, &mem[0]))
}
+
+
+const foos = []fn(var) -> bool { foo1, foo2 };
+
+fn foo1(arg: var) -> bool { arg }
+fn foo2(arg: var) -> bool { !arg }
+
+test "array of generic fns" {
+ assert(foos[0](true));
+ assert(!foos[1](true));
+}
test/cases/var_args.zig
@@ -54,3 +54,14 @@ fn extraFn(extra: u32, args: ...) -> usize {
}
return args.len;
}
+
+
+const foos = []fn(...) -> bool { foo1, foo2 };
+
+fn foo1(args: ...) -> bool { true }
+fn foo2(args: ...) -> bool { false }
+
+test "array of var args functions" {
+ assert(foos[0]());
+ assert(!foos[1]());
+}
test/compile_errors.zig
@@ -1904,4 +1904,16 @@ pub fn addCases(cases: &tests.CompileErrorContext) {
\\}
,
".tmp_source.zig:7:9: error: calling a generic function requires compile-time known function value");
+
+ cases.add("calling a generic function only known at runtime",
+ \\var foos = []fn(var) { foo1, foo2 };
+ \\
+ \\fn foo1(arg: var) {}
+ \\fn foo2(arg: var) {}
+ \\
+ \\pub fn main() -> %void {
+ \\ foos[0](true);
+ \\}
+ ,
+ ".tmp_source.zig:7:9: error: calling a generic function requires compile-time known function value");
}