Commit 839c453d88

Andrew Kelley <andrew@ziglang.org>
2025-02-07 06:47:46
std.heap.SmpAllocator: eliminate the global mutex
1 parent 60765a9
Changed files (1)
lib
lib/std/heap/SmpAllocator.zig
@@ -39,20 +39,14 @@ const Allocator = std.mem.Allocator;
 const SmpAllocator = @This();
 const PageAllocator = std.heap.PageAllocator;
 
-/// Protects the state in this struct (global state), except for `threads`
-/// which each have their own mutex.
-mutex: std.Thread.Mutex,
-next_thread_index: u32,
 cpu_count: u32,
 threads: [max_thread_count]Thread,
 
 var global: SmpAllocator = .{
-    .mutex = .{},
-    .next_thread_index = 0,
     .threads = @splat(.{}),
     .cpu_count = 0,
 };
-threadlocal var thread_id: Thread.Id = .none;
+threadlocal var thread_index: u32 = 0;
 
 const max_thread_count = 128;
 const slab_len: usize = @max(std.heap.page_size_max, 256 * 1024);
@@ -74,60 +68,22 @@ const Thread = struct {
     /// For each size class, points to the freed pointer.
     frees: [size_class_count]usize = @splat(0),
 
-    /// Index into `SmpAllocator.threads`.
-    const Id = enum(usize) {
-        none = 0,
-        first = 1,
-        _,
-
-        fn fromIndex(index: usize) Id {
-            return @enumFromInt(index + 1);
-        }
-
-        fn toIndex(id: Id) usize {
-            return @intFromEnum(id) - 1;
-        }
-    };
-
     fn lock() *Thread {
-        const id = thread_id;
-        if (id != .none) {
-            var index = id.toIndex();
-            {
-                const t = &global.threads[index];
-                if (t.mutex.tryLock()) return t;
-            }
-            const cpu_count = global.cpu_count;
-            assert(cpu_count != 0);
-            while (true) {
-                index = (index + 1) % cpu_count;
-                const t = &global.threads[index];
-                if (t.mutex.tryLock()) {
-                    thread_id = .fromIndex(index);
-                    return t;
-                }
+        var index = thread_index;
+        {
+            const t = &global.threads[index];
+            if (t.mutex.tryLock()) {
+                @branchHint(.likely);
+                return t;
             }
         }
+        const cpu_count = getCpuCount();
+        assert(cpu_count != 0);
         while (true) {
-            const thread_index = i: {
-                global.mutex.lock();
-                defer global.mutex.unlock();
-                const cpu_count = c: {
-                    const cpu_count = global.cpu_count;
-                    if (cpu_count == 0) {
-                        const n: u32 = @intCast(@max(std.Thread.getCpuCount() catch max_thread_count, max_thread_count));
-                        global.cpu_count = n;
-                        break :c n;
-                    }
-                    break :c cpu_count;
-                };
-                const thread_index = global.next_thread_index;
-                global.next_thread_index = @intCast((thread_index + 1) % cpu_count);
-                break :i thread_index;
-            };
-            const t = &global.threads[thread_index];
+            index = (index + 1) % cpu_count;
+            const t = &global.threads[index];
             if (t.mutex.tryLock()) {
-                thread_id = .fromIndex(thread_index);
+                thread_index = index;
                 return t;
             }
         }
@@ -138,6 +94,13 @@ const Thread = struct {
     }
 };
 
+fn getCpuCount() u32 {
+    const cpu_count = @atomicLoad(u32, &global.cpu_count, .unordered);
+    if (cpu_count != 0) return cpu_count;
+    const n: u32 = @intCast(@max(std.Thread.getCpuCount() catch max_thread_count, max_thread_count));
+    return if (@cmpxchgStrong(u32, &global.cpu_count, 0, n, .monotonic, .monotonic)) |other| other else n;
+}
+
 pub const vtable: Allocator.VTable = .{
     .alloc = alloc,
     .resize = resize,
@@ -159,8 +122,8 @@ fn alloc(context: *anyopaque, len: usize, alignment: mem.Alignment, ra: usize) ?
     }
 
     const slot_size = slotSize(class);
-    const max_search = 2;
-    var search_count: u32 = 0;
+    const max_search = 1;
+    var search_count: u8 = 0;
 
     var t = Thread.lock();
 
@@ -191,15 +154,14 @@ fn alloc(context: *anyopaque, len: usize, alignment: mem.Alignment, ra: usize) ?
         }
 
         t.unlock();
-        t = undefined;
-        const cpu_count = global.cpu_count;
+        const cpu_count = getCpuCount();
         assert(cpu_count != 0);
-        var index = thread_id.toIndex();
+        var index = thread_index;
         while (true) {
             index = (index + 1) % cpu_count;
             t = &global.threads[index];
             if (t.mutex.tryLock()) {
-                thread_id = .fromIndex(index);
+                thread_index = index;
                 search_count += 1;
                 continue :outer;
             }