Commit da24ea7f36

Jacob Young <jacobly0@users.noreply.github.com>
2023-06-02 10:24:25
Sema: rewrite `monomorphed_funcs` usage
In an effort to delete `Value.hashUncoerced`, generic instantiation has been redesigned. Instead of just storing instantiations in `monomorphed_funcs`, partially instantiated generic argument types are also cached. This isn't quite the single `getOrPut` that it used to be, but one `get` per generic argument plus one get for the instantiation, with an equal number of `put`s per unique instantiation isn't bad.
1 parent 04e66e6
Changed files (3)
src/Module.zig
@@ -99,6 +99,7 @@ tmp_hack_arena: std.heap.ArenaAllocator,
 /// This is currently only used for string literals.
 memoized_decls: std.AutoHashMapUnmanaged(InternPool.Index, Decl.Index) = .{},
 
+monomorphed_func_keys: std.ArrayListUnmanaged(InternPool.Index) = .{},
 /// The set of all the generic function instantiations. This is used so that when a generic
 /// function is called twice with the same comptime parameter arguments, both calls dispatch
 /// to the same function.
@@ -202,24 +203,40 @@ pub const CImportError = struct {
     }
 };
 
-const MonomorphedFuncsSet = std.HashMapUnmanaged(
-    Fn.Index,
-    void,
+pub const MonomorphedFuncKey = struct { func: Fn.Index, args_index: u32, args_len: u32 };
+
+pub const MonomorphedFuncAdaptedKey = struct { func: Fn.Index, args: []const InternPool.Index };
+
+pub const MonomorphedFuncsSet = std.HashMapUnmanaged(
+    MonomorphedFuncKey,
+    InternPool.Index,
     MonomorphedFuncsContext,
     std.hash_map.default_max_load_percentage,
 );
 
-const MonomorphedFuncsContext = struct {
+pub const MonomorphedFuncsContext = struct {
+    mod: *Module,
+
+    pub fn eql(_: @This(), a: MonomorphedFuncKey, b: MonomorphedFuncKey) bool {
+        return std.meta.eql(a, b);
+    }
+
+    pub fn hash(ctx: @This(), key: MonomorphedFuncKey) u64 {
+        const key_args = ctx.mod.monomorphed_func_keys.items[key.args_index..][0..key.args_len];
+        return std.hash.Wyhash.hash(@enumToInt(key.func), std.mem.sliceAsBytes(key_args));
+    }
+};
+
+pub const MonomorphedFuncsAdaptedContext = struct {
     mod: *Module,
 
-    pub fn eql(ctx: @This(), a: Fn.Index, b: Fn.Index) bool {
-        _ = ctx;
-        return a == b;
+    pub fn eql(ctx: @This(), adapted_key: MonomorphedFuncAdaptedKey, other_key: MonomorphedFuncKey) bool {
+        const other_key_args = ctx.mod.monomorphed_func_keys.items[other_key.args_index..][0..other_key.args_len];
+        return adapted_key.func == other_key.func and std.mem.eql(InternPool.Index, adapted_key.args, other_key_args);
     }
 
-    /// Must match `Sema.GenericCallAdapter.hash`.
-    pub fn hash(ctx: @This(), key: Fn.Index) u64 {
-        return ctx.mod.funcPtr(key).hash;
+    pub fn hash(_: @This(), adapted_key: MonomorphedFuncAdaptedKey) u64 {
+        return std.hash.Wyhash.hash(@enumToInt(adapted_key.func), std.mem.sliceAsBytes(adapted_key.args));
     }
 };
 
@@ -571,9 +588,6 @@ pub const Decl = struct {
     pub fn clearValues(decl: *Decl, mod: *Module) void {
         if (decl.getOwnedFunctionIndex(mod).unwrap()) |func| {
             _ = mod.align_stack_fns.remove(func);
-            if (mod.funcPtr(func).comptime_args != null) {
-                _ = mod.monomorphed_funcs.removeContext(func, .{ .mod = mod });
-            }
             mod.destroyFunc(func);
         }
     }
src/Sema.zig
@@ -6679,78 +6679,6 @@ fn callBuiltin(
     _ = try sema.analyzeCall(block, builtin_fn, func_ty, sema.src, sema.src, modifier, false, args, null, null);
 }
 
-const GenericCallAdapter = struct {
-    generic_fn: *Module.Fn,
-    precomputed_hash: u64,
-    func_ty_info: InternPool.Key.FuncType,
-    args: []const Arg,
-    module: *Module,
-
-    const Arg = struct {
-        ty: Type,
-        val: Value,
-        is_anytype: bool,
-    };
-
-    pub fn eql(ctx: @This(), adapted_key: void, other_key: Module.Fn.Index) bool {
-        _ = adapted_key;
-        const other_func = ctx.module.funcPtr(other_key);
-
-        // Checking for equality may happen on an item that has been inserted
-        // into the map but is not yet fully initialized. In such case, the
-        // two initialized fields are `hash` and `generic_owner_decl`.
-        if (ctx.generic_fn.owner_decl != other_func.generic_owner_decl.unwrap().?) return false;
-
-        const other_comptime_args = other_func.comptime_args.?;
-        for (other_comptime_args[0..ctx.func_ty_info.param_types.len], 0..) |other_arg, i| {
-            const this_arg = ctx.args[i];
-            const this_is_comptime = !this_arg.val.isGenericPoison();
-            const other_is_comptime = !other_arg.val.isGenericPoison();
-            const this_is_anytype = this_arg.is_anytype;
-            const other_is_anytype = other_func.isAnytypeParam(ctx.module, @intCast(u32, i));
-
-            if (other_is_anytype != this_is_anytype) return false;
-            if (other_is_comptime != this_is_comptime) return false;
-
-            if (this_is_anytype) {
-                // Both are anytype parameters.
-                if (!this_arg.ty.eql(other_arg.ty, ctx.module)) {
-                    return false;
-                }
-                if (this_is_comptime) {
-                    // Both are comptime and anytype parameters with matching types.
-                    if (!this_arg.val.eql(other_arg.val, other_arg.ty, ctx.module)) {
-                        return false;
-                    }
-                }
-            } else if (this_is_comptime) {
-                // Both are comptime parameters but not anytype parameters.
-                // We assert no error is possible here because any lazy values must be resolved
-                // before inserting into the generic function hash map.
-                const is_eql = Value.eqlAdvanced(
-                    this_arg.val,
-                    this_arg.ty,
-                    other_arg.val,
-                    other_arg.ty,
-                    ctx.module,
-                    null,
-                ) catch unreachable;
-                if (!is_eql) {
-                    return false;
-                }
-            }
-        }
-        return true;
-    }
-
-    /// The implementation of the hash is in semantic analysis of function calls, so
-    /// that any errors when computing the hash can be properly reported.
-    pub fn hash(ctx: @This(), adapted_key: void) u64 {
-        _ = adapted_key;
-        return ctx.precomputed_hash;
-    }
-};
-
 fn analyzeCall(
     sema: *Sema,
     block: *Block,
@@ -7480,11 +7408,12 @@ fn instantiateGenericCall(
     const ip = &mod.intern_pool;
 
     const func_val = try sema.resolveConstValue(block, func_src, func, "generic function being called must be comptime-known");
-    const module_fn = mod.funcPtr(switch (ip.indexToKey(func_val.toIntern())) {
+    const module_fn_index = switch (ip.indexToKey(func_val.toIntern())) {
         .func => |function| function.index,
         .ptr => |ptr| mod.declPtr(ptr.addr.decl).val.getFunctionIndex(mod).unwrap().?,
         else => unreachable,
-    });
+    };
+    const module_fn = mod.funcPtr(module_fn_index);
     // Check the Module's generic function map with an adapted context, so that we
     // can match against `uncasted_args` rather than doing the work below to create a
     // generic Scope only to junk it if it matches an existing instantiation.
@@ -7495,32 +7424,24 @@ fn instantiateGenericCall(
     const fn_info = fn_zir.getFnInfo(module_fn.zir_body_inst);
     const zir_tags = fn_zir.instructions.items(.tag);
 
-    // This hash must match `Module.MonomorphedFuncsContext.hash`.
-    // For parameters explicitly marked comptime and simple parameter type expressions,
-    // we know whether a parameter is elided from a monomorphed function, and can
-    // use it in the hash here. However, for parameter type expressions that are not
-    // explicitly marked comptime and rely on previous parameter comptime values, we
-    // don't find out until after generating a monomorphed function whether the parameter
-    // type ended up being a "must-be-comptime-known" type.
-    var hasher = std.hash.Wyhash.init(0);
-    std.hash.autoHash(&hasher, module_fn.owner_decl);
-
-    const generic_args = try sema.arena.alloc(GenericCallAdapter.Arg, func_ty_info.param_types.len);
-    {
-        var i: usize = 0;
+    const generic_args = try sema.arena.alloc(InternPool.Index, func_ty_info.param_types.len);
+    const callee_index = callee: {
+        var arg_i: usize = 0;
+        var generic_arg_i: u32 = 0;
+        var known_unique = false;
         for (fn_info.param_body) |inst| {
             var is_comptime = false;
             var is_anytype = false;
             switch (zir_tags[inst]) {
                 .param => {
-                    is_comptime = func_ty_info.paramIsComptime(@intCast(u5, i));
+                    is_comptime = func_ty_info.paramIsComptime(@intCast(u5, arg_i));
                 },
                 .param_comptime => {
                     is_comptime = true;
                 },
                 .param_anytype => {
                     is_anytype = true;
-                    is_comptime = func_ty_info.paramIsComptime(@intCast(u5, i));
+                    is_comptime = func_ty_info.paramIsComptime(@intCast(u5, arg_i));
                 },
                 .param_anytype_comptime => {
                     is_anytype = true;
@@ -7529,7 +7450,15 @@ fn instantiateGenericCall(
                 else => continue,
             }
 
-            const arg_ty = sema.typeOf(uncasted_args[i]);
+            defer arg_i += 1;
+            if (known_unique) {
+                if (is_comptime or is_anytype) {
+                    generic_arg_i += 1;
+                }
+                continue;
+            }
+
+            const arg_ty = sema.typeOf(uncasted_args[arg_i]);
             if (is_comptime or is_anytype) {
                 // Tuple default values are a part of the type and need to be
                 // resolved to hash the type.
@@ -7537,69 +7466,72 @@ fn instantiateGenericCall(
             }
 
             if (is_comptime) {
-                const arg_val = sema.analyzeGenericCallArgVal(block, .unneeded, uncasted_args[i]) catch |err| switch (err) {
+                const arg_val = sema.analyzeGenericCallArgVal(block, .unneeded, uncasted_args[arg_i]) catch |err| switch (err) {
                     error.NeededSourceLocation => {
                         const decl = sema.mod.declPtr(block.src_decl);
-                        const arg_src = mod.argSrc(call_src.node_offset.x, decl, i, bound_arg_src);
-                        _ = try sema.analyzeGenericCallArgVal(block, arg_src, uncasted_args[i]);
+                        const arg_src = mod.argSrc(call_src.node_offset.x, decl, arg_i, bound_arg_src);
+                        _ = try sema.analyzeGenericCallArgVal(block, arg_src, uncasted_args[arg_i]);
                         unreachable;
                     },
                     else => |e| return e,
                 };
-                arg_val.hashUncoerced(arg_ty, &hasher, mod);
+
                 if (is_anytype) {
-                    std.hash.autoHash(&hasher, arg_ty.toIntern());
-                    generic_args[i] = .{
-                        .ty = arg_ty,
-                        .val = arg_val,
-                        .is_anytype = true,
-                    };
+                    generic_args[generic_arg_i] = arg_val.toIntern();
                 } else {
-                    generic_args[i] = .{
-                        .ty = arg_ty,
-                        .val = arg_val,
-                        .is_anytype = false,
+                    const final_arg_ty = mod.monomorphed_funcs.getAdapted(
+                        Module.MonomorphedFuncAdaptedKey{
+                            .func = module_fn_index,
+                            .args = generic_args[0..generic_arg_i],
+                        },
+                        Module.MonomorphedFuncsAdaptedContext{ .mod = mod },
+                    ) orelse {
+                        known_unique = true;
+                        generic_arg_i += 1;
+                        continue;
+                    };
+                    const casted_arg = sema.coerce(block, final_arg_ty.toType(), uncasted_args[arg_i], .unneeded) catch |err| switch (err) {
+                        error.NeededSourceLocation => {
+                            const decl = sema.mod.declPtr(block.src_decl);
+                            const arg_src = mod.argSrc(call_src.node_offset.x, decl, arg_i, bound_arg_src);
+                            _ = try sema.coerce(block, final_arg_ty.toType(), uncasted_args[arg_i], arg_src);
+                            unreachable;
+                        },
+                        else => |e| return e,
                     };
+                    const casted_arg_val = sema.analyzeGenericCallArgVal(block, .unneeded, casted_arg) catch |err| switch (err) {
+                        error.NeededSourceLocation => {
+                            const decl = sema.mod.declPtr(block.src_decl);
+                            const arg_src = mod.argSrc(call_src.node_offset.x, decl, arg_i, bound_arg_src);
+                            _ = try sema.analyzeGenericCallArgVal(block, arg_src, casted_arg);
+                            unreachable;
+                        },
+                        else => |e| return e,
+                    };
+                    generic_args[generic_arg_i] = casted_arg_val.toIntern();
                 }
+                generic_arg_i += 1;
             } else if (is_anytype) {
-                std.hash.autoHash(&hasher, arg_ty.toIntern());
-                generic_args[i] = .{
-                    .ty = arg_ty,
-                    .val = Value.generic_poison,
-                    .is_anytype = true,
-                };
-            } else {
-                generic_args[i] = .{
-                    .ty = arg_ty,
-                    .val = Value.generic_poison,
-                    .is_anytype = false,
-                };
+                generic_args[generic_arg_i] = arg_ty.toIntern();
+                generic_arg_i += 1;
             }
-
-            i += 1;
         }
-    }
 
-    const precomputed_hash = hasher.final();
+        if (!known_unique) {
+            if (mod.monomorphed_funcs.getAdapted(
+                Module.MonomorphedFuncAdaptedKey{
+                    .func = module_fn_index,
+                    .args = generic_args[0..generic_arg_i],
+                },
+                Module.MonomorphedFuncsAdaptedContext{ .mod = mod },
+            )) |callee_func| break :callee mod.intern_pool.indexToKey(callee_func).func.index;
+        }
 
-    const adapter: GenericCallAdapter = .{
-        .generic_fn = module_fn,
-        .precomputed_hash = precomputed_hash,
-        .func_ty_info = func_ty_info,
-        .args = generic_args,
-        .module = mod,
-    };
-    const gop = try mod.monomorphed_funcs.getOrPutContextAdapted(gpa, {}, adapter, .{ .mod = mod });
-    const callee_index = if (!gop.found_existing) callee: {
         const new_module_func_index = try mod.createFunc(undefined);
         const new_module_func = mod.funcPtr(new_module_func_index);
 
-        // This ensures that we can operate on the hash map before the Module.Fn
-        // struct is fully initialized.
-        new_module_func.hash = precomputed_hash;
         new_module_func.generic_owner_decl = module_fn.owner_decl.toOptional();
         new_module_func.comptime_args = null;
-        gop.key_ptr.* = new_module_func_index;
 
         try namespace.anon_decls.ensureUnusedCapacity(gpa, 1);
 
@@ -7641,7 +7573,8 @@ fn instantiateGenericCall(
             new_decl,
             new_decl_index,
             uncasted_args,
-            module_fn,
+            generic_arg_i,
+            module_fn_index,
             new_module_func_index,
             namespace_index,
             func_ty_info,
@@ -7657,12 +7590,10 @@ fn instantiateGenericCall(
                 }
                 assert(namespace.anon_decls.orderedRemove(new_decl_index));
                 mod.destroyDecl(new_decl_index);
-                assert(mod.monomorphed_funcs.removeContext(new_module_func_index, .{ .mod = mod }));
                 mod.destroyFunc(new_module_func_index);
                 return err;
             },
             else => {
-                assert(mod.monomorphed_funcs.removeContext(new_module_func_index, .{ .mod = mod }));
                 // TODO look up the compile error that happened here and attach a note to it
                 // pointing here, at the generic instantiation callsite.
                 if (sema.owner_func) |owner_func| {
@@ -7675,9 +7606,8 @@ fn instantiateGenericCall(
         };
 
         break :callee new_func;
-    } else gop.key_ptr.*;
+    };
     const callee = mod.funcPtr(callee_index);
-
     callee.branch_quota = @max(callee.branch_quota, sema.branch_quota);
 
     const callee_inst = try sema.analyzeDeclVal(block, func_src, callee.owner_decl);
@@ -7752,7 +7682,7 @@ fn instantiateGenericCall(
     if (call_tag == .call_always_tail) {
         return sema.handleTailCall(block, call_src, func_ty, result);
     }
-    if (new_fn_info.return_type == .noreturn_type) {
+    if (func_ty.fnReturnType(mod).isNoReturn(mod)) {
         _ = try block.addNoOp(.unreach);
         return Air.Inst.Ref.unreachable_value;
     }
@@ -7766,7 +7696,8 @@ fn resolveGenericInstantiationType(
     new_decl: *Decl,
     new_decl_index: Decl.Index,
     uncasted_args: []const Air.Inst.Ref,
-    module_fn: *Module.Fn,
+    generic_args_len: u32,
+    module_fn_index: Module.Fn.Index,
     new_module_func: Module.Fn.Index,
     namespace: Namespace.Index,
     func_ty_info: InternPool.Key.FuncType,
@@ -7777,6 +7708,7 @@ fn resolveGenericInstantiationType(
     const gpa = sema.gpa;
 
     const zir_tags = fn_zir.instructions.items(.tag);
+    const module_fn = mod.funcPtr(module_fn_index);
     const fn_info = fn_zir.getFnInfo(module_fn.zir_body_inst);
 
     // Re-run the block that creates the function, with the comptime parameters
@@ -7893,9 +7825,15 @@ fn resolveGenericInstantiationType(
     const new_func = new_func_val.getFunctionIndex(mod).unwrap().?;
     assert(new_func == new_module_func);
 
+    const generic_args_index = @intCast(u32, mod.monomorphed_func_keys.items.len);
+    const generic_args = try mod.monomorphed_func_keys.addManyAsSlice(gpa, generic_args_len);
+    var generic_arg_i: u32 = 0;
+    try mod.monomorphed_funcs.ensureUnusedCapacityContext(gpa, generic_args_len + 1, .{ .mod = mod });
+
     arg_i = 0;
     for (fn_info.param_body) |inst| {
         var is_comptime = false;
+        var is_anytype = false;
         switch (zir_tags[inst]) {
             .param => {
                 is_comptime = func_ty_info.paramIsComptime(@intCast(u5, arg_i));
@@ -7904,9 +7842,11 @@ fn resolveGenericInstantiationType(
                 is_comptime = true;
             },
             .param_anytype => {
+                is_anytype = true;
                 is_comptime = func_ty_info.paramIsComptime(@intCast(u5, arg_i));
             },
             .param_anytype_comptime => {
+                is_anytype = true;
                 is_comptime = true;
             },
             else => continue,
@@ -7924,11 +7864,24 @@ fn resolveGenericInstantiationType(
 
         if (is_comptime) {
             const arg_val = (child_sema.resolveMaybeUndefValAllowVariables(arg) catch unreachable).?;
+            if (!is_anytype) {
+                if (mod.monomorphed_funcs.fetchPutAssumeCapacityContext(.{
+                    .func = module_fn_index,
+                    .args_index = generic_args_index,
+                    .args_len = generic_arg_i,
+                }, arg_ty.toIntern(), .{ .mod = mod })) |kv| assert(kv.value == arg_ty.toIntern());
+            }
+            generic_args[generic_arg_i] = arg_val.toIntern();
+            generic_arg_i += 1;
             child_sema.comptime_args[arg_i] = .{
                 .ty = arg_ty,
                 .val = (try arg_val.intern(arg_ty, mod)).toValue(),
             };
         } else {
+            if (is_anytype) {
+                generic_args[generic_arg_i] = arg_ty.toIntern();
+                generic_arg_i += 1;
+            }
             child_sema.comptime_args[arg_i] = .{
                 .ty = arg_ty,
                 .val = Value.generic_poison,
@@ -7963,6 +7916,12 @@ fn resolveGenericInstantiationType(
     new_decl.owns_tv = true;
     new_decl.analysis = .complete;
 
+    mod.monomorphed_funcs.putAssumeCapacityNoClobberContext(.{
+        .func = module_fn_index,
+        .args_index = generic_args_index,
+        .args_len = generic_arg_i,
+    }, new_decl.val.toIntern(), .{ .mod = mod });
+
     // Queue up a `codegen_func` work item for the new Fn. The `comptime_args` field
     // will be populated, ensuring it will have `analyzeBody` called with the ZIR
     // parameters mapped appropriately.
src/value.zig
@@ -1691,77 +1691,6 @@ pub const Value = struct {
         return (try orderAdvanced(a, b, mod, opt_sema)).compare(.eq);
     }
 
-    /// This is a more conservative hash function that produces equal hashes for values
-    /// that can coerce into each other.
-    /// This function is used by hash maps and so treats floating-point NaNs as equal
-    /// to each other, and not equal to other floating-point values.
-    pub fn hashUncoerced(val: Value, ty: Type, hasher: *std.hash.Wyhash, mod: *Module) void {
-        if (val.isUndef(mod)) return;
-        // The value is runtime-known and shouldn't affect the hash.
-        if (val.isRuntimeValue(mod)) return;
-
-        if (val.ip_index != .none) {
-            // The InternPool data structure hashes based on Key to make interned objects
-            // unique. An Index can be treated simply as u32 value for the
-            // purpose of Type/Value hashing and equality.
-            std.hash.autoHash(hasher, val.toIntern());
-            return;
-        }
-
-        switch (ty.zigTypeTag(mod)) {
-            .Opaque => unreachable, // Cannot hash opaque types
-            .Void,
-            .NoReturn,
-            .Undefined,
-            .Null,
-            .Struct, // It sure would be nice to do something clever with structs.
-            => |zig_type_tag| std.hash.autoHash(hasher, zig_type_tag),
-            .Pointer => {
-                assert(ty.isSlice(mod));
-                const slice = val.castTag(.slice).?.data;
-                const ptr_ty = ty.slicePtrFieldType(mod);
-                slice.ptr.hashUncoerced(ptr_ty, hasher, mod);
-            },
-            .Type,
-            .Float,
-            .ComptimeFloat,
-            .Bool,
-            .Int,
-            .ComptimeInt,
-            .Fn,
-            .Optional,
-            .ErrorSet,
-            .ErrorUnion,
-            .Enum,
-            .EnumLiteral,
-            => unreachable, // handled above with the ip_index check
-            .Array, .Vector => {
-                const len = ty.arrayLen(mod);
-                const elem_ty = ty.childType(mod);
-                var index: usize = 0;
-                while (index < len) : (index += 1) {
-                    const elem_val = val.elemValue(mod, index) catch |err| switch (err) {
-                        // Will be solved when arrays and vectors get migrated to the intern pool.
-                        error.OutOfMemory => @panic("OOM"),
-                    };
-                    elem_val.hashUncoerced(elem_ty, hasher, mod);
-                }
-            },
-            .Union => {
-                hasher.update(val.tagName(mod));
-                switch (mod.intern_pool.indexToKey(val.toIntern())) {
-                    .un => |un| {
-                        const active_field_ty = ty.unionFieldType(un.tag.toValue(), mod);
-                        un.val.toValue().hashUncoerced(active_field_ty, hasher, mod);
-                    },
-                    else => std.hash.autoHash(hasher, std.builtin.TypeId.Void),
-                }
-            },
-            .Frame => @panic("TODO implement hashing frame values"),
-            .AnyFrame => @panic("TODO implement hashing anyframe values"),
-        }
-    }
-
     pub fn isComptimeMutablePtr(val: Value, mod: *Module) bool {
         return switch (mod.intern_pool.indexToKey(val.toIntern())) {
             .ptr => |ptr| switch (ptr.addr) {