Commit c8b9364b30

Jacob Young <jacobly0@users.noreply.github.com>
2024-06-16 01:58:29
InternPool: use thread-safe hash map for strings
1 parent cda716e
Changed files (1)
src/InternPool.zig
@@ -46,14 +46,6 @@ namespaces_free_list: std.ArrayListUnmanaged(NamespaceIndex) = .{},
 /// These are not serialized; it is computed upon deserialization.
 maps: std.ArrayListUnmanaged(FieldMap) = .{},
 
-/// Used for finding the index inside `string_bytes`.
-string_table: std.HashMapUnmanaged(
-    u32,
-    void,
-    std.hash_map.StringIndexContext,
-    std.hash_map.default_max_load_percentage,
-) = .{},
-
 /// An index into `tracked_insts` gives a reference to a single ZIR instruction which
 /// persists across incremental updates.
 tracked_insts: std.AutoArrayHashMapUnmanaged(TrackedInst, void) = .{},
@@ -358,22 +350,31 @@ const Local = struct {
     /// node: Garbage.Node,
     /// header: List.Header,
     /// data: [capacity]u32,
-    /// tag: [capacity]Tag,
+    /// tag: [header.capacity]Tag,
     items: List,
 
     /// node: Garbage.Node,
     /// header: List.Header,
-    /// extra: [capacity]u32,
+    /// extra: [header.capacity]u32,
     extra: List,
 
+    /// node: Garbage.Node,
+    /// header: List.Header,
+    /// bytes: [header.capacity]u8,
+    strings: List,
+
     garbage: Garbage,
 
     const List = struct {
         entries: [*]u32,
 
-        const empty: List = .{
-            .entries = @constCast(&[_]u32{ 0, 0 })[Header.fields_len..].ptr,
-        };
+        const empty: List = .{ .entries = @constCast(&(extern struct {
+            header: Header,
+            entries: [0]u32,
+        }{
+            .header = .{ .len = 0, .capacity = 0 },
+            .entries = .{},
+        }).entries) };
 
         fn acquire(list: *const List) List {
             return .{ .entries = @atomicLoad([*]u32, &list.entries, .acquire) };
@@ -402,63 +403,75 @@ const Local = struct {
 };
 
 const Shard = struct {
-    aligned: void align(std.atomic.cache_line) = {},
-
-    mutate_mutex: std.Thread.Mutex.Recursive,
-
-    /// node: Local.Garbage.Node,
-    /// header: Map.Header,
-    /// entries: [capacity]Map.Entry,
-    map: Map,
+    shared: struct {
+        map: Map(Index),
+        string_map: Map(OptionalNullTerminatedString),
+    } align(std.atomic.cache_line),
+    mutate: struct {
+        // TODO: measure cost of sharing unrelated mutate state
+        map: Mutate align(std.atomic.cache_line),
+        string_map: Mutate align(std.atomic.cache_line),
+    },
 
-    const Map = struct {
-        entries: [*]u32,
+    const Mutate = struct {
+        mutex: std.Thread.Mutex.Recursive,
+        len: u32,
 
-        const empty: Map = .{
-            .entries = @constCast(&[_]u32{ 0, 1, @intFromEnum(Index.none), 0 })[Header.fields_len..].ptr,
+        const empty: Mutate = .{
+            .mutex = std.Thread.Mutex.Recursive.init,
+            .len = 0,
         };
+    };
 
-        fn acquire(map: *const Map) Map {
-            return .{ .entries = @atomicLoad([*]u32, &map.entries, .acquire) };
-        }
-        fn release(map: *Map, new_map: Map) void {
-            @atomicStore([*]u32, &map.entries, new_map.entries, .release);
-        }
-
-        const Header = extern struct {
-            len: u32,
-            capacity: u32,
+    fn Map(comptime Value: type) type {
+        comptime assert(@typeInfo(Value).Enum.tag_type == u32);
+        _ = @as(Value, .none); // expected .none key
+        return struct {
+            /// node: Local.Garbage.Node,
+            /// header: Header,
+            /// entries: [header.capacity]Entry,
+            entries: [*]Entry,
+
+            const empty: @This() = .{ .entries = @constCast(&(extern struct {
+                header: Header,
+                entries: [1]Entry,
+            }{
+                .header = .{ .capacity = 1 },
+                .entries = .{.{ .value = .none, .hash = undefined }},
+            }).entries) };
+
+            fn acquire(map: *const @This()) @This() {
+                return .{ .entries = @atomicLoad([*]Entry, &map.entries, .acquire) };
+            }
+            fn release(map: *@This(), new_map: @This()) void {
+                @atomicStore([*]Entry, &map.entries, new_map.entries, .release);
+            }
 
-            const fields_len: u32 = @typeInfo(Header).Struct.fields.len;
+            const Header = extern struct {
+                capacity: u32,
 
-            fn mask(head: *const Header) u32 {
-                assert(std.math.isPowerOfTwo(head.capacity));
-                assert(std.math.isPowerOfTwo(Entry.fields_len));
-                return (head.capacity - 1) * Entry.fields_len;
+                fn mask(head: *const Header) u32 {
+                    assert(std.math.isPowerOfTwo(head.capacity));
+                    return head.capacity - 1;
+                }
+            };
+            fn header(map: @This()) *Header {
+                return &(@as([*]Header, @ptrCast(map.entries)) - 1)[0];
             }
-        };
-        fn header(map: Map) *Header {
-            return @ptrCast(map.entries - Header.fields_len);
-        }
 
-        const Entry = extern struct {
-            index: Index,
-            hash: u32,
+            const Entry = extern struct {
+                value: Value,
+                hash: u32,
 
-            const fields_len: u32 = @typeInfo(Entry).Struct.fields.len;
-
-            fn acquire(entry: *const Entry) Index {
-                return @atomicLoad(Index, &entry.index, .acquire);
-            }
-            fn release(entry: *Entry, index: Index) void {
-                @atomicStore(Index, &entry.index, index, .release);
-            }
+                fn acquire(entry: *const Entry) Value {
+                    return @atomicLoad(Value, &entry.value, .acquire);
+                }
+                fn release(entry: *Entry, value: Value) void {
+                    @atomicStore(Value, &entry.value, value, .release);
+                }
+            };
         };
-        fn at(map: Map, index: usize) *Entry {
-            assert(index % Entry.fields_len == 0);
-            return @ptrCast(&map.entries[index]);
-        }
-    };
+    }
 };
 
 const FieldMap = std.ArrayHashMapUnmanaged(void, void, std.array_hash_map.AutoContext(void), false);
@@ -618,9 +631,13 @@ pub const NullTerminatedString = enum(u32) {
         return @enumFromInt(@intFromEnum(self));
     }
 
+    fn toOverlongSlice(string: NullTerminatedString, ip: *const InternPool) []const u8 {
+        return ip.string_bytes.items[@intFromEnum(string)..];
+    }
+
     pub fn toSlice(string: NullTerminatedString, ip: *const InternPool) [:0]const u8 {
-        const slice = ip.string_bytes.items[@intFromEnum(string)..];
-        return slice[0..std.mem.indexOfScalar(u8, slice, 0).? :0];
+        const overlong_slice = string.toOverlongSlice(ip);
+        return overlong_slice[0..std.mem.indexOfScalar(u8, overlong_slice, 0).? :0];
     }
 
     pub fn length(string: NullTerminatedString, ip: *const InternPool) u32 {
@@ -628,7 +645,10 @@ pub const NullTerminatedString = enum(u32) {
     }
 
     pub fn eqlSlice(string: NullTerminatedString, slice: []const u8, ip: *const InternPool) bool {
-        return std.mem.eql(u8, string.toSlice(ip), slice);
+        const overlong_slice = string.toOverlongSlice(ip);
+        return overlong_slice.len > slice.len and
+            std.mem.eql(u8, overlong_slice[0..slice.len], slice) and
+            overlong_slice[slice.len] == 0;
     }
 
     const Adapter = struct {
@@ -4639,14 +4659,21 @@ pub fn init(ip: *InternPool, gpa: Allocator, total_threads: usize) !void {
     @memset(ip.local, .{
         .items = Local.List.empty,
         .extra = Local.List.empty,
+        .strings = Local.List.empty,
         .garbage = .{},
     });
 
     ip.shard_shift = @intCast(std.math.log2_int_ceil(usize, total_threads));
     ip.shards = try gpa.alloc(Shard, @as(usize, 1) << ip.shard_shift);
     @memset(ip.shards, .{
-        .mutate_mutex = std.Thread.Mutex.Recursive.init,
-        .map = Shard.Map.empty,
+        .shared = .{
+            .map = Shard.Map(Index).empty,
+            .string_map = Shard.Map(OptionalNullTerminatedString).empty,
+        },
+        .mutate = .{
+            .map = Shard.Mutate.empty,
+            .string_map = Shard.Mutate.empty,
+        },
     });
 
     // Reserve string index 0 for an empty string.
@@ -4697,8 +4724,6 @@ pub fn deinit(ip: *InternPool, gpa: Allocator) void {
     for (ip.maps.items) |*map| map.deinit(gpa);
     ip.maps.deinit(gpa);
 
-    ip.string_table.deinit(gpa);
-
     ip.tracked_insts.deinit(gpa);
 
     ip.src_hash_deps.deinit(gpa);
@@ -5363,9 +5388,9 @@ const GetOrPutKey = union(enum) {
         switch (gop.*) {
             .existing => unreachable,
             .new => |info| {
-                info.shard.map.at(info.map_index).release(index);
-                info.shard.map.header().len += 1;
-                info.shard.mutate_mutex.unlock();
+                info.shard.shared.map.entries[info.map_index].release(index);
+                info.shard.mutate.map.len += 1;
+                info.shard.mutate.map.mutex.unlock();
             },
         }
         gop.* = .{ .existing = index };
@@ -5380,7 +5405,7 @@ const GetOrPutKey = union(enum) {
     fn deinit(gop: *GetOrPutKey) void {
         switch (gop.*) {
             .existing => {},
-            .new => |info| info.shard.mutate_mutex.unlock(),
+            .new => |info| info.shard.mutate.map.mutex.unlock(),
         }
         gop.* = undefined;
     }
@@ -5394,70 +5419,69 @@ fn getOrPutKey(
     const full_hash = key.hash64(ip);
     const hash: u32 = @truncate(full_hash >> 32);
     const shard = &ip.shards[@intCast(full_hash & (ip.shards.len - 1))];
-    var map = shard.map.acquire();
+    var map = shard.shared.map.acquire();
+    const Map = @TypeOf(map);
     var map_mask = map.header().mask();
     var map_index = hash;
-    while (true) : (map_index += Shard.Map.Entry.fields_len) {
+    while (true) : (map_index += 1) {
         map_index &= map_mask;
-        const entry = map.at(map_index);
+        const entry = &map.entries[map_index];
         const index = entry.acquire();
         if (index == .none) break;
-        if (entry.hash == hash and ip.indexToKey(index).eql(key, ip))
-            return .{ .existing = index };
+        if (entry.hash != hash) continue;
+        if (ip.indexToKey(index).eql(key, ip)) return .{ .existing = index };
     }
-    shard.mutate_mutex.lock();
-    errdefer shard.mutate_mutex.unlock();
-    if (map.entries != shard.map.entries) {
-        map = shard.map;
+    shard.mutate.map.mutex.lock();
+    errdefer shard.mutate.map.mutex.unlock();
+    if (map.entries != shard.shared.map.entries) {
+        map = shard.shared.map;
         map_mask = map.header().mask();
         map_index = hash;
     }
-    while (true) : (map_index += Shard.Map.Entry.fields_len) {
+    while (true) : (map_index += 1) {
         map_index &= map_mask;
-        const entry = map.at(map_index);
-        const index = entry.index;
+        const entry = &map.entries[map_index];
+        const index = entry.value;
         if (index == .none) break;
-        if (entry.hash == hash and ip.indexToKey(index).eql(key, ip)) {
-            defer shard.mutate_mutex.unlock();
+        if (entry.hash != hash) continue;
+        if (ip.indexToKey(index).eql(key, ip)) {
+            defer shard.mutate.map.mutex.unlock();
             return .{ .existing = index };
         }
     }
     const map_header = map.header().*;
-    if (map_header.len >= map_header.capacity * 3 / 5) {
+    if (shard.mutate.map.len >= map_header.capacity * 3 / 5) {
         const new_map_capacity = map_header.capacity * 2;
         const new_map_buf = try gpa.alignedAlloc(
             u8,
             Local.garbage_align,
-            @sizeOf(Local.Garbage.Node) + (Shard.Map.Header.fields_len +
-                new_map_capacity * Shard.Map.Entry.fields_len) * @sizeOf(u32),
+            @sizeOf(Local.Garbage.Node) + @sizeOf(Map.Header) +
+                new_map_capacity * @sizeOf(Map.Entry),
         );
         const new_node: *Local.Garbage.Node = @ptrCast(new_map_buf.ptr);
         new_node.* = .{ .data = .{ .buf_len = new_map_buf.len } };
         ip.local[@intFromEnum(tid)].garbage.prepend(new_node);
         const new_map_entries = std.mem.bytesAsSlice(
-            u32,
-            new_map_buf[@sizeOf(Local.Garbage.Node)..],
-        )[Shard.Map.Header.fields_len..];
-        const new_map: Shard.Map = .{ .entries = new_map_entries.ptr };
-        new_map.header().* = .{
-            .len = map_header.len,
-            .capacity = new_map_capacity,
-        };
-        @memset(new_map_entries, @intFromEnum(Index.none));
+            Map.Entry,
+            new_map_buf[@sizeOf(Local.Garbage.Node) + @sizeOf(Map.Header) ..],
+        );
+        const new_map: Map = .{ .entries = new_map_entries.ptr };
+        new_map.header().* = .{ .capacity = new_map_capacity };
+        @memset(new_map_entries, .{ .value = .none, .hash = undefined });
         const new_map_mask = new_map.header().mask();
         map_index = 0;
-        while (map_index < map_header.capacity * 2) : (map_index += Shard.Map.Entry.fields_len) {
-            const entry = map.at(map_index);
-            const index = entry.index;
+        while (map_index < map_header.capacity) : (map_index += 1) {
+            const entry = &map.entries[map_index];
+            const index = entry.value;
             if (index == .none) continue;
             const item_hash = entry.hash;
             var new_map_index = item_hash;
-            while (true) : (new_map_index += Shard.Map.Entry.fields_len) {
+            while (true) : (new_map_index += 1) {
                 new_map_index &= new_map_mask;
-                const new_entry = new_map.at(new_map_index);
-                if (new_entry.index != .none) continue;
+                const new_entry = &new_map.entries[new_map_index];
+                if (new_entry.value != .none) continue;
                 new_entry.* = .{
-                    .index = index,
+                    .value = index,
                     .hash = item_hash,
                 };
                 break;
@@ -5465,13 +5489,13 @@ fn getOrPutKey(
         }
         map = new_map;
         map_index = hash;
-        while (true) : (map_index += Shard.Map.Entry.fields_len) {
+        while (true) : (map_index += 1) {
             map_index &= new_map_mask;
-            if (map.at(map_index).index == .none) break;
+            if (map.entries[map_index].value == .none) break;
         }
-        shard.map.release(new_map);
+        shard.shared.map.release(new_map);
     }
-    map.at(map_index).hash = hash;
+    map.entries[map_index].hash = hash;
     return .{ .new = .{ .shard = shard, .map_index = map_index } };
 }
 
@@ -7689,22 +7713,19 @@ pub fn getIfExists(ip: *const InternPool, key: Key) ?Index {
     const full_hash = key.hash64(ip);
     const hash: u32 = @truncate(full_hash >> 32);
     const shard = &ip.shards[@intCast(full_hash & (ip.shards.len - 1))];
-    const map = shard.map.acquire();
+    const map = shard.shared.map.acquire();
     const map_mask = map.header().mask();
     var map_index = hash;
-    while (true) : (map_index += Shard.Map.Entry.fields_len) {
+    while (true) : (map_index += 1) {
         map_index &= map_mask;
-        const entry = map.at(map_index);
+        const entry = &map.entries[map_index];
         const index = entry.acquire();
         if (index == .none) return null;
-        if (entry.hash == hash and ip.indexToKey(index).eql(key, ip)) return index;
+        if (entry.hash != hash) continue;
+        if (ip.indexToKey(index).eql(key, ip)) return index;
     }
 }
 
-pub fn getAssumeExists(ip: *const InternPool, key: Key) Index {
-    return ip.getIfExists(key).?;
-}
-
 fn addStringsToMap(
     ip: *InternPool,
     map_index: MapIndex,
@@ -8618,7 +8639,13 @@ fn dumpStatsFallible(ip: *const InternPool, arena: Allocator) anyerror!void {
             .type_inferred_error_set => 0,
             .type_enum_explicit, .type_enum_nonexhaustive => b: {
                 const info = ip.extraData(EnumExplicit, data);
-                var ints = @typeInfo(EnumExplicit).Struct.fields.len + info.captures_len + info.fields_len;
+                var ints = @typeInfo(EnumExplicit).Struct.fields.len;
+                if (info.zir_index == .none) ints += 1;
+                ints += if (info.captures_len != std.math.maxInt(u32))
+                    info.captures_len
+                else
+                    @typeInfo(PackedU64).Struct.fields.len;
+                ints += info.fields_len;
                 if (info.values_map != .none) ints += info.fields_len;
                 break :b @sizeOf(u32) * ints;
             },
@@ -9084,7 +9111,6 @@ pub fn getOrPutTrailingString(
     len: usize,
     comptime embedded_nulls: EmbeddedNulls,
 ) Allocator.Error!embedded_nulls.StringType() {
-    _ = tid;
     const string_bytes = &ip.string_bytes;
     const str_index: u32 = @intCast(string_bytes.items.len - len);
     if (len > 0 and string_bytes.getLast() == 0) {
@@ -9101,25 +9127,123 @@ pub fn getOrPutTrailingString(
             return @enumFromInt(str_index);
         },
     }
-    const gop = try ip.string_table.getOrPutContextAdapted(gpa, key, std.hash_map.StringIndexAdapter{
-        .bytes = string_bytes,
-    }, std.hash_map.StringIndexContext{
-        .bytes = string_bytes,
-    });
-    if (gop.found_existing) {
+    const maybe_existing_index = try ip.getOrPutStringValue(gpa, tid, key, @enumFromInt(str_index));
+    if (maybe_existing_index.unwrap()) |existing_index| {
         string_bytes.shrinkRetainingCapacity(str_index);
-        return @enumFromInt(gop.key_ptr.*);
+        return @enumFromInt(@intFromEnum(existing_index));
     } else {
-        gop.key_ptr.* = str_index;
         string_bytes.appendAssumeCapacity(0);
         return @enumFromInt(str_index);
     }
 }
 
-pub fn getString(ip: *InternPool, s: []const u8) OptionalNullTerminatedString {
-    return if (ip.string_table.getKeyAdapted(s, std.hash_map.StringIndexAdapter{
-        .bytes = &ip.string_bytes,
-    })) |index| @enumFromInt(index) else .none;
+fn getOrPutStringValue(
+    ip: *InternPool,
+    gpa: Allocator,
+    tid: Zcu.PerThread.Id,
+    key: []const u8,
+    value: NullTerminatedString,
+) Allocator.Error!OptionalNullTerminatedString {
+    const full_hash = Hash.hash(0, key);
+    const hash: u32 = @truncate(full_hash >> 32);
+    const shard = &ip.shards[@intCast(full_hash & (ip.shards.len - 1))];
+    var map = shard.shared.string_map.acquire();
+    const Map = @TypeOf(map);
+    var map_mask = map.header().mask();
+    var map_index = hash;
+    while (true) : (map_index += 1) {
+        map_index &= map_mask;
+        const entry = &map.entries[map_index];
+        const index = entry.acquire().unwrap() orelse break;
+        if (entry.hash != hash) continue;
+        if (index.eqlSlice(key, ip)) return index.toOptional();
+    }
+    shard.mutate.string_map.mutex.lock();
+    defer shard.mutate.string_map.mutex.unlock();
+    if (map.entries != shard.shared.string_map.entries) {
+        shard.mutate.string_map.len += 1;
+        map = shard.shared.string_map;
+        map_mask = map.header().mask();
+        map_index = hash;
+    }
+    while (true) : (map_index += 1) {
+        map_index &= map_mask;
+        const entry = &map.entries[map_index];
+        const index = entry.acquire().unwrap() orelse break;
+        if (entry.hash != hash) continue;
+        if (index.eqlSlice(key, ip)) return index.toOptional();
+    }
+    defer shard.mutate.string_map.len += 1;
+    const map_header = map.header().*;
+    if (shard.mutate.string_map.len < map_header.capacity * 3 / 5) {
+        const entry = &map.entries[map_index];
+        entry.hash = hash;
+        entry.release(value.toOptional());
+        return .none;
+    }
+    const new_map_capacity = map_header.capacity * 2;
+    const new_map_buf = try gpa.alignedAlloc(
+        u8,
+        Local.garbage_align,
+        @sizeOf(Local.Garbage.Node) + @sizeOf(Map.Header) +
+            new_map_capacity * @sizeOf(Map.Entry),
+    );
+    const new_node: *Local.Garbage.Node = @ptrCast(new_map_buf.ptr);
+    new_node.* = .{ .data = .{ .buf_len = new_map_buf.len } };
+    ip.local[@intFromEnum(tid)].garbage.prepend(new_node);
+    const new_map_entries = std.mem.bytesAsSlice(
+        Map.Entry,
+        new_map_buf[@sizeOf(Local.Garbage.Node) + @sizeOf(Map.Header) ..],
+    );
+    const new_map: Map = .{ .entries = new_map_entries.ptr };
+    new_map.header().* = .{ .capacity = new_map_capacity };
+    @memset(new_map_entries, .{ .value = .none, .hash = undefined });
+    const new_map_mask = new_map.header().mask();
+    map_index = 0;
+    while (map_index < map_header.capacity) : (map_index += 1) {
+        const entry = &map.entries[map_index];
+        const index = entry.value.unwrap() orelse continue;
+        const item_hash = entry.hash;
+        var new_map_index = item_hash;
+        while (true) : (new_map_index += 1) {
+            new_map_index &= new_map_mask;
+            const new_entry = &new_map.entries[new_map_index];
+            if (new_entry.value != .none) continue;
+            new_entry.* = .{
+                .value = index.toOptional(),
+                .hash = item_hash,
+            };
+            break;
+        }
+    }
+    map = new_map;
+    map_index = hash;
+    while (true) : (map_index += 1) {
+        map_index &= new_map_mask;
+        if (map.entries[map_index].value == .none) break;
+    }
+    map.entries[map_index] = .{
+        .value = value.toOptional(),
+        .hash = hash,
+    };
+    shard.shared.string_map.release(new_map);
+    return .none;
+}
+
+pub fn getString(ip: *InternPool, key: []const u8) OptionalNullTerminatedString {
+    const full_hash = Hash.hash(0, key);
+    const hash: u32 = @truncate(full_hash >> 32);
+    const shard = &ip.shards[@intCast(full_hash & (ip.shards.len - 1))];
+    const map = shard.shared.string_map.acquire();
+    const map_mask = map.header().mask();
+    var map_index = hash;
+    while (true) : (map_index += 1) {
+        map_index &= map_mask;
+        const entry = map.at(map_index);
+        const index = entry.acquire().unwrap() orelse return null;
+        if (entry.hash != hash) continue;
+        if (index.eqlSlice(key, ip)) return index;
+    }
 }
 
 pub fn typeOf(ip: *const InternPool, index: Index) Index {