Commit cf744aa182

Andrew Kelley <andrew@ziglang.org>
2025-11-21 21:02:59
std.Io.Threaded: slightly different semantics
while still preserving the guarantee about async() being assigned a unit of concurrency (or immediately running the task), this change: * retains the error from calling getCpuCount() * spawns all threads in detached mode, using WaitGroup to join them * treats all workers the same regardless of whether they are processing concurrent or async tasks. one thread pool does all the work, while respecting async and concurrent limits.
1 parent 13b537d
Changed files (3)
lib
std
lib/std/Io/Threaded/test.zig
@@ -10,7 +10,7 @@ test "concurrent vs main prevents deadlock via oversubscription" {
     defer threaded.deinit();
     const io = threaded.io();
 
-    threaded.cpu_count = 1;
+    threaded.async_limit = .nothing;
 
     var queue: Io.Queue(u8) = .init(&.{});
 
@@ -38,7 +38,7 @@ test "concurrent vs concurrent prevents deadlock via oversubscription" {
     defer threaded.deinit();
     const io = threaded.io();
 
-    threaded.cpu_count = 1;
+    threaded.async_limit = .nothing;
 
     var queue: Io.Queue(u8) = .init(&.{});
 
lib/std/Io/Threaded.zig
@@ -22,12 +22,30 @@ mutex: std.Thread.Mutex = .{},
 cond: std.Thread.Condition = .{},
 run_queue: std.SinglyLinkedList = .{},
 join_requested: bool = false,
-threads: std.ArrayList(std.Thread),
 stack_size: usize,
-cpu_count: usize, // 0 means no limit
-concurrency_limit: usize, // 0 means no limit
-available_thread_count: usize = 0,
-one_shot_thread_count: usize = 0,
+/// All threads are spawned detached; this is how we wait until they all exit.
+wait_group: std.Thread.WaitGroup = .{},
+/// Maximum thread pool size (excluding main thread) when dispatching async
+/// tasks. Until this limit, calls to `Io.async` when all threads are busy will
+/// cause a new thread to be spawned and permanently added to the pool. After
+/// this limit, calls to `Io.async` when all threads are busy run the task
+/// immediately.
+///
+/// Defaults to a number equal to logical CPU cores.
+async_limit: Io.Limit,
+/// Maximum thread pool size (excluding main thread) for dispatching concurrent
+/// tasks. Until this limit, calls to `Io.concurrent` will increase the thread
+/// pool size.
+///
+/// concurrent tasks. After this number, calls to `Io.concurrent` return
+/// `error.ConcurrencyUnavailable`.
+concurrent_limit: Io.Limit = .unlimited,
+/// Error from calling `std.Thread.getCpuCount` in `init`.
+cpu_count_error: ?std.Thread.CpuCountError,
+/// Number of threads that are unavailable to take tasks. To calculate
+/// available count, subtract this from either `async_limit` or
+/// `concurrent_limit`.
+busy_count: usize = 0,
 
 wsa: if (is_windows) Wsa else struct {} = .{},
 
@@ -103,19 +121,18 @@ pub fn init(
 ) Threaded {
     if (builtin.single_threaded) return .init_single_threaded;
 
+    const cpu_count = std.Thread.getCpuCount();
+
     var t: Threaded = .{
         .allocator = gpa,
-        .threads = .empty,
         .stack_size = std.Thread.SpawnConfig.default_stack_size,
-        .cpu_count = std.Thread.getCpuCount() catch 0,
-        .concurrency_limit = 0,
+        .async_limit = if (cpu_count) |n| .limited(n - 1) else |_| .nothing,
+        .cpu_count_error = if (cpu_count) |_| null else |e| e,
         .old_sig_io = undefined,
         .old_sig_pipe = undefined,
         .have_signal_handler = false,
     };
 
-    t.threads.ensureTotalCapacity(gpa, t.cpu_count) catch {};
-
     if (posix.Sigaction != void) {
         // This causes sending `posix.SIG.IO` to thread to interrupt blocking
         // syscalls, returning `posix.E.INTR`.
@@ -140,19 +157,17 @@ pub fn init(
 /// * `deinit` is safe, but unnecessary to call.
 pub const init_single_threaded: Threaded = .{
     .allocator = .failing,
-    .threads = .empty,
     .stack_size = std.Thread.SpawnConfig.default_stack_size,
-    .cpu_count = 1,
-    .concurrency_limit = 0,
+    .async_limit = .nothing,
+    .cpu_count_error = null,
+    .concurrent_limit = .nothing,
     .old_sig_io = undefined,
     .old_sig_pipe = undefined,
     .have_signal_handler = false,
 };
 
 pub fn deinit(t: *Threaded) void {
-    const gpa = t.allocator;
     t.join();
-    t.threads.deinit(gpa);
     if (is_windows and t.wsa.status == .initialized) {
         if (ws2_32.WSACleanup() != 0) recoverableOsBugDetected();
     }
@@ -171,10 +186,12 @@ fn join(t: *Threaded) void {
         t.join_requested = true;
     }
     t.cond.broadcast();
-    for (t.threads.items) |thread| thread.join();
+    t.wait_group.wait();
 }
 
 fn worker(t: *Threaded) void {
+    defer t.wait_group.finish();
+
     t.mutex.lock();
     defer t.mutex.unlock();
 
@@ -184,20 +201,13 @@ fn worker(t: *Threaded) void {
             const closure: *Closure = @fieldParentPtr("node", closure_node);
             closure.start(closure);
             t.mutex.lock();
-            t.available_thread_count += 1;
+            t.busy_count -= 1;
         }
         if (t.join_requested) break;
         t.cond.wait(&t.mutex);
     }
 }
 
-fn oneShotWorker(t: *Threaded, closure: *Closure) void {
-    closure.start(closure);
-    t.mutex.lock();
-    defer t.mutex.unlock();
-    t.one_shot_thread_count -= 1;
-}
-
 pub fn io(t: *Threaded) Io {
     return .{
         .userdata = t,
@@ -488,7 +498,7 @@ fn async(
     start: *const fn (context: *const anyopaque, result: *anyopaque) void,
 ) ?*Io.AnyFuture {
     const t: *Threaded = @ptrCast(@alignCast(userdata));
-    if (t.cpu_count == 1 or builtin.single_threaded) {
+    if (builtin.single_threaded or t.async_limit == .nothing) {
         start(context.ptr, result.ptr);
         return null;
     }
@@ -500,35 +510,29 @@ fn async(
 
     t.mutex.lock();
 
-    if (t.available_thread_count == 0) {
-        if (t.cpu_count != 0 and t.threads.items.len >= t.cpu_count) {
-            t.mutex.unlock();
-            ac.deinit(gpa);
-            start(context.ptr, result.ptr);
-            return null;
-        }
+    const busy_count = t.busy_count;
 
-        t.threads.ensureUnusedCapacity(gpa, 1) catch {
-            t.mutex.unlock();
-            ac.deinit(gpa);
-            start(context.ptr, result.ptr);
-            return null;
-        };
+    if (busy_count >= @intFromEnum(t.async_limit)) {
+        t.mutex.unlock();
+        ac.deinit(gpa);
+        start(context.ptr, result.ptr);
+        return null;
+    }
 
-        const thread = std.Thread.spawn(
-            .{ .stack_size = t.stack_size },
-            worker,
-            .{t},
-        ) catch {
+    t.busy_count = busy_count + 1;
+
+    const pool_size = t.wait_group.value();
+    if (pool_size - busy_count == 0) {
+        t.wait_group.start();
+        const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch {
+            t.wait_group.finish();
+            t.busy_count = busy_count;
             t.mutex.unlock();
             ac.deinit(gpa);
             start(context.ptr, result.ptr);
             return null;
         };
-
-        t.threads.appendAssumeCapacity(thread);
-    } else {
-        t.available_thread_count -= 1;
+        thread.detach();
     }
 
     t.run_queue.prepend(&ac.closure.node);
@@ -550,47 +554,33 @@ fn concurrent(
     const t: *Threaded = @ptrCast(@alignCast(userdata));
 
     const gpa = t.allocator;
-    const ac = AsyncClosure.init(gpa, result_len, result_alignment, context, context_alignment, start) catch {
+    const ac = AsyncClosure.init(gpa, result_len, result_alignment, context, context_alignment, start) catch
         return error.ConcurrencyUnavailable;
-    };
     errdefer ac.deinit(gpa);
 
     t.mutex.lock();
     defer t.mutex.unlock();
 
-    // If there's an avilable thread, use it.
-    if (t.available_thread_count > 0) {
-        t.available_thread_count -= 1;
-        t.run_queue.prepend(&ac.closure.node);
-        t.cond.signal();
-        return @ptrCast(ac);
-    }
+    const busy_count = t.busy_count;
 
-    // If we can spawn a normal worker, spawn it and use it.
-    if (t.cpu_count == 0 or t.threads.items.len < t.cpu_count) {
-        t.threads.ensureUnusedCapacity(gpa, 1) catch return error.ConcurrencyUnavailable;
+    if (busy_count >= @intFromEnum(t.concurrent_limit))
+        return error.ConcurrencyUnavailable;
+
+    t.busy_count = busy_count + 1;
+    errdefer t.busy_count = busy_count;
+
+    const pool_size = t.wait_group.value();
+    if (pool_size - busy_count == 0) {
+        t.wait_group.start();
+        errdefer t.wait_group.finish();
 
         const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch
             return error.ConcurrencyUnavailable;
-
-        t.threads.appendAssumeCapacity(thread);
-        t.run_queue.prepend(&ac.closure.node);
-        t.cond.signal();
-        return @ptrCast(ac);
+        thread.detach();
     }
 
-    // If we have a concurrencty limit and we havent' hit it yet,
-    // spawn a new one-shot thread.
-    if (t.concurrency_limit != 0 and t.one_shot_thread_count >= t.concurrency_limit)
-        return error.ConcurrencyUnavailable;
-
-    t.one_shot_thread_count += 1;
-    errdefer t.one_shot_thread_count -= 1;
-
-    const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, oneShotWorker, .{ t, &ac.closure }) catch
-        return error.ConcurrencyUnavailable;
-    thread.detach();
-
+    t.run_queue.prepend(&ac.closure.node);
+    t.cond.signal();
     return @ptrCast(ac);
 }
 
@@ -684,41 +674,37 @@ fn groupAsync(
     context_alignment: std.mem.Alignment,
     start: *const fn (*Io.Group, context: *const anyopaque) void,
 ) void {
-    if (builtin.single_threaded) return start(group, context.ptr);
-
     const t: *Threaded = @ptrCast(@alignCast(userdata));
+    if (builtin.single_threaded or t.async_limit == .nothing)
+        return start(group, context.ptr);
+
     const gpa = t.allocator;
     const gc = GroupClosure.init(gpa, t, group, context, context_alignment, start) catch
         return start(group, context.ptr);
 
     t.mutex.lock();
 
-    if (t.available_thread_count == 0) {
-        if (t.cpu_count != 0 and t.threads.items.len >= t.cpu_count) {
-            t.mutex.unlock();
-            gc.deinit(gpa);
-            return start(group, context.ptr);
-        }
+    const busy_count = t.busy_count;
 
-        t.threads.ensureUnusedCapacity(gpa, 1) catch {
-            t.mutex.unlock();
-            gc.deinit(gpa);
-            return start(group, context.ptr);
-        };
+    if (busy_count >= @intFromEnum(t.async_limit)) {
+        t.mutex.unlock();
+        gc.deinit(gpa);
+        return start(group, context.ptr);
+    }
 
-        const thread = std.Thread.spawn(
-            .{ .stack_size = t.stack_size },
-            worker,
-            .{t},
-        ) catch {
+    t.busy_count = busy_count + 1;
+
+    const pool_size = t.wait_group.value();
+    if (pool_size - busy_count == 0) {
+        t.wait_group.start();
+        const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch {
+            t.wait_group.finish();
+            t.busy_count = busy_count;
             t.mutex.unlock();
             gc.deinit(gpa);
             return start(group, context.ptr);
         };
-
-        t.threads.appendAssumeCapacity(thread);
-    } else {
-        t.available_thread_count -= 1;
+        thread.detach();
     }
 
     // Append to the group linked list inside the mutex to make `Io.Group.async` thread-safe.
lib/std/Thread/WaitGroup.zig
@@ -60,6 +60,10 @@ pub fn isDone(wg: *WaitGroup) bool {
     return (state / one_pending) == 0;
 }
 
+pub fn value(wg: *WaitGroup) usize {
+    return wg.state.load(.monotonic) / one_pending;
+}
+
 // Spawns a new thread for the task. This is appropriate when the callee
 // delegates all work.
 pub fn spawnManager(