Commit b0fe7eef54

Jacob Young <jacobly0@users.noreply.github.com>
2024-07-15 09:19:15
InternPool: fix various data structure invariants
1 parent e324547
Changed files (1)
src/InternPool.zig
@@ -939,8 +939,12 @@ const Shard = struct {
                     return @atomicLoad(Value, &entry.value, .acquire);
                 }
                 fn release(entry: *Entry, value: Value) void {
+                    assert(value != .none);
                     @atomicStore(Value, &entry.value, value, .release);
                 }
+                fn resetUnordered(entry: *Entry) void {
+                    @atomicStore(Value, &entry.value, .none, .unordered);
+                }
             };
         };
     }
@@ -6583,35 +6587,55 @@ const GetOrPutKey = union(enum) {
     },
 
     fn put(gop: *GetOrPutKey) Index {
-        return gop.putAt(0);
-    }
-    fn putAt(gop: *GetOrPutKey, offset: u32) Index {
         switch (gop.*) {
             .existing => unreachable,
-            .new => |info| {
+            .new => |*info| {
                 const index = Index.Unwrapped.wrap(.{
                     .tid = info.tid,
-                    .index = info.ip.getLocal(info.tid).mutate.items.len - 1 - offset,
+                    .index = info.ip.getLocal(info.tid).mutate.items.len - 1,
                 }, info.ip);
-                info.shard.shared.map.entries[info.map_index].release(index);
+                gop.putTentative(index);
+                gop.putFinal(index);
+                return index;
+            },
+        }
+    }
+
+    fn putTentative(gop: *GetOrPutKey, index: Index) void {
+        assert(index != .none);
+        switch (gop.*) {
+            .existing => unreachable,
+            .new => |*info| gop.new.shard.shared.map.entries[info.map_index].release(index),
+        }
+    }
+
+    fn putFinal(gop: *GetOrPutKey, index: Index) void {
+        assert(index != .none);
+        switch (gop.*) {
+            .existing => unreachable,
+            .new => |info| {
+                assert(info.shard.shared.map.entries[info.map_index].value == index);
                 info.shard.mutate.map.len += 1;
                 info.shard.mutate.map.mutex.unlock();
                 gop.* = .{ .existing = index };
-                return index;
             },
         }
     }
 
-    fn assign(gop: *GetOrPutKey, new_gop: GetOrPutKey) void {
-        gop.deinit();
-        gop.* = new_gop;
+    fn cancel(gop: *GetOrPutKey) void {
+        switch (gop.*) {
+            .existing => {},
+            .new => |info| info.shard.mutate.map.mutex.unlock(),
+        }
+        gop.* = .{ .existing = undefined };
     }
 
     fn deinit(gop: *GetOrPutKey) void {
         switch (gop.*) {
             .existing => {},
-            .new => |info| info.shard.mutate.map.mutex.unlock(),
+            .new => |info| info.shard.shared.map.entries[info.map_index].resetUnordered(),
         }
+        gop.cancel();
         gop.* = undefined;
     }
 };
@@ -6620,6 +6644,15 @@ fn getOrPutKey(
     gpa: Allocator,
     tid: Zcu.PerThread.Id,
     key: Key,
+) Allocator.Error!GetOrPutKey {
+    return ip.getOrPutKeyEnsuringAdditionalCapacity(gpa, tid, key, 0);
+}
+fn getOrPutKeyEnsuringAdditionalCapacity(
+    ip: *InternPool,
+    gpa: Allocator,
+    tid: Zcu.PerThread.Id,
+    key: Key,
+    additional_capacity: u32,
 ) Allocator.Error!GetOrPutKey {
     const full_hash = key.hash64(ip);
     const hash: u32 = @truncate(full_hash >> 32);
@@ -6655,11 +6688,16 @@ fn getOrPutKey(
         }
     }
     const map_header = map.header().*;
-    if (shard.mutate.map.len >= map_header.capacity * 3 / 5) {
+    const required = shard.mutate.map.len + additional_capacity;
+    if (required >= map_header.capacity * 3 / 5) {
         const arena_state = &ip.getLocal(tid).mutate.arena;
         var arena = arena_state.promote(gpa);
         defer arena_state.* = arena.state;
-        const new_map_capacity = map_header.capacity * 2;
+        var new_map_capacity = map_header.capacity;
+        while (true) {
+            new_map_capacity *= 2;
+            if (required < new_map_capacity * 3 / 5) break;
+        }
         const new_map_buf = try arena.allocator().alignedAlloc(
             u8,
             Map.alignment,
@@ -6728,10 +6766,11 @@ pub fn get(ip: *InternPool, gpa: Allocator, tid: Zcu.PerThread.Id, key: Key) All
             assert(ptr_type.sentinel == .none or ip.typeOf(ptr_type.sentinel) == ptr_type.child);
 
             if (ptr_type.flags.size == .Slice) {
+                gop.cancel();
                 var new_key = key;
                 new_key.ptr_type.flags.size = .Many;
                 const ptr_type_index = try ip.get(gpa, tid, new_key);
-                gop.assign(try ip.getOrPutKey(gpa, tid, key));
+                gop = try ip.getOrPutKey(gpa, tid, key);
 
                 try items.ensureUnusedCapacity(1);
                 items.appendAssumeCapacity(.{
@@ -6911,9 +6950,10 @@ pub fn get(ip: *InternPool, gpa: Allocator, tid: Zcu.PerThread.Id, key: Key) All
                 },
                 .anon_decl => |anon_decl| if (ptrsHaveSameAlignment(ip, ptr.ty, ptr_type, anon_decl.orig_ty)) item: {
                     if (ptr.ty != anon_decl.orig_ty) {
+                        gop.cancel();
                         var new_key = key;
                         new_key.ptr.base_addr.anon_decl.orig_ty = ptr.ty;
-                        gop.assign(try ip.getOrPutKey(gpa, tid, new_key));
+                        gop = try ip.getOrPutKey(gpa, tid, new_key);
                         if (gop == .existing) return gop.existing;
                     }
                     break :item .{
@@ -6984,11 +7024,12 @@ pub fn get(ip: *InternPool, gpa: Allocator, tid: Zcu.PerThread.Id, key: Key) All
                         },
                         else => unreachable,
                     }
+                    gop.cancel();
                     const index_index = try ip.get(gpa, tid, .{ .int = .{
                         .ty = .usize_type,
                         .storage = .{ .u64 = base_index.index },
                     } });
-                    gop.assign(try ip.getOrPutKey(gpa, tid, key));
+                    gop = try ip.getOrPutKey(gpa, tid, key);
                     try items.ensureUnusedCapacity(1);
                     items.appendAssumeCapacity(.{
                         .tag = switch (ptr.base_addr) {
@@ -7397,11 +7438,12 @@ pub fn get(ip: *InternPool, gpa: Allocator, tid: Zcu.PerThread.Id, key: Key) All
                 }
                 const elem = switch (aggregate.storage) {
                     .bytes => |bytes| elem: {
+                        gop.cancel();
                         const elem = try ip.get(gpa, tid, .{ .int = .{
                             .ty = .u8_type,
                             .storage = .{ .u64 = bytes.at(0, ip) },
                         } });
-                        gop.assign(try ip.getOrPutKey(gpa, tid, key));
+                        gop = try ip.getOrPutKey(gpa, tid, key);
                         try items.ensureUnusedCapacity(1);
                         break :elem elem;
                     },
@@ -8219,9 +8261,9 @@ pub fn getFuncDeclIes(
         extra.mutate.len = prev_extra_len;
     }
 
-    var func_gop = try ip.getOrPutKey(gpa, tid, .{
+    var func_gop = try ip.getOrPutKeyEnsuringAdditionalCapacity(gpa, tid, .{
         .func = extraFuncDecl(tid, extra.list.*, func_decl_extra_index),
-    });
+    }, 3);
     defer func_gop.deinit();
     if (func_gop == .existing) {
         // An existing function type was found; undo the additions to our two arrays.
@@ -8229,23 +8271,28 @@ pub fn getFuncDeclIes(
         extra.mutate.len = prev_extra_len;
         return func_gop.existing;
     }
-    var error_union_type_gop = try ip.getOrPutKey(gpa, tid, .{ .error_union_type = .{
+    func_gop.putTentative(func_index);
+    var error_union_type_gop = try ip.getOrPutKeyEnsuringAdditionalCapacity(gpa, tid, .{ .error_union_type = .{
         .error_set_type = error_set_type,
         .payload_type = key.bare_return_type,
-    } });
+    } }, 2);
     defer error_union_type_gop.deinit();
-    var error_set_type_gop = try ip.getOrPutKey(gpa, tid, .{
+    error_union_type_gop.putTentative(error_union_type);
+    var error_set_type_gop = try ip.getOrPutKeyEnsuringAdditionalCapacity(gpa, tid, .{
         .inferred_error_set_type = func_index,
-    });
+    }, 1);
     defer error_set_type_gop.deinit();
+    error_set_type_gop.putTentative(error_set_type);
     var func_ty_gop = try ip.getOrPutKey(gpa, tid, .{
         .func_type = extraFuncType(tid, extra.list.*, func_type_extra_index),
     });
     defer func_ty_gop.deinit();
-    assert(func_gop.putAt(3) == func_index);
-    assert(error_union_type_gop.putAt(2) == error_union_type);
-    assert(error_set_type_gop.putAt(1) == error_set_type);
-    assert(func_ty_gop.putAt(0) == func_ty);
+    func_ty_gop.putTentative(func_ty);
+
+    func_gop.putFinal(func_index);
+    error_union_type_gop.putFinal(error_union_type);
+    error_set_type_gop.putFinal(error_set_type);
+    func_ty_gop.putFinal(func_ty);
     return func_index;
 }
 
@@ -8504,9 +8551,9 @@ pub fn getFuncInstanceIes(
         extra.mutate.len = prev_extra_len;
     }
 
-    var func_gop = try ip.getOrPutKey(gpa, tid, .{
+    var func_gop = try ip.getOrPutKeyEnsuringAdditionalCapacity(gpa, tid, .{
         .func = ip.extraFuncInstance(tid, extra.list.*, func_extra_index),
-    });
+    }, 3);
     defer func_gop.deinit();
     if (func_gop == .existing) {
         // Hot path: undo the additions to our two arrays.
@@ -8514,19 +8561,23 @@ pub fn getFuncInstanceIes(
         extra.mutate.len = prev_extra_len;
         return func_gop.existing;
     }
-    var error_union_type_gop = try ip.getOrPutKey(gpa, tid, .{ .error_union_type = .{
+    func_gop.putTentative(func_index);
+    var error_union_type_gop = try ip.getOrPutKeyEnsuringAdditionalCapacity(gpa, tid, .{ .error_union_type = .{
         .error_set_type = error_set_type,
         .payload_type = arg.bare_return_type,
-    } });
+    } }, 2);
     defer error_union_type_gop.deinit();
-    var error_set_type_gop = try ip.getOrPutKey(gpa, tid, .{
+    error_union_type_gop.putTentative(error_union_type);
+    var error_set_type_gop = try ip.getOrPutKeyEnsuringAdditionalCapacity(gpa, tid, .{
         .inferred_error_set_type = func_index,
-    });
+    }, 1);
     defer error_set_type_gop.deinit();
+    error_set_type_gop.putTentative(error_set_type);
     var func_ty_gop = try ip.getOrPutKey(gpa, tid, .{
         .func_type = extraFuncType(tid, extra.list.*, func_type_extra_index),
     });
     defer func_ty_gop.deinit();
+    func_ty_gop.putTentative(func_ty);
     try finishFuncInstance(
         ip,
         gpa,
@@ -8538,10 +8589,11 @@ pub fn getFuncInstanceIes(
         arg.alignment,
         arg.section,
     );
-    assert(func_gop.putAt(3) == func_index);
-    assert(error_union_type_gop.putAt(2) == error_union_type);
-    assert(error_set_type_gop.putAt(1) == error_set_type);
-    assert(func_ty_gop.putAt(0) == func_ty);
+
+    func_gop.putFinal(func_index);
+    error_union_type_gop.putFinal(error_union_type);
+    error_set_type_gop.putFinal(error_set_type);
+    func_ty_gop.putFinal(func_ty);
     return func_index;
 }
 
@@ -10837,19 +10889,18 @@ pub fn getBackingDecl(ip: *const InternPool, val: Index) OptionalDeclIndex {
     while (true) {
         const unwrapped_base = base.unwrap(ip);
         const base_item = unwrapped_base.getItem(ip);
-        const base_extra_items = unwrapped_base.getExtra(ip).view().items(.@"0");
         switch (base_item.tag) {
-            .ptr_decl => return @enumFromInt(base_extra_items[
+            .ptr_decl => return @enumFromInt(unwrapped_base.getExtra(ip).view().items(.@"0")[
                 base_item.data + std.meta.fieldIndex(PtrDecl, "decl").?
             ]),
             inline .ptr_eu_payload,
             .ptr_opt_payload,
             .ptr_elem,
             .ptr_field,
-            => |tag| base = @enumFromInt(base_extra_items[
+            => |tag| base = @enumFromInt(unwrapped_base.getExtra(ip).view().items(.@"0")[
                 base_item.data + std.meta.fieldIndex(tag.Payload(), "base").?
             ]),
-            .ptr_slice => base = @enumFromInt(base_extra_items[
+            .ptr_slice => base = @enumFromInt(unwrapped_base.getExtra(ip).view().items(.@"0")[
                 base_item.data + std.meta.fieldIndex(PtrSlice, "ptr").?
             ]),
             else => return .none,