Commit 1754e014f5

Andrew Kelley <andrew@ziglang.org>
2025-02-07 23:05:28
std.heap.SmpAllocator: rotate on free sometimes
* slab length reduced to 64K * track freelist length with u8s * on free(), rotate if freelist length exceeds max_freelist_len Prevents memory leakage in the scenario where one thread only allocates and another thread only frees.
1 parent a9d3005
Changed files (1)
lib
lib/std/heap/SmpAllocator.zig
@@ -47,10 +47,16 @@ var global: SmpAllocator = .{
 threadlocal var thread_index: u32 = 0;
 
 const max_thread_count = 128;
-const slab_len: usize = @max(std.heap.page_size_max, 256 * 1024);
+const slab_len: usize = @max(std.heap.page_size_max, 64 * 1024);
 /// Because of storing free list pointers, the minimum size class is 3.
 const min_class = math.log2(@sizeOf(usize));
 const size_class_count = math.log2(slab_len) - min_class;
+/// When a freelist length exceeds this number, a `free` will rotate up to
+/// `max_free_search` times before pushing.
+const max_freelist_len: u8 = 16;
+const max_free_search = 1;
+/// Before mapping a fresh page, `alloc` will rotate this many times.
+const max_alloc_search = 1;
 
 const Thread = struct {
     /// Avoid false sharing.
@@ -62,9 +68,13 @@ const Thread = struct {
     /// to support freelist reclamation.
     mutex: std.Thread.Mutex = .{},
 
+    /// For each size class, tracks the next address to be returned from
+    /// `alloc` when the freelist is empty.
     next_addrs: [size_class_count]usize = @splat(0),
     /// For each size class, points to the freed pointer.
     frees: [size_class_count]usize = @splat(0),
+    /// For each size class, tracks the number of items in the freelist.
+    freelist_lens: [size_class_count]u8 = @splat(0),
 
     fn lock() *Thread {
         var index = thread_index;
@@ -121,7 +131,6 @@ fn alloc(context: *anyopaque, len: usize, alignment: mem.Alignment, ra: usize) ?
 
     const slot_size = slotSize(class);
     assert(slab_len % slot_size == 0);
-    const max_search = 1;
     var search_count: u8 = 0;
 
     var t = Thread.lock();
@@ -133,6 +142,7 @@ fn alloc(context: *anyopaque, len: usize, alignment: mem.Alignment, ra: usize) ?
             defer t.unlock();
             const node: *usize = @ptrFromInt(top_free_ptr);
             t.frees[class] = node.*;
+            t.freelist_lens[class] -|= 1;
             return @ptrFromInt(top_free_ptr);
         }
 
@@ -144,12 +154,13 @@ fn alloc(context: *anyopaque, len: usize, alignment: mem.Alignment, ra: usize) ?
             return @ptrFromInt(next_addr);
         }
 
-        if (search_count >= max_search) {
+        if (search_count >= max_alloc_search) {
             @branchHint(.likely);
             defer t.unlock();
             // slab alignment here ensures the % slab len earlier catches the end of slots.
             const slab = PageAllocator.map(slab_len, .fromByteUnits(slab_len)) orelse return null;
             t.next_addrs[class] = @intFromPtr(slab) + slot_size;
+            t.freelist_lens[class] = 0;
             return slab;
         }
 
@@ -203,12 +214,42 @@ fn free(context: *anyopaque, memory: []u8, alignment: mem.Alignment, ra: usize)
     }
 
     const node: *usize = @alignCast(@ptrCast(memory.ptr));
+    var search_count: u8 = 0;
+
+    var t = Thread.lock();
+
+    outer: while (true) {
+        const freelist_len = t.freelist_lens[class];
+        if (freelist_len < max_freelist_len) {
+            @branchHint(.likely);
+            defer t.unlock();
+            node.* = t.frees[class];
+            t.frees[class] = @intFromPtr(node);
+            return;
+        }
 
-    const t = Thread.lock();
-    defer t.unlock();
+        if (search_count >= max_free_search) {
+            defer t.unlock();
+            t.freelist_lens[class] = freelist_len +| 1;
+            node.* = t.frees[class];
+            t.frees[class] = @intFromPtr(node);
+            return;
+        }
 
-    node.* = t.frees[class];
-    t.frees[class] = @intFromPtr(node);
+        t.unlock();
+        const cpu_count = getCpuCount();
+        assert(cpu_count != 0);
+        var index = thread_index;
+        while (true) {
+            index = (index + 1) % cpu_count;
+            t = &global.threads[index];
+            if (t.mutex.tryLock()) {
+                thread_index = index;
+                search_count += 1;
+                continue :outer;
+            }
+        }
+    }
 }
 
 fn sizeClassIndex(len: usize, alignment: mem.Alignment) usize {