Commit 4a55fc6c53

Andrew Kelley <andrew@ziglang.org>
2023-07-09 23:41:36
InternPool: avoid false negatives for functions with inferred error sets
There is one case where function types may be inequal but we still want to find the same function body instance in InternPool. In the case of the functions having an inferred error set, the key used to find an existing function body will necessarily have a unique inferred error set type, because it refers to the function body InternPool Index. To make this case work we omit the inferred error set from the equality and hashing functions.
1 parent f3dc53f
Changed files (1)
src/InternPool.zig
@@ -537,6 +537,32 @@ pub const Key = union(enum) {
             assert(i < self.param_types.len);
             return @as(u1, @truncate(self.noalias_bits >> i)) != 0;
         }
+
+        pub fn eql(a: FuncType, b: FuncType, ip: *const InternPool) bool {
+            return std.mem.eql(Index, a.param_types.get(ip), b.param_types.get(ip)) and
+                a.return_type == b.return_type and
+                a.comptime_bits == b.comptime_bits and
+                a.noalias_bits == b.noalias_bits and
+                a.alignment == b.alignment and
+                a.cc == b.cc and
+                a.is_var_args == b.is_var_args and
+                a.is_generic == b.is_generic and
+                a.is_noinline == b.is_noinline;
+        }
+
+        pub fn hash(self: FuncType, hasher: *Hash, ip: *const InternPool) void {
+            for (self.param_types.get(ip)) |param_type| {
+                std.hash.autoHash(hasher, param_type);
+            }
+            std.hash.autoHash(hasher, self.return_type);
+            std.hash.autoHash(hasher, self.comptime_bits);
+            std.hash.autoHash(hasher, self.noalias_bits);
+            std.hash.autoHash(hasher, self.alignment);
+            std.hash.autoHash(hasher, self.cc);
+            std.hash.autoHash(hasher, self.is_var_args);
+            std.hash.autoHash(hasher, self.is_generic);
+            std.hash.autoHash(hasher, self.is_noinline);
+        }
     };
 
     pub const Variable = struct {
@@ -572,6 +598,7 @@ pub const Key = union(enum) {
         zir_body_inst_extra_index: u32,
         /// Index into extra array of the resolved inferred error set for this function.
         /// Used for mutating that data.
+        /// 0 when the function does not have an inferred error set.
         resolved_error_set_extra_index: u32,
         /// When a generic function is instantiated, branch_quota is inherited from the
         /// active Sema context. Importantly, this value is also updated when an existing
@@ -942,17 +969,7 @@ pub const Key = union(enum) {
 
             .func_type => |func_type| {
                 var hasher = Hash.init(seed);
-                for (func_type.param_types.get(ip)) |param_type| {
-                    std.hash.autoHash(&hasher, param_type);
-                }
-                std.hash.autoHash(&hasher, func_type.return_type);
-                std.hash.autoHash(&hasher, func_type.comptime_bits);
-                std.hash.autoHash(&hasher, func_type.noalias_bits);
-                std.hash.autoHash(&hasher, func_type.alignment);
-                std.hash.autoHash(&hasher, func_type.cc);
-                std.hash.autoHash(&hasher, func_type.is_var_args);
-                std.hash.autoHash(&hasher, func_type.is_generic);
-                std.hash.autoHash(&hasher, func_type.is_noinline);
+                func_type.hash(&hasher, ip);
                 return hasher.final();
             },
 
@@ -964,13 +981,24 @@ pub const Key = union(enum) {
             },
 
             .func => |func| {
-                if (func.generic_owner == .none)
+                // In the case of a function with an inferred error set, we
+                // must not include the inferred error set type in the hash,
+                // otherwise we would get false negatives for interning generic
+                // function instances which have inferred error sets.
+
+                if (func.generic_owner == .none and func.resolved_error_set_extra_index == 0)
                     return Hash.hash(seed, asBytes(&func.owner_decl) ++ asBytes(&func.ty));
 
                 var hasher = Hash.init(seed);
                 std.hash.autoHash(&hasher, func.generic_owner);
                 for (func.comptime_args.get(ip)) |arg| std.hash.autoHash(&hasher, arg);
-                std.hash.autoHash(&hasher, func.ty);
+                if (func.resolved_error_set_extra_index == 0) {
+                    std.hash.autoHash(&hasher, func.ty);
+                } else {
+                    var ty_info = ip.indexToFuncType(func.ty).?;
+                    ty_info.return_type = ip.errorUnionPayload(ty_info.return_type);
+                    ty_info.hash(&hasher, ip);
+                }
                 return hasher.final();
             },
 
@@ -1079,13 +1107,37 @@ pub const Key = union(enum) {
                 if (a_info.generic_owner != b_info.generic_owner)
                     return false;
 
-                if (a_info.ty != b_info.ty)
-                    return false;
+                if (a_info.generic_owner == .none) {
+                    if (a_info.owner_decl != b_info.owner_decl)
+                        return false;
+                } else {
+                    if (!std.mem.eql(
+                        Index,
+                        a_info.comptime_args.get(ip),
+                        b_info.comptime_args.get(ip),
+                    )) return false;
+                }
 
-                if (a_info.generic_owner == .none)
-                    return a_info.owner_decl == b_info.owner_decl;
+                if (a_info.ty == b_info.ty)
+                    return true;
 
-                return std.mem.eql(Index, a_info.comptime_args.get(ip), b_info.comptime_args.get(ip));
+                // There is one case where the types may be inequal but we
+                // still want to find the same function body instance. In the
+                // case of the functions having an inferred error set, the key
+                // used to find an existing function body will necessarily have
+                // a unique inferred error set type, because it refers to the
+                // function body InternPool Index. To make this case work we
+                // omit the inferred error set from the equality check.
+                if (a_info.resolved_error_set_extra_index == 0 or
+                    b_info.resolved_error_set_extra_index == 0)
+                {
+                    return false;
+                }
+                var a_ty_info = ip.indexToFuncType(a_info.ty).?;
+                a_ty_info.return_type = ip.errorUnionPayload(a_ty_info.return_type);
+                var b_ty_info = ip.indexToFuncType(b_info.ty).?;
+                b_ty_info.return_type = ip.errorUnionPayload(a_ty_info.return_type);
+                return a_ty_info.eql(b_ty_info, ip);
             },
 
             .ptr => |a_info| {
@@ -1246,16 +1298,7 @@ pub const Key = union(enum) {
 
             .func_type => |a_info| {
                 const b_info = b.func_type;
-
-                return std.mem.eql(Index, a_info.param_types.get(ip), b_info.param_types.get(ip)) and
-                    a_info.return_type == b_info.return_type and
-                    a_info.comptime_bits == b_info.comptime_bits and
-                    a_info.noalias_bits == b_info.noalias_bits and
-                    a_info.alignment == b_info.alignment and
-                    a_info.cc == b_info.cc and
-                    a_info.is_var_args == b_info.is_var_args and
-                    a_info.is_generic == b_info.is_generic and
-                    a_info.is_noinline == b_info.is_noinline;
+                return Key.FuncType.eql(a_info, b_info, ip);
             },
 
             .memoized_call => |a_info| {
@@ -2156,7 +2199,6 @@ pub const Tag = enum(u8) {
     const Error = Key.Error;
     const EnumTag = Key.EnumTag;
     const ExternFunc = Key.ExternFunc;
-    const Func = Key.Func;
     const Union = Key.Union;
     const TypePointer = Key.PtrType;
 
@@ -3137,7 +3179,7 @@ pub fn indexToKey(ip: *const InternPool, index: Index) Key {
         },
         .extern_func => .{ .extern_func = ip.extraData(Tag.ExternFunc, data) },
         .func_instance => .{ .func = ip.indexToKeyFuncInstance(data) },
-        .func_decl => .{ .func = ip.indexToKeyFuncDecl(data) },
+        .func_decl => .{ .func = ip.extraIndexToFuncDecl(data) },
         .only_possible_value => {
             const ty = @as(Index, @enumFromInt(data));
             const ty_item = ip.items.get(@intFromEnum(ty));
@@ -3286,9 +3328,9 @@ fn extraFuncType(ip: *const InternPool, extra_index: u32) Key.FuncType {
     };
 }
 
-fn indexToKeyFuncDecl(ip: *const InternPool, data: u32) Key.Func {
+fn extraIndexToFuncDecl(ip: *const InternPool, extra_index: u32) Key.Func {
     _ = ip;
-    _ = data;
+    _ = extra_index;
     @panic("TODO");
 }
 
@@ -4357,38 +4399,41 @@ pub fn getFuncDecl(ip: *InternPool, gpa: Allocator, key: GetFuncDeclKey) Allocat
 
     try ip.extra.ensureUnusedCapacity(gpa, @typeInfo(Tag.FuncDecl).Struct.fields.len);
     try ip.items.ensureUnusedCapacity(gpa, 1);
+    try ip.map.ensureUnusedCapacity(gpa, 1);
 
-    ip.items.appendAssumeCapacity(.{
-        .tag = .func_decl,
-        .data = ip.addExtraAssumeCapacity(Tag.FuncDecl{
-            .analysis = .{
-                .state = if (key.cc == .Inline) .inline_only else .none,
-                .is_cold = false,
-                .is_noinline = key.is_noinline,
-                .calls_or_awaits_errorable_fn = false,
-                .stack_alignment = .none,
-                .inferred_error_set = false,
-            },
-            .owner_decl = key.owner_decl,
-            .ty = key.ty,
-            .zir_body_inst = key.zir_body_inst,
-            .lbrace_line = key.lbrace_line,
-            .rbrace_line = key.rbrace_line,
-            .lbrace_column = key.lbrace_column,
-            .rbrace_column = key.rbrace_column,
-        }),
+    const func_decl_extra_index = ip.addExtraAssumeCapacity(Tag.FuncDecl{
+        .analysis = .{
+            .state = if (key.cc == .Inline) .inline_only else .none,
+            .is_cold = false,
+            .is_noinline = key.is_noinline,
+            .calls_or_awaits_errorable_fn = false,
+            .stack_alignment = .none,
+            .inferred_error_set = false,
+        },
+        .owner_decl = key.owner_decl,
+        .ty = key.ty,
+        .zir_body_inst = key.zir_body_inst,
+        .lbrace_line = key.lbrace_line,
+        .rbrace_line = key.rbrace_line,
+        .lbrace_column = key.lbrace_column,
+        .rbrace_column = key.rbrace_column,
     });
 
     const adapter: KeyAdapter = .{ .intern_pool = ip };
-    const gop = try ip.map.getOrPutAdapted(gpa, Key{
-        .func = indexToKeyFuncDecl(ip, @intCast(ip.items.len - 1)),
+    const gop = ip.map.getOrPutAssumeCapacityAdapted(Key{
+        .func = extraIndexToFuncDecl(ip, func_decl_extra_index),
     }, adapter);
-    if (!gop.found_existing) return @enumFromInt(ip.items.len - 1);
 
-    // An existing function type was found; undo the additions to our two arrays.
-    ip.items.len -= 1;
-    ip.extra.items.len = prev_extra_len;
-    return @enumFromInt(gop.index);
+    if (gop.found_existing) {
+        ip.extra.items.len = prev_extra_len;
+        return @enumFromInt(gop.index);
+    }
+
+    ip.items.appendAssumeCapacity(.{
+        .tag = .func_decl,
+        .data = func_decl_extra_index,
+    });
+    return @enumFromInt(ip.items.len - 1);
 }
 
 pub const GetFuncDeclIesKey = struct {
@@ -4434,25 +4479,27 @@ pub fn getFuncDeclIes(ip: *InternPool, gpa: Allocator, key: GetFuncDeclIesKey) A
         params_len);
     try ip.items.ensureUnusedCapacity(gpa, 4);
 
+    const func_decl_extra_index = ip.addExtraAssumeCapacity(Tag.FuncDecl{
+        .analysis = .{
+            .state = if (key.cc == .Inline) .inline_only else .none,
+            .is_cold = false,
+            .is_noinline = key.is_noinline,
+            .calls_or_awaits_errorable_fn = false,
+            .stack_alignment = .none,
+            .inferred_error_set = true,
+        },
+        .owner_decl = key.owner_decl,
+        .ty = @enumFromInt(ip.items.len + 1),
+        .zir_body_inst = key.zir_body_inst,
+        .lbrace_line = key.lbrace_line,
+        .rbrace_line = key.rbrace_line,
+        .lbrace_column = key.lbrace_column,
+        .rbrace_column = key.rbrace_column,
+    });
+
     ip.items.appendAssumeCapacity(.{
         .tag = .func_decl,
-        .data = ip.addExtraAssumeCapacity(Tag.FuncDecl{
-            .analysis = .{
-                .state = if (key.cc == .Inline) .inline_only else .none,
-                .is_cold = false,
-                .is_noinline = key.is_noinline,
-                .calls_or_awaits_errorable_fn = false,
-                .stack_alignment = .none,
-                .inferred_error_set = true,
-            },
-            .owner_decl = key.owner_decl,
-            .ty = @enumFromInt(ip.items.len + 1),
-            .zir_body_inst = key.zir_body_inst,
-            .lbrace_line = key.lbrace_line,
-            .rbrace_line = key.rbrace_line,
-            .lbrace_column = key.lbrace_column,
-            .rbrace_column = key.rbrace_column,
-        }),
+        .data = func_decl_extra_index,
     });
     ip.extra.appendAssumeCapacity(@intFromEnum(Index.none));
 
@@ -4497,7 +4544,7 @@ pub fn getFuncDeclIes(ip: *InternPool, gpa: Allocator, key: GetFuncDeclIesKey) A
 
     const adapter: KeyAdapter = .{ .intern_pool = ip };
     const gop = ip.map.getOrPutAssumeCapacityAdapted(Key{
-        .func = indexToKeyFuncDecl(ip, @intCast(ip.items.len - 4)),
+        .func = extraIndexToFuncDecl(ip, func_decl_extra_index),
     }, adapter);
     if (!gop.found_existing) {
         assert(!ip.map.getOrPutAssumeCapacityAdapted(Key{ .error_union_type = .{
@@ -5570,6 +5617,10 @@ pub fn errorUnionSet(ip: *const InternPool, ty: Index) Index {
     return ip.indexToKey(ty).error_union_type.error_set_type;
 }
 
+pub fn errorUnionPayload(ip: *const InternPool, ty: Index) Index {
+    return ip.indexToKey(ty).error_union_type.payload_type;
+}
+
 /// The is only legal because the initializer is not part of the hash.
 pub fn mutateVarInit(ip: *InternPool, index: Index, init_index: Index) void {
     const item = ip.items.get(@intFromEnum(index));
@@ -5738,7 +5789,7 @@ fn dumpStatsFallible(ip: *const InternPool, arena: Allocator) anyerror!void {
             .float_comptime_float => @sizeOf(Float128),
             .variable => @sizeOf(Tag.Variable) + @sizeOf(Module.Decl),
             .extern_func => @sizeOf(Tag.ExternFunc) + @sizeOf(Module.Decl),
-            .func_decl => @sizeOf(Tag.Func) + @sizeOf(Module.Decl),
+            .func_decl => @sizeOf(Tag.FuncDecl) + @sizeOf(Module.Decl),
             .func_instance => b: {
                 const info = ip.extraData(Tag.FuncInstance, data);
                 const ty = ip.typeOf(info.generic_owner);