Commit 0541532ed6
Changed files (7)
src/all_types.hpp
@@ -726,6 +726,19 @@ struct FnTypeParamInfo {
TypeTableEntry *type;
};
+struct GenericParamValue {
+ TypeTableEntry *type;
+ ConstExprValue *value;
+};
+
+struct GenericFnTypeId {
+ FnTableEntry *fn_entry;
+ GenericParamValue *params;
+ size_t param_count;
+};
+
+uint32_t generic_fn_type_id_hash(GenericFnTypeId *id);
+bool generic_fn_type_id_eql(GenericFnTypeId *a, GenericFnTypeId *b);
struct FnTypeId {
TypeTableEntry *return_type;
@@ -957,7 +970,6 @@ struct FnTableEntry {
ScopeFnDef *fndef_scope; // parent should be the top level decls or container decls
Scope *child_scope; // parent is scope for last parameter
ScopeBlock *def_scope; // parent is child_scope
- ImportTableEntry *import_entry;
Buf symbol_name;
TypeTableEntry *type_entry; // function type
TypeTableEntry *implicit_return_type;
@@ -969,6 +981,7 @@ struct FnTableEntry {
IrExecutable ir_executable;
IrExecutable analyzed_executable;
size_t prealloc_bbc;
+ AstNode **param_source_nodes;
AstNode *fn_no_inline_set_node;
AstNode *fn_export_set_node;
@@ -1050,6 +1063,7 @@ struct CodeGen {
HashMap<Buf *, TypeTableEntry *, buf_hash, buf_eql_buf> primitive_type_table;
HashMap<FnTypeId *, TypeTableEntry *, fn_type_id_hash, fn_type_id_eql> fn_type_table;
HashMap<Buf *, ErrorTableEntry *, buf_hash, buf_eql_buf> error_table;
+ HashMap<GenericFnTypeId *, FnTableEntry *, generic_fn_type_id_hash, generic_fn_type_id_eql> generic_table;
ZigList<ImportTableEntry *> import_queue;
size_t import_queue_index;
@@ -1201,7 +1215,6 @@ struct VariableTableEntry {
Scope *parent_scope;
Scope *child_scope;
LLVMValueRef param_value_ref;
- bool force_depends_on_compile_var;
bool shadowable;
size_t mem_slot_index;
size_t ref_count;
src/analyze.cpp
@@ -907,22 +907,25 @@ static TypeTableEntry *get_generic_fn_type(CodeGen *g, FnTypeId *fn_type_id) {
return fn_type;
}
-static TypeTableEntry *analyze_fn_type(CodeGen *g, TldFn *tld_fn) {
- AstNode *proto_node = tld_fn->base.source_node;
+void init_fn_type_id(FnTypeId *fn_type_id, AstNode *proto_node) {
assert(proto_node->type == NodeTypeFnProto);
AstNodeFnProto *fn_proto = &proto_node->data.fn_proto;
- FnTypeId fn_type_id = {0};
- fn_type_id.is_extern = fn_proto->is_extern || (fn_proto->visib_mod == VisibModExport);
- fn_type_id.is_naked = fn_proto->is_nakedcc;
- fn_type_id.is_cold = fn_proto->is_coldcc;
- fn_type_id.param_count = fn_proto->params.length;
- fn_type_id.param_info = allocate_nonzero<FnTypeParamInfo>(fn_type_id.param_count);
- fn_type_id.next_param_index = 0;
- fn_type_id.is_var_args = fn_proto->is_var_args;
+ fn_type_id->is_extern = fn_proto->is_extern || (fn_proto->visib_mod == VisibModExport);
+ fn_type_id->is_naked = fn_proto->is_nakedcc;
+ fn_type_id->is_cold = fn_proto->is_coldcc;
+ fn_type_id->param_count = fn_proto->params.length;
+ fn_type_id->param_info = allocate_nonzero<FnTypeParamInfo>(fn_type_id->param_count);
+ fn_type_id->next_param_index = 0;
+ fn_type_id->is_var_args = fn_proto->is_var_args;
+}
- FnTableEntry *fn_entry = tld_fn->fn_entry;
- Scope *child_scope = fn_entry->fndef_scope ? &fn_entry->fndef_scope->base : tld_fn->base.parent_scope;
+static TypeTableEntry *analyze_fn_type(CodeGen *g, AstNode *proto_node, Scope *child_scope) {
+ assert(proto_node->type == NodeTypeFnProto);
+ AstNodeFnProto *fn_proto = &proto_node->data.fn_proto;
+
+ FnTypeId fn_type_id = {0};
+ init_fn_type_id(&fn_type_id, proto_node);
for (; fn_type_id.next_param_index < fn_type_id.param_count; fn_type_id.next_param_index += 1) {
AstNode *param_node = fn_proto->params.at(fn_type_id.next_param_index);
@@ -939,10 +942,6 @@ static TypeTableEntry *analyze_fn_type(CodeGen *g, TldFn *tld_fn) {
return get_generic_fn_type(g, &fn_type_id);
}
- if (fn_entry && buf_len(param_node->data.param_decl.name) == 0) {
- add_node_error(g, param_node, buf_sprintf("missing parameter name"));
- }
-
TypeTableEntry *type_entry = analyze_type_expr(g, child_scope, param_node->data.param_decl.type);
switch (type_entry->id) {
@@ -1366,6 +1365,23 @@ static void get_fully_qualified_decl_name(Buf *buf, Tld *tld, uint8_t sep) {
}
}
+FnTableEntry *create_fn(CodeGen *g, AstNode *proto_node) {
+ assert(proto_node->type == NodeTypeFnProto);
+ AstNodeFnProto *fn_proto = &proto_node->data.fn_proto;
+
+ FnTableEntry *fn_table_entry = allocate<FnTableEntry>(1);
+ fn_table_entry->analyzed_executable.backward_branch_count = &fn_table_entry->prealloc_bbc;
+ fn_table_entry->analyzed_executable.backward_branch_quota = default_backward_branch_quota;
+ fn_table_entry->analyzed_executable.fn_entry = fn_table_entry;
+ fn_table_entry->ir_executable.fn_entry = fn_table_entry;
+ fn_table_entry->proto_node = proto_node;
+ fn_table_entry->fn_def_node = proto_node->data.fn_proto.fn_def_node;
+ fn_table_entry->fn_inline = fn_proto->is_inline ? FnInlineAlways : FnInlineAuto;
+ fn_table_entry->internal_linkage = (fn_proto->visib_mod != VisibModExport);
+
+ return fn_table_entry;
+}
+
static void resolve_decl_fn(CodeGen *g, TldFn *tld_fn) {
ImportTableEntry *import = tld_fn->base.import;
AstNode *proto_node = tld_fn->base.source_node;
@@ -1381,17 +1397,7 @@ static void resolve_decl_fn(CodeGen *g, TldFn *tld_fn) {
return;
}
- FnTableEntry *fn_table_entry = allocate<FnTableEntry>(1);
- fn_table_entry->analyzed_executable.backward_branch_count = &fn_table_entry->prealloc_bbc;
- fn_table_entry->analyzed_executable.backward_branch_quota = default_backward_branch_quota;
- fn_table_entry->analyzed_executable.fn_entry = fn_table_entry;
- fn_table_entry->ir_executable.fn_entry = fn_table_entry;
- fn_table_entry->import_entry = import;
- fn_table_entry->proto_node = proto_node;
- fn_table_entry->fn_def_node = fn_def_node;
- fn_table_entry->fn_inline = fn_proto->is_inline ? FnInlineAlways : FnInlineAuto;
- fn_table_entry->internal_linkage = (fn_proto->visib_mod != VisibModExport);
-
+ FnTableEntry *fn_table_entry = create_fn(g, tld_fn->base.source_node);
get_fully_qualified_decl_name(&fn_table_entry->symbol_name, &tld_fn->base, '_');
tld_fn->fn_entry = fn_table_entry;
@@ -1399,9 +1405,18 @@ static void resolve_decl_fn(CodeGen *g, TldFn *tld_fn) {
if (fn_table_entry->fn_def_node) {
fn_table_entry->fndef_scope = create_fndef_scope(
fn_table_entry->fn_def_node, tld_fn->base.parent_scope, fn_table_entry);
+
+ for (size_t i = 0; i < fn_proto->params.length; i += 1) {
+ AstNode *param_node = fn_proto->params.at(i);
+ assert(param_node->type == NodeTypeParamDecl);
+ if (buf_len(param_node->data.param_decl.name) == 0) {
+ add_node_error(g, param_node, buf_sprintf("missing parameter name"));
+ }
+ }
}
- fn_table_entry->type_entry = analyze_fn_type(g, tld_fn);
+ Scope *child_scope = fn_table_entry->fndef_scope ? &fn_table_entry->fndef_scope->base : tld_fn->base.parent_scope;
+ fn_table_entry->type_entry = analyze_fn_type(g, proto_node, child_scope);
if (fn_table_entry->type_entry->id == TypeTableEntryIdInvalid) {
tld_fn->base.resolution = TldResolutionInvalid;
@@ -2142,6 +2157,13 @@ bool type_is_codegen_pointer(TypeTableEntry *type) {
return false;
}
+AstNode *get_param_decl_node(FnTableEntry *fn_entry, size_t index) {
+ if (fn_entry->param_source_nodes)
+ return fn_entry->param_source_nodes[index];
+ else
+ return fn_entry->proto_node->data.fn_proto.params.at(index);
+}
+
static void analyze_fn_body(CodeGen *g, FnTableEntry *fn_table_entry) {
assert(fn_table_entry->anal_state != FnAnalStateProbing);
if (fn_table_entry->anal_state != FnAnalStateReady)
@@ -2151,17 +2173,19 @@ static void analyze_fn_body(CodeGen *g, FnTableEntry *fn_table_entry) {
AstNodeFnProto *fn_proto = &fn_table_entry->proto_node->data.fn_proto;
- Scope *child_scope = &fn_table_entry->fndef_scope->base;
- assert(child_scope);
+ assert(fn_table_entry->fndef_scope);
+ if (!fn_table_entry->child_scope)
+ fn_table_entry->child_scope = &fn_table_entry->fndef_scope->base;
// define local variables for parameters
TypeTableEntry *fn_type = fn_table_entry->type_entry;
assert(!fn_type->data.fn.is_generic);
FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
for (size_t i = 0; i < fn_type_id->param_count; i += 1) {
- AstNode *param_decl_node = fn_proto->params.at(i);
- AstNodeParamDecl *param_decl = ¶m_decl_node->data.param_decl;
FnTypeParamInfo *param_info = &fn_type_id->param_info[i];
+ AstNode *param_decl_node = get_param_decl_node(fn_table_entry, i);
+ AstNodeParamDecl *param_decl = ¶m_decl_node->data.param_decl;
+
TypeTableEntry *param_type = param_info->type;
bool is_noalias = param_info->is_noalias;
@@ -2175,9 +2199,9 @@ static void analyze_fn_body(CodeGen *g, FnTableEntry *fn_table_entry) {
buf_sprintf("byvalue types not yet supported on extern function parameters"));
}
- VariableTableEntry *var = add_variable(g, param_decl_node, child_scope, param_decl->name, param_type, true, nullptr);
+ VariableTableEntry *var = add_variable(g, param_decl_node, fn_table_entry->child_scope, param_decl->name, param_type, true, nullptr);
var->src_arg_index = i;
- child_scope = var->child_scope;
+ fn_table_entry->child_scope = var->child_scope;
fn_table_entry->variable_list.append(var);
if (fn_type->data.fn.gen_param_info) {
@@ -2185,8 +2209,6 @@ static void analyze_fn_body(CodeGen *g, FnTableEntry *fn_table_entry) {
}
}
- fn_table_entry->child_scope = child_scope;
-
TypeTableEntry *expected_type = fn_type_id->return_type;
if (fn_type_id->is_extern && handle_is_ptr(expected_type)) {
@@ -2622,6 +2644,40 @@ static uint32_t hash_const_val(TypeTableEntry *type, ConstExprValue *const_val)
zig_unreachable();
}
+uint32_t generic_fn_type_id_hash(GenericFnTypeId *id) {
+ uint32_t result = 0;
+ result += hash_ptr(id->fn_entry);
+ for (size_t i = 0; i < id->param_count; i += 1) {
+ GenericParamValue *generic_param = &id->params[i];
+ if (generic_param->value) {
+ result += hash_const_val(generic_param->type, generic_param->value);
+ result += hash_ptr(generic_param->type);
+ }
+ }
+ return result;
+}
+
+bool generic_fn_type_id_eql(GenericFnTypeId *a, GenericFnTypeId *b) {
+ assert(a->fn_entry);
+ if (a->fn_entry != b->fn_entry) return false;
+ assert(a->param_count == b->param_count);
+ for (size_t i = 0; i < a->param_count; i += 1) {
+ GenericParamValue *a_val = &a->params[i];
+ GenericParamValue *b_val = &b->params[i];
+ if (a_val->type != b_val->type) return false;
+ if (a_val->value && b_val->value) {
+ assert(a_val->value->special == ConstValSpecialStatic);
+ assert(b_val->value->special == ConstValSpecialStatic);
+ if (!const_values_equal(a_val->value, b_val->value, a_val->type)) {
+ return false;
+ }
+ } else {
+ assert(!a_val->value && !b_val->value);
+ }
+ }
+ return true;
+}
+
bool type_has_bits(TypeTableEntry *type_entry) {
assert(type_entry);
assert(type_entry->id != TypeTableEntryIdInvalid);
src/analyze.hpp
@@ -70,6 +70,9 @@ void init_tld(Tld *tld, TldId id, Buf *name, VisibMod visib_mod, AstNode *source
VariableTableEntry *add_variable(CodeGen *g, AstNode *source_node, Scope *parent_scope, Buf *name,
TypeTableEntry *type_entry, bool is_const, ConstExprValue *init_value);
TypeTableEntry *analyze_type_expr(CodeGen *g, Scope *scope, AstNode *node);
+FnTableEntry *create_fn(CodeGen *g, AstNode *proto_node);
+void init_fn_type_id(FnTypeId *fn_type_id, AstNode *proto_node);
+AstNode *get_param_decl_node(FnTableEntry *fn_entry, size_t index);
Scope *create_block_scope(AstNode *node, Scope *parent);
Scope *create_defer_scope(AstNode *node, Scope *parent);
src/ast_render.cpp
@@ -397,7 +397,7 @@ static void render_node_extra(AstRender *ar, AstNode *node, bool grouped) {
assert(param_decl->type == NodeTypeParamDecl);
if (buf_len(param_decl->data.param_decl.name) > 0) {
const char *noalias_str = param_decl->data.param_decl.is_noalias ? "noalias " : "";
- const char *inline_str = param_decl->data.param_decl.is_inline ? "inline " : "";
+ const char *inline_str = param_decl->data.param_decl.is_inline ? "inline " : "";
fprintf(ar->f, "%s%s", noalias_str, inline_str);
print_symbol(ar, param_decl->data.param_decl.name);
fprintf(ar->f, ": ");
src/codegen.cpp
@@ -60,6 +60,7 @@ CodeGen *codegen_create(Buf *root_source_dir, const ZigTarget *target) {
g->primitive_type_table.init(32);
g->fn_type_table.init(32);
g->error_table.init(16);
+ g->generic_table.init(16);
g->is_release_build = false;
g->is_test_build = false;
g->want_h_file = true;
@@ -2305,11 +2306,8 @@ static void do_code_gen(CodeGen *g) {
if (should_skip_fn_codegen(g, fn_table_entry))
continue;
- AstNode *proto_node = fn_table_entry->proto_node;
- assert(proto_node->type == NodeTypeFnProto);
- AstNodeFnProto *fn_proto = &proto_node->data.fn_proto;
-
TypeTableEntry *fn_type = fn_table_entry->type_entry;
+ FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
LLVMValueRef fn_val = fn_llvm_value(g, fn_table_entry);
@@ -2327,22 +2325,20 @@ static void do_code_gen(CodeGen *g) {
// set parameter attributes
- for (size_t param_decl_i = 0; param_decl_i < fn_proto->params.length; param_decl_i += 1) {
- AstNode *param_node = fn_proto->params.at(param_decl_i);
- assert(param_node->type == NodeTypeParamDecl);
-
- FnGenParamInfo *info = &fn_type->data.fn.gen_param_info[param_decl_i];
- size_t gen_index = info->gen_index;
- bool is_byval = info->is_byval;
+ for (size_t param_i = 0; param_i < fn_type_id->param_count; param_i += 1) {
+ FnGenParamInfo *gen_info = &fn_type->data.fn.gen_param_info[param_i];
+ size_t gen_index = gen_info->gen_index;
+ bool is_byval = gen_info->is_byval;
if (gen_index == SIZE_MAX) {
continue;
}
- TypeTableEntry *param_type = info->type;
+ FnTypeParamInfo *param_info = &fn_type_id->param_info[param_i];
+
+ TypeTableEntry *param_type = gen_info->type;
LLVMValueRef argument_val = LLVMGetParam(fn_val, gen_index);
- bool param_is_noalias = param_node->data.param_decl.is_noalias;
- if (param_is_noalias) {
+ if (param_info->is_noalias) {
LLVMAddAttribute(argument_val, LLVMNoAliasAttribute);
}
if ((param_type->id == TypeTableEntryIdPointer && param_type->data.pointer.is_const) || is_byval) {
@@ -2402,7 +2398,6 @@ static void do_code_gen(CodeGen *g) {
if (should_skip_fn_codegen(g, fn_table_entry))
continue;
- ImportTableEntry *import = fn_table_entry->import_entry;
LLVMValueRef fn = fn_llvm_value(g, fn_table_entry);
g->cur_fn = fn_table_entry;
g->cur_fn_val = fn;
@@ -2412,10 +2407,6 @@ static void do_code_gen(CodeGen *g) {
g->cur_ret_ptr = nullptr;
}
- AstNode *proto_node = fn_table_entry->proto_node;
- assert(proto_node->type == NodeTypeFnProto);
- AstNodeFnProto *fn_proto = &proto_node->data.fn_proto;
-
build_all_basic_blocks(g, fn_table_entry);
clear_debug_source_node(g);
@@ -2444,6 +2435,8 @@ static void do_code_gen(CodeGen *g) {
*slot = LLVMBuildAlloca(g->builder, instruction->type_entry->type_ref, "");
}
+ ImportTableEntry *import = get_scope_import(&fn_table_entry->fndef_scope->base);
+
// create debug variable declarations for variables and allocate all local variables
for (size_t var_i = 0; var_i < fn_table_entry->variable_list.length; var_i += 1) {
VariableTableEntry *var = fn_table_entry->variable_list.at(var_i);
@@ -2484,10 +2477,12 @@ static void do_code_gen(CodeGen *g) {
}
}
+ FnTypeId *fn_type_id = &fn_table_entry->type_entry->data.fn.fn_type_id;
+
// create debug variable declarations for parameters
// rely on the first variables in the variable_list being parameters.
size_t next_var_i = 0;
- for (size_t param_i = 0; param_i < fn_proto->params.length; param_i += 1) {
+ for (size_t param_i = 0; param_i < fn_type_id->param_count; param_i += 1) {
FnGenParamInfo *info = &fn_table_entry->type_entry->data.fn.gen_param_info[param_i];
if (info->gen_index == SIZE_MAX)
continue;
@@ -3392,14 +3387,12 @@ void codegen_generate_h_file(CodeGen *g) {
buf_resize(&h_buf, 0);
for (size_t fn_def_i = 0; fn_def_i < g->fn_defs.length; fn_def_i += 1) {
FnTableEntry *fn_table_entry = g->fn_defs.at(fn_def_i);
- AstNode *proto_node = fn_table_entry->proto_node;
- assert(proto_node->type == NodeTypeFnProto);
- AstNodeFnProto *fn_proto = &proto_node->data.fn_proto;
- if (fn_proto->visib_mod != VisibModExport)
+ if (fn_table_entry->internal_linkage)
continue;
FnTypeId *fn_type_id = &fn_table_entry->type_entry->data.fn.fn_type_id;
+
Buf return_type_c = BUF_INIT;
get_c_type(g, fn_type_id->return_type, &return_type_c);
@@ -3412,7 +3405,7 @@ void codegen_generate_h_file(CodeGen *g) {
if (fn_type_id->param_count > 0) {
for (size_t param_i = 0; param_i < fn_type_id->param_count; param_i += 1) {
FnTypeParamInfo *param_info = &fn_type_id->param_info[param_i];
- AstNode *param_decl_node = fn_proto->params.at(param_i);
+ AstNode *param_decl_node = get_param_decl_node(fn_table_entry, param_i);
Buf *param_name = param_decl_node->data.param_decl.name;
const char *comma_str = (param_i == 0) ? "" : ", ";
src/ir.cpp
@@ -2825,9 +2825,8 @@ IrInstruction *ir_gen_fn(CodeGen *codegen, FnTableEntry *fn_entry) {
AstNode *body_node = fn_def_node->data.fn_def.body;
assert(fn_entry->child_scope);
- Scope *child_scope = fn_entry->child_scope;
- return ir_gen(codegen, body_node, child_scope, ir_executable);
+ return ir_gen(codegen, body_node, fn_entry->child_scope, ir_executable);
}
static ErrorMsg *ir_add_error(IrAnalyze *ira, IrInstruction *source_instruction, Buf *msg) {
@@ -4220,9 +4219,9 @@ static TypeTableEntry *ir_analyze_instruction_decl_var(IrAnalyze *ira, IrInstruc
}
static bool ir_analyze_fn_call_inline_arg(IrAnalyze *ira, AstNode *fn_proto_node,
- IrInstruction *arg, Scope **exec_scope, size_t *next_arg_index)
+ IrInstruction *arg, Scope **exec_scope, size_t *next_proto_i)
{
- AstNode *param_decl_node = fn_proto_node->data.fn_proto.params.at(*next_arg_index);
+ AstNode *param_decl_node = fn_proto_node->data.fn_proto.params.at(*next_proto_i);
assert(param_decl_node->type == NodeTypeParamDecl);
AstNode *param_type_node = param_decl_node->data.param_decl.type;
TypeTableEntry *param_type = analyze_type_expr(ira->codegen, *exec_scope, param_type_node);
@@ -4241,14 +4240,66 @@ static bool ir_analyze_fn_call_inline_arg(IrAnalyze *ira, AstNode *fn_proto_node
VariableTableEntry *var = add_variable(ira->codegen, param_decl_node,
*exec_scope, param_name, param_type, true, first_arg_val);
*exec_scope = var->child_scope;
- *next_arg_index += 1;
+ *next_proto_i += 1;
return true;
}
+static bool ir_analyze_fn_call_generic_arg(IrAnalyze *ira, AstNode *fn_proto_node,
+ IrInstruction *arg, Scope **child_scope, size_t *next_proto_i,
+ GenericFnTypeId *generic_id, FnTypeId *fn_type_id, IrInstruction **casted_args,
+ FnTableEntry *impl_fn)
+{
+ AstNode *param_decl_node = fn_proto_node->data.fn_proto.params.at(*next_proto_i);
+ assert(param_decl_node->type == NodeTypeParamDecl);
+ AstNode *param_type_node = param_decl_node->data.param_decl.type;
+ TypeTableEntry *param_type = analyze_type_expr(ira->codegen, *child_scope, param_type_node);
+ if (param_type->id == TypeTableEntryIdInvalid)
+ return false;
+
+ bool is_var_type = (param_type->id == TypeTableEntryIdVar);
+ IrInstruction *casted_arg;
+ if (is_var_type) {
+ casted_arg = arg;
+ } else {
+ casted_arg = ir_get_casted_value(ira, arg, param_type);
+ if (casted_arg->type_entry->id == TypeTableEntryIdInvalid)
+ return false;
+ }
+
+ bool inline_arg = param_decl_node->data.param_decl.is_inline;
+ if (inline_arg || is_var_type) {
+ ConstExprValue *arg_val = ir_resolve_const(ira, casted_arg);
+ if (!arg_val)
+ return false;
+
+ Buf *param_name = param_decl_node->data.param_decl.name;
+ VariableTableEntry *var = add_variable(ira->codegen, param_decl_node,
+ *child_scope, param_name, param_type, true, arg_val);
+ *child_scope = var->child_scope;
+ // This generic function instance could be called with anything, so when this variable is read it
+ // needs to know that it depends on compile time variable data.
+ var->value->depends_on_compile_var = true;
+
+ GenericParamValue *generic_param = &generic_id->params[generic_id->param_count];
+ generic_param->type = casted_arg->type_entry;
+ generic_param->value = arg_val;
+ generic_id->param_count += 1;
+ } else {
+ casted_args[fn_type_id->param_count] = casted_arg;
+ FnTypeParamInfo *param_info = &fn_type_id->param_info[fn_type_id->param_count];
+ param_info->type = param_type;
+ param_info->is_noalias = param_decl_node->data.param_decl.is_noalias;
+ impl_fn->param_source_nodes[fn_type_id->param_count] = param_decl_node;
+ fn_type_id->param_count += 1;
+ }
+ *next_proto_i += 1;
+ return true;
+}
+
static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *call_instruction,
FnTableEntry *fn_entry, TypeTableEntry *fn_type, IrInstruction *fn_ref,
- IrInstruction *first_arg_ptr, bool is_inline)
+ IrInstruction *first_arg_ptr, bool inline_fn_call)
{
FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
size_t first_arg_1_or_0 = first_arg_ptr ? 1 : 0;
@@ -4278,7 +4329,8 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
return ira->codegen->builtin_types.entry_invalid;
}
- if (is_inline) {
+ if (inline_fn_call) {
+ // No special handling is needed for compile time evaluation of generic functions.
if (!fn_entry) {
ir_add_error(ira, fn_ref, buf_sprintf("unable to evaluate constant expression"));
return ira->codegen->builtin_types.entry_invalid;
@@ -4290,13 +4342,13 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
// Fork a scope of the function with known values for the parameters.
Scope *exec_scope = &fn_entry->fndef_scope->base;
- size_t next_arg_index = 0;
+ size_t next_proto_i = 0;
if (first_arg_ptr) {
IrInstruction *first_arg = ir_get_deref(ira, first_arg_ptr, first_arg_ptr);
if (first_arg->type_entry->id == TypeTableEntryIdInvalid)
return ira->codegen->builtin_types.entry_invalid;
- if (!ir_analyze_fn_call_inline_arg(ira, fn_proto_node, first_arg, &exec_scope, &next_arg_index))
+ if (!ir_analyze_fn_call_inline_arg(ira, fn_proto_node, first_arg, &exec_scope, &next_proto_i))
return ira->codegen->builtin_types.entry_invalid;
}
@@ -4305,7 +4357,7 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
if (old_arg->type_entry->id == TypeTableEntryIdInvalid)
return ira->codegen->builtin_types.entry_invalid;
- if (!ir_analyze_fn_call_inline_arg(ira, fn_proto_node, old_arg, &exec_scope, &next_arg_index))
+ if (!ir_analyze_fn_call_inline_arg(ira, fn_proto_node, old_arg, &exec_scope, &next_proto_i))
return ira->codegen->builtin_types.entry_invalid;
}
@@ -4327,9 +4379,89 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
return ir_finish_anal(ira, return_type);
}
+ if (fn_type->data.fn.is_generic) {
+ assert(fn_entry);
+
+ IrInstruction **casted_args = allocate<IrInstruction *>(call_param_count);
+
+ // Fork a scope of the function with known values for the parameters.
+ Scope *parent_scope = fn_entry->fndef_scope->base.parent;
+ FnTableEntry *impl_fn = create_fn(ira->codegen, fn_proto_node);
+ impl_fn->param_source_nodes = allocate<AstNode *>(call_param_count);
+ buf_init_from_buf(&impl_fn->symbol_name, &fn_entry->symbol_name);
+ impl_fn->fndef_scope = create_fndef_scope(impl_fn->fn_def_node, parent_scope, impl_fn);
+ impl_fn->child_scope = &impl_fn->fndef_scope->base;
+ FnTypeId fn_type_id = {0};
+ init_fn_type_id(&fn_type_id, fn_proto_node);
+ fn_type_id.param_count = 0;
+
+ // TODO maybe GenericFnTypeId can be replaced with using the child_scope directly
+ // as the key in generic_table
+ GenericFnTypeId *generic_id = allocate<GenericFnTypeId>(1);
+ generic_id->fn_entry = fn_entry;
+ generic_id->param_count = 0;
+ generic_id->params = allocate<GenericParamValue>(src_param_count);
+ size_t next_proto_i = 0;
+
+ if (first_arg_ptr) {
+ IrInstruction *first_arg = ir_get_deref(ira, first_arg_ptr, first_arg_ptr);
+ if (first_arg->type_entry->id == TypeTableEntryIdInvalid)
+ return ira->codegen->builtin_types.entry_invalid;
+
+ if (!ir_analyze_fn_call_generic_arg(ira, fn_proto_node, first_arg, &impl_fn->child_scope,
+ &next_proto_i, generic_id, &fn_type_id, casted_args, impl_fn))
+ {
+ return ira->codegen->builtin_types.entry_invalid;
+ }
+ }
+ for (size_t call_i = 0; call_i < call_instruction->arg_count; call_i += 1) {
+ IrInstruction *arg = call_instruction->args[call_i]->other;
+ if (arg->type_entry->id == TypeTableEntryIdInvalid)
+ return ira->codegen->builtin_types.entry_invalid;
+
+ if (!ir_analyze_fn_call_generic_arg(ira, fn_proto_node, arg, &impl_fn->child_scope,
+ &next_proto_i, generic_id, &fn_type_id, casted_args, impl_fn))
+ {
+ return ira->codegen->builtin_types.entry_invalid;
+ }
+ }
+
+ auto existing_entry = ira->codegen->generic_table.put_unique(generic_id, impl_fn);
+ if (existing_entry) {
+ // throw away all our work and use the existing function
+ impl_fn = existing_entry->value;
+ } else {
+ // finish instantiating the function
+ AstNode *return_type_node = fn_proto_node->data.fn_proto.return_type;
+ TypeTableEntry *return_type = analyze_type_expr(ira->codegen, impl_fn->child_scope, return_type_node);
+ if (return_type->id == TypeTableEntryIdInvalid)
+ return ira->codegen->builtin_types.entry_invalid;
+ fn_type_id.return_type = return_type;
+
+ impl_fn->type_entry = get_fn_type(ira->codegen, &fn_type_id);
+ if (impl_fn->type_entry->id == TypeTableEntryIdInvalid)
+ return ira->codegen->builtin_types.entry_invalid;
+
+ ira->codegen->fn_protos.append(impl_fn);
+ ira->codegen->fn_defs.append(impl_fn);
+ }
+
+ size_t impl_param_count = impl_fn->type_entry->data.fn.fn_type_id.param_count;
+ IrInstruction *new_call_instruction = ir_build_call_from(&ira->new_irb, &call_instruction->base,
+ impl_fn, nullptr, impl_param_count, casted_args);
+
+ TypeTableEntry *return_type = impl_fn->type_entry->data.fn.fn_type_id.return_type;
+ if (type_has_bits(return_type) && handle_is_ptr(return_type)) {
+ FnTableEntry *callsite_fn = exec_fn_entry(ira->new_irb.exec);
+ assert(callsite_fn);
+ callsite_fn->alloca_list.append(new_call_instruction);
+ }
+
+ return ir_finish_anal(ira, return_type);
+ }
+
IrInstruction **casted_args = allocate<IrInstruction *>(call_param_count);
size_t next_arg_index = 0;
-
if (first_arg_ptr) {
IrInstruction *first_arg = ir_get_deref(ira, first_arg_ptr, first_arg_ptr);
if (first_arg->type_entry->id == TypeTableEntryIdInvalid)
test/self_hosted2.zig
@@ -111,6 +111,33 @@ fn testCompileTimeFib() {
assert(fib_7 == 13);
}
+fn max(inline T: type, a: T, b: T) -> T {
+ if (a > b) a else b
+}
+const the_max = max(u32, 1234, 5678);
+
+fn testCompileTimeGenericEval() {
+ assert(the_max == 5678);
+}
+
+fn gimmeTheBigOne(a: u32, b: u32) -> u32 {
+ max(u32, a, b)
+}
+
+fn shouldCallSameInstance(a: u32, b: u32) -> u32 {
+ max(u32, a, b)
+}
+
+fn sameButWithFloats(a: f64, b: f64) -> f64 {
+ max(f64, a, b)
+}
+
+fn testFnWithInlineArgs() {
+ assert(gimmeTheBigOne(1234, 5678) == 5678);
+ assert(shouldCallSameInstance(34, 12) == 34);
+ assert(sameButWithFloats(0.43, 0.49) == 0.49);
+}
+
fn assert(ok: bool) {
if (!ok)
@@ -129,6 +156,8 @@ fn runAllTests() {
testStructStatic();
testStaticFnEval();
testCompileTimeFib();
+ testCompileTimeGenericEval();
+ testFnWithInlineArgs();
}
export nakedcc fn _start() -> unreachable {