Commit 363606d87b

Andrew Kelley <superjoe30@gmail.com>
2016-12-05 07:08:17
IR: inline function evaluation works on generic functions
1 parent 25a89e7
Changed files (3)
src/analyze.cpp
@@ -878,7 +878,7 @@ static IrInstruction *analyze_const_value(CodeGen *g, Scope *scope, AstNode *nod
     return ir_eval_const_value(g, scope, node, type_entry, &backward_branch_count, default_backward_branch_quota);
 }
 
-static TypeTableEntry *analyze_type_expr(CodeGen *g, Scope *scope, AstNode *node) {
+TypeTableEntry *analyze_type_expr(CodeGen *g, Scope *scope, AstNode *node) {
     IrInstruction *result = analyze_const_value(g, scope, node, g->builtin_types.entry_type);
     if (result->type_entry->id == TypeTableEntryIdInvalid)
         return g->builtin_types.entry_invalid;
@@ -889,6 +889,19 @@ static TypeTableEntry *analyze_type_expr(CodeGen *g, Scope *scope, AstNode *node
 
 static TypeTableEntry *get_generic_fn_type(CodeGen *g, FnTypeId *fn_type_id) {
     TypeTableEntry *fn_type = new_type_table_entry(TypeTableEntryIdFn);
+    buf_init_from_str(&fn_type->name, "fn(");
+    size_t i = 0;
+    for (; i < fn_type_id->next_param_index; i += 1) {
+        const char *comma_str = (i == 0) ? "" : ",";
+        buf_appendf(&fn_type->name, "%s%s", comma_str,
+            buf_ptr(&fn_type_id->param_info[i].type->name));
+    }
+    for (; i < fn_type_id->param_count; i += 1) {
+        const char *comma_str = (i == 0) ? "" : ",";
+        buf_appendf(&fn_type->name, "%svar", comma_str);
+    }
+    buf_appendf(&fn_type->name, ")->var");
+
     fn_type->data.fn.fn_type_id = *fn_type_id;
     fn_type->data.fn.is_generic = true;
     return fn_type;
src/analyze.hpp
@@ -69,6 +69,7 @@ void init_tld(Tld *tld, TldId id, Buf *name, VisibMod visib_mod, AstNode *source
     Scope *parent_scope, Tld *parent_tld);
 VariableTableEntry *add_variable(CodeGen *g, AstNode *source_node, Scope *parent_scope, Buf *name,
     TypeTableEntry *type_entry, bool is_const, ConstExprValue *init_value);
+TypeTableEntry *analyze_type_expr(CodeGen *g, Scope *scope, AstNode *node);
 
 Scope *create_block_scope(AstNode *node, Scope *parent);
 Scope *create_defer_scope(AstNode *node, Scope *parent);
src/ir.cpp
@@ -2119,7 +2119,7 @@ static IrInstruction *ir_gen_for_expr(IrBuilder *irb, Scope *parent_scope, AstNo
     child_scope = elem_var->child_scope;
 
     IrInstruction *undefined_value = ir_build_const_undefined(irb, child_scope, elem_node);
-    ir_build_var_decl(irb, child_scope, elem_node, elem_var, elem_var_type, undefined_value); 
+    ir_build_var_decl(irb, child_scope, elem_node, elem_var, elem_var_type, undefined_value);
     IrInstruction *elem_var_ptr = ir_build_var_ptr(irb, child_scope, node, elem_var);
 
     AstNode *index_var_source_node;
@@ -2137,7 +2137,7 @@ static IrInstruction *ir_gen_for_expr(IrBuilder *irb, Scope *parent_scope, AstNo
     IrInstruction *usize = ir_build_const_type(irb, child_scope, node, irb->codegen->builtin_types.entry_usize);
     IrInstruction *zero = ir_build_const_usize(irb, child_scope, node, 0);
     IrInstruction *one = ir_build_const_usize(irb, child_scope, node, 1);
-    ir_build_var_decl(irb, child_scope, index_var_source_node, index_var, usize, zero); 
+    ir_build_var_decl(irb, child_scope, index_var_source_node, index_var, usize, zero);
     IrInstruction *index_ptr = ir_build_var_ptr(irb, child_scope, node, index_var);
 
 
@@ -2347,7 +2347,7 @@ static IrInstruction *ir_gen_if_var_expr(IrBuilder *irb, Scope *scope, AstNode *
 
     IrInstruction *var_ptr_value = ir_build_unwrap_maybe(irb, scope, node, expr_value, false);
     IrInstruction *var_value = var_is_ptr ? var_ptr_value : ir_build_load_ptr(irb, scope, node, var_ptr_value);
-    ir_build_var_decl(irb, scope, node, var, var_type, var_value); 
+    ir_build_var_decl(irb, scope, node, var, var_type, var_value);
     IrInstruction *then_expr_result = ir_gen_node(irb, then_node, var->child_scope);
     if (then_expr_result == irb->codegen->invalid_instruction)
         return then_expr_result;
@@ -2405,7 +2405,7 @@ static bool ir_gen_switch_prong_expr(IrBuilder *irb, Scope *scope, AstNode *swit
             var_value = var_is_ptr ? target_value_ptr : ir_build_load_ptr(irb, scope, var_symbol_node, target_value_ptr);
         }
         IrInstruction *var_type = nullptr; // infer the type
-        ir_build_var_decl(irb, scope, var_symbol_node, var, var_type, var_value); 
+        ir_build_var_decl(irb, scope, var_symbol_node, var, var_type, var_value);
     } else {
         child_scope = scope;
     }
@@ -3287,44 +3287,6 @@ IrInstruction *ir_eval_const_value(CodeGen *codegen, Scope *scope, AstNode *node
     return result;
 }
 
-static IrInstruction *ir_eval_fn(IrAnalyze *ira, IrInstruction *source_instruction,
-    FnTableEntry *fn_entry, IrInstruction **args)
-{
-    if (!fn_entry) {
-        ir_add_error(ira, source_instruction,
-            buf_sprintf("unable to evaluate constant expression"));
-        return ira->codegen->invalid_instruction;
-    }
-
-    if (!ir_emit_backward_branch(ira, source_instruction))
-        return ira->codegen->invalid_instruction;
-
-    TypeTableEntry *fn_type = fn_entry->type_entry;
-    FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
-
-    // Fork a scope of the function with known values for the parameters.
-
-    Scope *exec_scope = &fn_entry->fndef_scope->base;
-    for (size_t i = 0; i < fn_type_id->param_count; i += 1) {
-        AstNode *param_decl_node = fn_entry->proto_node->data.fn_proto.params.at(i);
-        Buf *param_name = param_decl_node->data.param_decl.name;
-        IrInstruction *arg = args[i];
-        ConstExprValue *arg_val = ir_resolve_const(ira, arg);
-        if (!arg_val)
-            return ira->codegen->invalid_instruction;
-
-        VariableTableEntry *var = add_variable(ira->codegen, param_decl_node, exec_scope, param_name,
-                arg->type_entry, true, arg_val);
-        exec_scope = var->child_scope;
-    }
-
-    // Analyze the fn body block like any other constant expression.
-
-    AstNode *body_node = fn_entry->fn_def_node->data.fn_def.body;
-    return ir_eval_const_value(ira->codegen, exec_scope, body_node, fn_type_id->return_type,
-        ira->new_irb.exec->backward_branch_count, ira->new_irb.exec->backward_branch_quota);
-}
-
 static TypeTableEntry *ir_resolve_type_lval(IrAnalyze *ira, IrInstruction *type_value, LValPurpose lval) {
     if (lval != LValPurposeNone)
         zig_panic("TODO");
@@ -4257,6 +4219,33 @@ static TypeTableEntry *ir_analyze_instruction_decl_var(IrAnalyze *ira, IrInstruc
     return ira->codegen->builtin_types.entry_void;
 }
 
+static bool ir_analyze_fn_call_inline_arg(IrAnalyze *ira, AstNode *fn_proto_node,
+    IrInstruction *arg, Scope **exec_scope, size_t *next_arg_index)
+{
+    AstNode *param_decl_node = fn_proto_node->data.fn_proto.params.at(*next_arg_index);
+    assert(param_decl_node->type == NodeTypeParamDecl);
+    AstNode *param_type_node = param_decl_node->data.param_decl.type;
+    TypeTableEntry *param_type = analyze_type_expr(ira->codegen, *exec_scope, param_type_node);
+    if (param_type->id == TypeTableEntryIdInvalid)
+        return false;
+
+    IrInstruction *casted_arg = ir_get_casted_value(ira, arg, param_type);
+    if (casted_arg->type_entry->id == TypeTableEntryIdInvalid)
+        return false;
+
+    ConstExprValue *first_arg_val = ir_resolve_const(ira, casted_arg);
+    if (!first_arg_val)
+        return false;
+
+    Buf *param_name = param_decl_node->data.param_decl.name;
+    VariableTableEntry *var = add_variable(ira->codegen, param_decl_node,
+        *exec_scope, param_name, param_type, true, first_arg_val);
+    *exec_scope = var->child_scope;
+    *next_arg_index += 1;
+
+    return true;
+}
+
 static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *call_instruction,
     FnTableEntry *fn_entry, TypeTableEntry *fn_type, IrInstruction *fn_ref,
     IrInstruction *first_arg_ptr, bool is_inline)
@@ -4289,8 +4278,58 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
         return ira->codegen->builtin_types.entry_invalid;
     }
 
+    if (is_inline) {
+        if (!fn_entry) {
+            ir_add_error(ira, fn_ref, buf_sprintf("unable to evaluate constant expression"));
+            return ira->codegen->builtin_types.entry_invalid;
+        }
+
+        if (!ir_emit_backward_branch(ira, &call_instruction->base))
+            return ira->codegen->builtin_types.entry_invalid;
+
+        // Fork a scope of the function with known values for the parameters.
+        Scope *exec_scope = &fn_entry->fndef_scope->base;
+
+        size_t next_arg_index = 0;
+        if (first_arg_ptr) {
+            IrInstruction *first_arg = ir_get_deref(ira, first_arg_ptr, first_arg_ptr);
+            if (first_arg->type_entry->id == TypeTableEntryIdInvalid)
+                return ira->codegen->builtin_types.entry_invalid;
+
+            if (!ir_analyze_fn_call_inline_arg(ira, fn_proto_node, first_arg, &exec_scope, &next_arg_index))
+                return ira->codegen->builtin_types.entry_invalid;
+        }
+
+        for (size_t call_i = 0; call_i < call_instruction->arg_count; call_i += 1) {
+            IrInstruction *old_arg = call_instruction->args[call_i]->other;
+            if (old_arg->type_entry->id == TypeTableEntryIdInvalid)
+                return ira->codegen->builtin_types.entry_invalid;
+
+            if (!ir_analyze_fn_call_inline_arg(ira, fn_proto_node, old_arg, &exec_scope, &next_arg_index))
+                return ira->codegen->builtin_types.entry_invalid;
+        }
+
+        AstNode *return_type_node = fn_proto_node->data.fn_proto.return_type;
+        TypeTableEntry *return_type = analyze_type_expr(ira->codegen, exec_scope, return_type_node);
+        if (return_type->id == TypeTableEntryIdInvalid)
+            return ira->codegen->builtin_types.entry_invalid;
+
+        // Analyze the fn body block like any other constant expression.
+        AstNode *body_node = fn_entry->fn_def_node->data.fn_def.body;
+        IrInstruction *result = ir_eval_const_value(ira->codegen, exec_scope, body_node, return_type,
+            ira->new_irb.exec->backward_branch_count, ira->new_irb.exec->backward_branch_quota);
+        if (result->type_entry->id == TypeTableEntryIdInvalid)
+            return ira->codegen->builtin_types.entry_invalid;
+
+        ConstExprValue *out_val = ir_build_const_from(ira, &call_instruction->base,
+                result->static_value.depends_on_compile_var);
+        *out_val = result->static_value;
+        return ir_finish_anal(ira, return_type);
+    }
+
     IrInstruction **casted_args = allocate<IrInstruction *>(call_param_count);
     size_t next_arg_index = 0;
+
     if (first_arg_ptr) {
         IrInstruction *first_arg = ir_get_deref(ira, first_arg_ptr, first_arg_ptr);
         if (first_arg->type_entry->id == TypeTableEntryIdInvalid)
@@ -4304,9 +4343,6 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
         if (casted_arg->type_entry->id == TypeTableEntryIdInvalid)
             return ira->codegen->builtin_types.entry_invalid;
 
-        if (is_inline && !ir_resolve_const(ira, casted_arg))
-            return ira->codegen->builtin_types.entry_invalid;
-
         casted_args[next_arg_index] = casted_arg;
         next_arg_index += 1;
     }
@@ -4326,9 +4362,6 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
             casted_arg = old_arg;
         }
 
-        if (is_inline && !ir_resolve_const(ira, casted_arg))
-            return ira->codegen->builtin_types.entry_invalid;
-
         casted_args[next_arg_index] = casted_arg;
         next_arg_index += 1;
     }
@@ -4339,25 +4372,13 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
     if (return_type->id == TypeTableEntryIdInvalid)
         return ira->codegen->builtin_types.entry_invalid;
 
-    if (is_inline) {
-        assert(call_param_count == fn_type_id->param_count);
-        IrInstruction *result = ir_eval_fn(ira, &call_instruction->base, fn_entry, casted_args);
-        if (result->type_entry->id == TypeTableEntryIdInvalid)
-            return ira->codegen->builtin_types.entry_invalid;
-
-        ConstExprValue *out_val = ir_build_const_from(ira, &call_instruction->base,
-                result->static_value.depends_on_compile_var);
-        *out_val = result->static_value;
-        return ir_finish_anal(ira, return_type);
-    }
-
     IrInstruction *new_call_instruction = ir_build_call_from(&ira->new_irb, &call_instruction->base,
             fn_entry, fn_ref, call_param_count, casted_args);
 
     if (type_has_bits(return_type) && handle_is_ptr(return_type)) {
-        FnTableEntry *owner_fn = exec_fn_entry(ira->new_irb.exec);
-        assert(owner_fn);
-        owner_fn->alloca_list.append(new_call_instruction);
+        FnTableEntry *callsite_fn = exec_fn_entry(ira->new_irb.exec);
+        assert(callsite_fn);
+        callsite_fn->alloca_list.append(new_call_instruction);
     }
 
     return ir_finish_anal(ira, return_type);
@@ -5349,7 +5370,7 @@ static TypeTableEntry *ir_analyze_instruction_typeof(IrAnalyze *ira, IrInstructi
         case TypeTableEntryIdTypeDecl:
             {
                 ConstExprValue *out_val = ir_build_const_from(ira, &typeof_instruction->base, false);
-                // TODO depends_on_compile_var should be set based on whether the type of the expression 
+                // TODO depends_on_compile_var should be set based on whether the type of the expression
                 // depends_on_compile_var. but we currently don't have a thing to tell us if the type of
                 // something depends on a compile var
                 out_val->data.x_type = type_entry;