Commit 6db9be8900

Andrew Kelley <superjoe30@gmail.com>
2018-03-09 20:20:44
don't memoize comptime functions if they can mutate state via parameters
closes #639
1 parent aaf2230
src/all_types.hpp
@@ -1168,10 +1168,17 @@ struct TypeTableEntry {
     LLVMTypeRef type_ref;
     ZigLLVMDIType *di_type;
 
-    bool zero_bits;
+    bool zero_bits; // this is denormalized data
     bool is_copyable;
     bool gen_h_loop_flag;
 
+    // This is denormalized data. The simplest type that has this
+    // flag set to true is a mutable pointer. A const pointer has
+    // the same value for this flag as the child type.
+    // If a struct has any fields that have this flag true, then
+    // the flag is true for the struct.
+    bool can_mutate_state_through_it;
+
     union {
         TypeTableEntryPointer pointer;
         TypeTableEntryInt integral;
src/analyze.cpp
@@ -398,6 +398,7 @@ TypeTableEntry *get_pointer_to_type_extra(CodeGen *g, TypeTableEntry *child_type
 
     TypeTableEntry *entry = new_type_table_entry(TypeTableEntryIdPointer);
     entry->is_copyable = true;
+    entry->can_mutate_state_through_it = is_const ? child_type->can_mutate_state_through_it : true;
 
     const char *const_str = is_const ? "const " : "";
     const char *volatile_str = is_volatile ? "volatile " : "";
@@ -482,6 +483,7 @@ TypeTableEntry *get_maybe_type(CodeGen *g, TypeTableEntry *child_type) {
         assert(child_type->type_ref || child_type->zero_bits);
         assert(child_type->di_type);
         entry->is_copyable = type_is_copyable(g, child_type);
+        entry->can_mutate_state_through_it = child_type->can_mutate_state_through_it;
 
         buf_resize(&entry->name, 0);
         buf_appendf(&entry->name, "?%s", buf_ptr(&child_type->name));
@@ -572,6 +574,7 @@ TypeTableEntry *get_error_union_type(CodeGen *g, TypeTableEntry *err_set_type, T
     entry->is_copyable = true;
     assert(payload_type->di_type);
     ensure_complete_type(g, payload_type);
+    entry->can_mutate_state_through_it = payload_type->can_mutate_state_through_it;
 
     buf_resize(&entry->name, 0);
     buf_appendf(&entry->name, "%s!%s", buf_ptr(&err_set_type->name), buf_ptr(&payload_type->name));
@@ -730,6 +733,7 @@ TypeTableEntry *get_slice_type(CodeGen *g, TypeTableEntry *ptr_type) {
 
     TypeTableEntry *entry = new_type_table_entry(TypeTableEntryIdStruct);
     entry->is_copyable = true;
+    entry->can_mutate_state_through_it = ptr_type->can_mutate_state_through_it;
 
     // replace the & with [] to go from a ptr type name to a slice type name
     buf_resize(&entry->name, 0);
@@ -1735,6 +1739,8 @@ TypeTableEntry *get_struct_type(CodeGen *g, const char *type_name, const char *f
             struct_type->data.structure.gen_field_count += 1;
         } else {
             field->gen_index = SIZE_MAX;
+            struct_type->can_mutate_state_through_it = struct_type->can_mutate_state_through_it ||
+                field->type_entry->can_mutate_state_through_it;
         }
 
         auto prev_entry = struct_type->data.structure.fields_by_name.put_unique(field->name, field);
@@ -2475,6 +2481,9 @@ static void resolve_struct_zero_bits(CodeGen *g, TypeTableEntry *struct_type) {
         if (!type_has_bits(field_type))
             continue;
 
+        struct_type->can_mutate_state_through_it = struct_type->can_mutate_state_through_it ||
+            field_type->can_mutate_state_through_it;
+
         if (gen_field_index == 0) {
             if (struct_type->data.structure.layout == ContainerLayoutPacked) {
                 struct_type->data.structure.abi_alignment = 1;
@@ -2662,6 +2671,8 @@ static void resolve_union_zero_bits(CodeGen *g, TypeTableEntry *union_type) {
             }
         }
         union_field->type_entry = field_type;
+        union_type->can_mutate_state_through_it = union_type->can_mutate_state_through_it ||
+            field_type->can_mutate_state_through_it;
 
         if (field_node->data.struct_field.value != nullptr && !decl_node->data.container_decl.auto_enum) {
             ErrorMsg *msg = add_node_error(g, field_node->data.struct_field.value,
@@ -4565,6 +4576,23 @@ bool generic_fn_type_id_eql(GenericFnTypeId *a, GenericFnTypeId *b) {
     return true;
 }
 
+bool fn_eval_cacheable(Scope *scope) {
+    while (scope) {
+        if (scope->id == ScopeIdVarDecl) {
+            ScopeVarDecl *var_scope = (ScopeVarDecl *)scope;
+            if (var_scope->var->value->type->can_mutate_state_through_it)
+                return false;
+        } else if (scope->id == ScopeIdFnDef) {
+            return true;
+        } else {
+            zig_unreachable();
+        }
+
+        scope = scope->parent;
+    }
+    zig_unreachable();
+}
+
 uint32_t fn_eval_hash(Scope* scope) {
     uint32_t result = 0;
     while (scope) {
src/analyze.hpp
@@ -195,5 +195,6 @@ TypeTableEntry *get_auto_err_set_type(CodeGen *g, FnTableEntry *fn_entry);
 
 uint32_t get_coro_frame_align_bytes(CodeGen *g);
 bool fn_type_can_fail(FnTypeId *fn_type_id);
+bool fn_eval_cacheable(Scope *scope);
 
 #endif
src/ir.cpp
@@ -11830,12 +11830,15 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
             return_type = specified_return_type;
         }
 
-        IrInstruction *result;
+        bool cacheable = fn_eval_cacheable(exec_scope);
+        IrInstruction *result = nullptr;
+        if (cacheable) {
+            auto entry = ira->codegen->memoized_fn_eval_table.maybe_get(exec_scope);
+            if (entry)
+                result = entry->value;
+        }
 
-        auto entry = ira->codegen->memoized_fn_eval_table.maybe_get(exec_scope);
-        if (entry) {
-            result = entry->value;
-        } else {
+        if (result == nullptr) {
             // Analyze the fn body block like any other constant expression.
             AstNode *body_node = fn_entry->body_node;
             result = ir_eval_const_value(ira->codegen, exec_scope, body_node, return_type,
@@ -11859,7 +11862,9 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
                 }
             }
 
-            ira->codegen->memoized_fn_eval_table.put(exec_scope, result);
+            if (cacheable) {
+                ira->codegen->memoized_fn_eval_table.put(exec_scope, result);
+            }
 
             if (type_is_invalid(result->value.type))
                 return ira->codegen->builtin_types.entry_invalid;
std/sort.zig
@@ -964,8 +964,7 @@ fn u8desc(lhs: &const u8, rhs: &const u8) bool {
 
 test "stable sort" {
     testStableSort();
-    // TODO: uncomment this after https://github.com/zig-lang/zig/issues/639
-    //comptime testStableSort();
+    comptime testStableSort();
 }
 fn testStableSort() void {
     var expected = []IdAndValue {
test/cases/eval.zig
@@ -420,3 +420,31 @@ test "binary math operator in partially inlined function" {
     assert(s[2] == 0x90a0b0c);
     assert(s[3] == 0xd0e0f10);
 }
+
+
+test "comptime function with the same args is memoized" {
+    comptime {
+        assert(MakeType(i32) == MakeType(i32));
+        assert(MakeType(i32) != MakeType(f64));
+    }
+}
+
+fn MakeType(comptime T: type) type {
+    return struct {
+        field: T,
+    };
+}
+
+test "comptime function with mutable pointer is not memoized" {
+    comptime {
+        var x: i32 = 1;
+        const ptr = &x;
+        increment(ptr);
+        increment(ptr);
+        assert(x == 3);
+    }
+}
+
+fn increment(value: &i32) void {
+    *value += 1;
+}