Commit f12fbce0f5

Andrew Kelley <superjoe30@gmail.com>
2016-12-19 00:23:46
IR: memoize compile-time evaluated fn invocations
1 parent 4816121
src/all_types.hpp
@@ -1067,6 +1067,9 @@ struct BuiltinFnEntry {
     LLVMValueRef fn_val;
 };
 
+uint32_t fn_eval_hash(Scope*);
+bool fn_eval_eql(Scope *a, Scope *b);
+
 struct CodeGen {
     LLVMModuleRef module;
     ZigList<ErrorMsg*> errors;
@@ -1085,6 +1088,7 @@ struct CodeGen {
     HashMap<FnTypeId *, TypeTableEntry *, fn_type_id_hash, fn_type_id_eql> fn_type_table;
     HashMap<Buf *, ErrorTableEntry *, buf_hash, buf_eql_buf> error_table;
     HashMap<GenericFnTypeId *, FnTableEntry *, generic_fn_type_id_hash, generic_fn_type_id_eql> generic_table;
+    HashMap<Scope *, IrInstruction *, fn_eval_hash, fn_eval_eql> memoized_fn_eval_table;
 
     ZigList<ImportTableEntry *> import_queue;
     size_t import_queue_index;
src/analyze.cpp
@@ -2693,6 +2693,54 @@ bool generic_fn_type_id_eql(GenericFnTypeId *a, GenericFnTypeId *b) {
     return true;
 }
 
+uint32_t fn_eval_hash(Scope* scope) {
+    uint32_t result = 0;
+    while (scope) {
+        if (scope->id == ScopeIdVarDecl) {
+            ScopeVarDecl *var_scope = (ScopeVarDecl *)scope;
+            result += hash_const_val(var_scope->var->type, var_scope->var->value);
+        } else if (scope->id == ScopeIdFnDef) {
+            ScopeFnDef *fn_scope = (ScopeFnDef *)scope;
+            result += hash_ptr(fn_scope->fn_entry);
+            return result;
+        } else {
+            zig_unreachable();
+        }
+
+        scope = scope->parent;
+    }
+    zig_unreachable();
+}
+
+bool fn_eval_eql(Scope *a, Scope *b) {
+    while (a && b) {
+        if (a->id != b->id)
+            return false;
+
+        if (a->id == ScopeIdVarDecl) {
+            ScopeVarDecl *a_var_scope = (ScopeVarDecl *)a;
+            ScopeVarDecl *b_var_scope = (ScopeVarDecl *)b;
+            if (a_var_scope->var->type != b_var_scope->var->type)
+                return false;
+            if (!const_values_equal(a_var_scope->var->value, b_var_scope->var->value, a_var_scope->var->type))
+                return false;
+        } else if (a->id == ScopeIdFnDef) {
+            ScopeFnDef *a_fn_scope = (ScopeFnDef *)a;
+            ScopeFnDef *b_fn_scope = (ScopeFnDef *)b;
+            if (a_fn_scope->fn_entry != b_fn_scope->fn_entry)
+                return false;
+
+            return true;
+        } else {
+            zig_unreachable();
+        }
+
+        a = a->parent;
+        b = b->parent;
+    }
+    return false;
+}
+
 bool type_has_bits(TypeTableEntry *type_entry) {
     assert(type_entry);
     assert(type_entry->id != TypeTableEntryIdInvalid);
src/codegen.cpp
@@ -61,6 +61,7 @@ CodeGen *codegen_create(Buf *root_source_dir, const ZigTarget *target) {
     g->fn_type_table.init(32);
     g->error_table.init(16);
     g->generic_table.init(16);
+    g->memoized_fn_eval_table.init(16);
     g->is_release_build = false;
     g->is_test_build = false;
     g->want_h_file = true;
src/ir.cpp
@@ -6049,13 +6049,22 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
         if (return_type->id == TypeTableEntryIdInvalid)
             return ira->codegen->builtin_types.entry_invalid;
 
-        // Analyze the fn body block like any other constant expression.
-        AstNode *body_node = fn_entry->fn_def_node->data.fn_def.body;
-        IrInstruction *result = ir_eval_const_value(ira->codegen, exec_scope, body_node, return_type,
-            ira->new_irb.exec->backward_branch_count, ira->new_irb.exec->backward_branch_quota, fn_entry,
-            nullptr, call_instruction->base.source_node, nullptr);
-        if (result->type_entry->id == TypeTableEntryIdInvalid)
-            return ira->codegen->builtin_types.entry_invalid;
+        IrInstruction *result;
+
+        auto entry = ira->codegen->memoized_fn_eval_table.maybe_get(exec_scope);
+        if (entry) {
+            result = entry->value;
+        } else {
+            // Analyze the fn body block like any other constant expression.
+            AstNode *body_node = fn_entry->fn_def_node->data.fn_def.body;
+            result = ir_eval_const_value(ira->codegen, exec_scope, body_node, return_type,
+                ira->new_irb.exec->backward_branch_count, ira->new_irb.exec->backward_branch_quota, fn_entry,
+                nullptr, call_instruction->base.source_node, nullptr);
+            if (result->type_entry->id == TypeTableEntryIdInvalid)
+                return ira->codegen->builtin_types.entry_invalid;
+
+            ira->codegen->memoized_fn_eval_table.put(exec_scope, result);
+        }
 
         ConstExprValue *out_val = ir_build_const_from(ira, &call_instruction->base,
                 result->static_value.depends_on_compile_var);