Commit 14c3dc4c49
Changed files (3)
lib
std
lib/std/Io/ThreadPool.zig
@@ -0,0 +1,852 @@
+const builtin = @import("builtin");
+const std = @import("../std.zig");
+const Allocator = std.mem.Allocator;
+const assert = std.debug.assert;
+const WaitGroup = std.Thread.WaitGroup;
+const Io = std.Io;
+const Pool = @This();
+
+/// Must be a thread-safe allocator.
+allocator: std.mem.Allocator,
+mutex: std.Thread.Mutex = .{},
+cond: std.Thread.Condition = .{},
+run_queue: std.SinglyLinkedList = .{},
+is_running: bool = true,
+threads: std.ArrayListUnmanaged(std.Thread),
+ids: if (builtin.single_threaded) struct {
+ inline fn deinit(_: @This(), _: std.mem.Allocator) void {}
+ fn getIndex(_: @This(), _: std.Thread.Id) usize {
+ return 0;
+ }
+} else std.AutoArrayHashMapUnmanaged(std.Thread.Id, void),
+stack_size: usize,
+
+threadlocal var current_closure: ?*AsyncClosure = null;
+
+pub const Runnable = struct {
+ runFn: RunProto,
+ node: std.SinglyLinkedList.Node = .{},
+};
+
+pub const RunProto = *const fn (*Runnable, id: ?usize) void;
+
+pub const Options = struct {
+ allocator: std.mem.Allocator,
+ n_jobs: ?usize = null,
+ track_ids: bool = false,
+ stack_size: usize = std.Thread.SpawnConfig.default_stack_size,
+};
+
+pub fn init(pool: *Pool, options: Options) !void {
+ const gpa = options.allocator;
+ const thread_count = options.n_jobs orelse @max(1, std.Thread.getCpuCount() catch 1);
+ const threads = try gpa.alloc(std.Thread, thread_count);
+ errdefer gpa.free(threads);
+
+ pool.* = .{
+ .allocator = gpa,
+ .threads = .initBuffer(threads),
+ .ids = .{},
+ .stack_size = options.stack_size,
+ };
+
+ if (builtin.single_threaded) return;
+
+ if (options.track_ids) {
+ try pool.ids.ensureTotalCapacity(gpa, 1 + thread_count);
+ pool.ids.putAssumeCapacityNoClobber(std.Thread.getCurrentId(), {});
+ }
+}
+
+pub fn deinit(pool: *Pool) void {
+ const gpa = pool.allocator;
+ pool.join();
+ pool.threads.deinit(gpa);
+ pool.ids.deinit(gpa);
+ pool.* = undefined;
+}
+
+fn join(pool: *Pool) void {
+ if (builtin.single_threaded) return;
+
+ {
+ pool.mutex.lock();
+ defer pool.mutex.unlock();
+
+ // ensure future worker threads exit the dequeue loop
+ pool.is_running = false;
+ }
+
+ // wake up any sleeping threads (this can be done outside the mutex)
+ // then wait for all the threads we know are spawned to complete.
+ pool.cond.broadcast();
+ for (pool.threads.items) |thread| thread.join();
+}
+
+/// Runs `func` in the thread pool, calling `WaitGroup.start` beforehand, and
+/// `WaitGroup.finish` after it returns.
+///
+/// In the case that queuing the function call fails to allocate memory, or the
+/// target is single-threaded, the function is called directly.
+pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args: anytype) void {
+ wait_group.start();
+
+ if (builtin.single_threaded) {
+ @call(.auto, func, args);
+ wait_group.finish();
+ return;
+ }
+
+ const Args = @TypeOf(args);
+ const Closure = struct {
+ arguments: Args,
+ pool: *Pool,
+ runnable: Runnable = .{ .runFn = runFn },
+ wait_group: *WaitGroup,
+
+ fn runFn(runnable: *Runnable, _: ?usize) void {
+ const closure: *@This() = @alignCast(@fieldParentPtr("runnable", runnable));
+ @call(.auto, func, closure.arguments);
+ closure.wait_group.finish();
+ closure.pool.allocator.destroy(closure);
+ }
+ };
+
+ pool.mutex.lock();
+
+ const gpa = pool.allocator;
+ const closure = gpa.create(Closure) catch {
+ pool.mutex.unlock();
+ @call(.auto, func, args);
+ wait_group.finish();
+ return;
+ };
+ closure.* = .{
+ .arguments = args,
+ .pool = pool,
+ .wait_group = wait_group,
+ };
+
+ pool.run_queue.prepend(&closure.runnable.node);
+
+ if (pool.threads.items.len < pool.threads.capacity) {
+ pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{
+ .stack_size = pool.stack_size,
+ .allocator = gpa,
+ }, worker, .{pool}) catch t: {
+ pool.threads.items.len -= 1;
+ break :t undefined;
+ };
+ }
+
+ pool.mutex.unlock();
+ pool.cond.signal();
+}
+
+/// Runs `func` in the thread pool, calling `WaitGroup.start` beforehand, and
+/// `WaitGroup.finish` after it returns.
+///
+/// The first argument passed to `func` is a dense `usize` thread id, the rest
+/// of the arguments are passed from `args`. Requires the pool to have been
+/// initialized with `.track_ids = true`.
+///
+/// In the case that queuing the function call fails to allocate memory, or the
+/// target is single-threaded, the function is called directly.
+pub fn spawnWgId(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args: anytype) void {
+ wait_group.start();
+
+ if (builtin.single_threaded) {
+ @call(.auto, func, .{0} ++ args);
+ wait_group.finish();
+ return;
+ }
+
+ const Args = @TypeOf(args);
+ const Closure = struct {
+ arguments: Args,
+ pool: *Pool,
+ runnable: Runnable = .{ .runFn = runFn },
+ wait_group: *WaitGroup,
+
+ fn runFn(runnable: *Runnable, id: ?usize) void {
+ const closure: *@This() = @alignCast(@fieldParentPtr("runnable", runnable));
+ @call(.auto, func, .{id.?} ++ closure.arguments);
+ closure.wait_group.finish();
+ closure.pool.allocator.destroy(closure);
+ }
+ };
+
+ pool.mutex.lock();
+
+ const gpa = pool.allocator;
+ const closure = gpa.create(Closure) catch {
+ const id: ?usize = pool.ids.getIndex(std.Thread.getCurrentId());
+ pool.mutex.unlock();
+ @call(.auto, func, .{id.?} ++ args);
+ wait_group.finish();
+ return;
+ };
+ closure.* = .{
+ .arguments = args,
+ .pool = pool,
+ .wait_group = wait_group,
+ };
+
+ pool.run_queue.prepend(&closure.runnable.node);
+
+ if (pool.threads.items.len < pool.threads.capacity) {
+ pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{
+ .stack_size = pool.stack_size,
+ .allocator = gpa,
+ }, worker, .{pool}) catch t: {
+ pool.threads.items.len -= 1;
+ break :t undefined;
+ };
+ }
+
+ pool.mutex.unlock();
+ pool.cond.signal();
+}
+
+pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) void {
+ if (builtin.single_threaded) {
+ @call(.auto, func, args);
+ return;
+ }
+
+ const Args = @TypeOf(args);
+ const Closure = struct {
+ arguments: Args,
+ pool: *Pool,
+ runnable: Runnable = .{ .runFn = runFn },
+
+ fn runFn(runnable: *Runnable, _: ?usize) void {
+ const closure: *@This() = @alignCast(@fieldParentPtr("runnable", runnable));
+ @call(.auto, func, closure.arguments);
+ closure.pool.allocator.destroy(closure);
+ }
+ };
+
+ pool.mutex.lock();
+
+ const gpa = pool.allocator;
+ const closure = gpa.create(Closure) catch {
+ pool.mutex.unlock();
+ @call(.auto, func, args);
+ return;
+ };
+ closure.* = .{
+ .arguments = args,
+ .pool = pool,
+ };
+
+ pool.run_queue.prepend(&closure.runnable.node);
+
+ if (pool.threads.items.len < pool.threads.capacity) {
+ pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{
+ .stack_size = pool.stack_size,
+ .allocator = gpa,
+ }, worker, .{pool}) catch t: {
+ pool.threads.items.len -= 1;
+ break :t undefined;
+ };
+ }
+
+ pool.mutex.unlock();
+ pool.cond.signal();
+}
+
+test spawn {
+ const TestFn = struct {
+ fn checkRun(completed: *bool) void {
+ completed.* = true;
+ }
+ };
+
+ var completed: bool = false;
+
+ {
+ var pool: Pool = undefined;
+ try pool.init(.{
+ .allocator = std.testing.allocator,
+ });
+ defer pool.deinit();
+ pool.spawn(TestFn.checkRun, .{&completed});
+ }
+
+ try std.testing.expectEqual(true, completed);
+}
+
+fn worker(pool: *Pool) void {
+ pool.mutex.lock();
+ defer pool.mutex.unlock();
+
+ const id: ?usize = if (pool.ids.count() > 0) @intCast(pool.ids.count()) else null;
+ if (id) |_| pool.ids.putAssumeCapacityNoClobber(std.Thread.getCurrentId(), {});
+
+ while (true) {
+ while (pool.run_queue.popFirst()) |run_node| {
+ // Temporarily unlock the mutex in order to execute the run_node
+ pool.mutex.unlock();
+ defer pool.mutex.lock();
+
+ const runnable: *Runnable = @fieldParentPtr("node", run_node);
+ runnable.runFn(runnable, id);
+ }
+
+ // Stop executing instead of waiting if the thread pool is no longer running.
+ if (pool.is_running) {
+ pool.cond.wait(&pool.mutex);
+ } else {
+ break;
+ }
+ }
+}
+
+pub fn waitAndWork(pool: *Pool, wait_group: *WaitGroup) void {
+ var id: ?usize = null;
+
+ while (!wait_group.isDone()) {
+ pool.mutex.lock();
+ if (pool.run_queue.popFirst()) |run_node| {
+ id = id orelse pool.ids.getIndex(std.Thread.getCurrentId());
+ pool.mutex.unlock();
+ const runnable: *Runnable = @fieldParentPtr("node", run_node);
+ runnable.runFn(runnable, id);
+ continue;
+ }
+
+ pool.mutex.unlock();
+ wait_group.wait();
+ return;
+ }
+}
+
+pub fn getIdCount(pool: *Pool) usize {
+ return @intCast(1 + pool.threads.items.len);
+}
+
+pub fn io(pool: *Pool) Io {
+ return .{
+ .userdata = pool,
+ .vtable = &.{
+ .async = async,
+ .await = await,
+ .go = go,
+ .cancel = cancel,
+ .cancelRequested = cancelRequested,
+ .select = select,
+
+ .mutexLock = mutexLock,
+ .mutexUnlock = mutexUnlock,
+
+ .conditionWait = conditionWait,
+ .conditionWake = conditionWake,
+
+ .createFile = createFile,
+ .openFile = openFile,
+ .closeFile = closeFile,
+ .pread = pread,
+ .pwrite = pwrite,
+
+ .now = now,
+ .sleep = sleep,
+ },
+ };
+}
+
+const AsyncClosure = struct {
+ func: *const fn (context: *anyopaque, result: *anyopaque) void,
+ runnable: Runnable = .{ .runFn = runFn },
+ reset_event: std.Thread.ResetEvent,
+ select_condition: ?*std.Thread.ResetEvent,
+ cancel_tid: std.Thread.Id,
+ context_offset: usize,
+ result_offset: usize,
+
+ const done_reset_event: *std.Thread.ResetEvent = @ptrFromInt(@alignOf(std.Thread.ResetEvent));
+
+ const canceling_tid: std.Thread.Id = switch (@typeInfo(std.Thread.Id)) {
+ .int => |int_info| switch (int_info.signedness) {
+ .signed => -1,
+ .unsigned => std.math.maxInt(std.Thread.Id),
+ },
+ .pointer => @ptrFromInt(std.math.maxInt(usize)),
+ else => @compileError("unsupported std.Thread.Id: " ++ @typeName(std.Thread.Id)),
+ };
+
+ fn runFn(runnable: *Pool.Runnable, _: ?usize) void {
+ const closure: *AsyncClosure = @alignCast(@fieldParentPtr("runnable", runnable));
+ const tid = std.Thread.getCurrentId();
+ if (@cmpxchgStrong(
+ std.Thread.Id,
+ &closure.cancel_tid,
+ 0,
+ tid,
+ .acq_rel,
+ .acquire,
+ )) |cancel_tid| {
+ assert(cancel_tid == canceling_tid);
+ return;
+ }
+ current_closure = closure;
+ closure.func(closure.contextPointer(), closure.resultPointer());
+ current_closure = null;
+ if (@cmpxchgStrong(
+ std.Thread.Id,
+ &closure.cancel_tid,
+ tid,
+ 0,
+ .acq_rel,
+ .acquire,
+ )) |cancel_tid| assert(cancel_tid == canceling_tid);
+
+ if (@atomicRmw(
+ ?*std.Thread.ResetEvent,
+ &closure.select_condition,
+ .Xchg,
+ done_reset_event,
+ .release,
+ )) |select_reset| {
+ assert(select_reset != done_reset_event);
+ select_reset.set();
+ }
+ closure.reset_event.set();
+ }
+
+ fn contextOffset(context_alignment: std.mem.Alignment) usize {
+ return context_alignment.forward(@sizeOf(AsyncClosure));
+ }
+
+ fn resultOffset(
+ context_alignment: std.mem.Alignment,
+ context_len: usize,
+ result_alignment: std.mem.Alignment,
+ ) usize {
+ return result_alignment.forward(contextOffset(context_alignment) + context_len);
+ }
+
+ fn resultPointer(closure: *AsyncClosure) [*]u8 {
+ const base: [*]u8 = @ptrCast(closure);
+ return base + closure.result_offset;
+ }
+
+ fn contextPointer(closure: *AsyncClosure) [*]u8 {
+ const base: [*]u8 = @ptrCast(closure);
+ return base + closure.context_offset;
+ }
+
+ fn waitAndFree(closure: *AsyncClosure, gpa: Allocator, result: []u8) void {
+ closure.reset_event.wait();
+ const base: [*]align(@alignOf(AsyncClosure)) u8 = @ptrCast(closure);
+ @memcpy(result, closure.resultPointer()[0..result.len]);
+ gpa.free(base[0 .. closure.result_offset + result.len]);
+ }
+};
+
+fn async(
+ userdata: ?*anyopaque,
+ result: []u8,
+ result_alignment: std.mem.Alignment,
+ context: []const u8,
+ context_alignment: std.mem.Alignment,
+ start: *const fn (context: *const anyopaque, result: *anyopaque) void,
+) ?*Io.AnyFuture {
+ const pool: *Pool = @alignCast(@ptrCast(userdata));
+ pool.mutex.lock();
+
+ const gpa = pool.allocator;
+ const context_offset = context_alignment.forward(@sizeOf(AsyncClosure));
+ const result_offset = result_alignment.forward(context_offset + context.len);
+ const n = result_offset + result.len;
+ const closure: *AsyncClosure = @alignCast(@ptrCast(gpa.alignedAlloc(u8, .of(AsyncClosure), n) catch {
+ pool.mutex.unlock();
+ start(context.ptr, result.ptr);
+ return null;
+ }));
+ closure.* = .{
+ .func = start,
+ .context_offset = context_offset,
+ .result_offset = result_offset,
+ .reset_event = .{},
+ .cancel_tid = 0,
+ .select_condition = null,
+ };
+ @memcpy(closure.contextPointer()[0..context.len], context);
+ pool.run_queue.prepend(&closure.runnable.node);
+
+ if (pool.threads.items.len < pool.threads.capacity) {
+ pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{
+ .stack_size = pool.stack_size,
+ .allocator = gpa,
+ }, worker, .{pool}) catch t: {
+ pool.threads.items.len -= 1;
+ break :t undefined;
+ };
+ }
+
+ pool.mutex.unlock();
+ pool.cond.signal();
+
+ return @ptrCast(closure);
+}
+
+const DetachedClosure = struct {
+ pool: *Pool,
+ func: *const fn (context: *anyopaque) void,
+ run_node: Pool.RunQueue.Node = .{ .data = .{ .runFn = runFn } },
+ context_alignment: std.mem.Alignment,
+ context_len: usize,
+
+ fn runFn(runnable: *Pool.Runnable, _: ?usize) void {
+ const run_node: *Pool.RunQueue.Node = @fieldParentPtr("data", runnable);
+ const closure: *DetachedClosure = @alignCast(@fieldParentPtr("run_node", run_node));
+ closure.func(closure.contextPointer());
+ const gpa = closure.pool.allocator;
+ const base: [*]align(@alignOf(DetachedClosure)) u8 = @ptrCast(closure);
+ gpa.free(base[0..contextEnd(closure.context_alignment, closure.context_len)]);
+ }
+
+ fn contextOffset(context_alignment: std.mem.Alignment) usize {
+ return context_alignment.forward(@sizeOf(DetachedClosure));
+ }
+
+ fn contextEnd(context_alignment: std.mem.Alignment, context_len: usize) usize {
+ return contextOffset(context_alignment) + context_len;
+ }
+
+ fn contextPointer(closure: *DetachedClosure) [*]u8 {
+ const base: [*]u8 = @ptrCast(closure);
+ return base + contextOffset(closure.context_alignment);
+ }
+};
+
+fn go(
+ userdata: ?*anyopaque,
+ context: []const u8,
+ context_alignment: std.mem.Alignment,
+ start: *const fn (context: *const anyopaque) void,
+) void {
+ const pool: *Pool = @alignCast(@ptrCast(userdata));
+ pool.mutex.lock();
+
+ const gpa = pool.allocator;
+ const n = DetachedClosure.contextEnd(context_alignment, context.len);
+ const closure: *DetachedClosure = @alignCast(@ptrCast(gpa.alignedAlloc(u8, .of(DetachedClosure), n) catch {
+ pool.mutex.unlock();
+ start(context.ptr);
+ return;
+ }));
+ closure.* = .{
+ .pool = pool,
+ .func = start,
+ .context_alignment = context_alignment,
+ .context_len = context.len,
+ };
+ @memcpy(closure.contextPointer()[0..context.len], context);
+ pool.run_queue.prepend(&closure.run_node);
+
+ if (pool.threads.items.len < pool.threads.capacity) {
+ pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{
+ .stack_size = pool.stack_size,
+ .allocator = gpa,
+ }, worker, .{pool}) catch t: {
+ pool.threads.items.len -= 1;
+ break :t undefined;
+ };
+ }
+
+ pool.mutex.unlock();
+ pool.cond.signal();
+}
+
+fn await(
+ userdata: ?*anyopaque,
+ any_future: *std.Io.AnyFuture,
+ result: []u8,
+ result_alignment: std.mem.Alignment,
+) void {
+ _ = result_alignment;
+ const pool: *Pool = @alignCast(@ptrCast(userdata));
+ const closure: *AsyncClosure = @ptrCast(@alignCast(any_future));
+ closure.waitAndFree(pool.allocator, result);
+}
+
+fn cancel(
+ userdata: ?*anyopaque,
+ any_future: *Io.AnyFuture,
+ result: []u8,
+ result_alignment: std.mem.Alignment,
+) void {
+ _ = result_alignment;
+ const pool: *Pool = @alignCast(@ptrCast(userdata));
+ const closure: *AsyncClosure = @ptrCast(@alignCast(any_future));
+ switch (@atomicRmw(
+ std.Thread.Id,
+ &closure.cancel_tid,
+ .Xchg,
+ AsyncClosure.canceling_tid,
+ .acq_rel,
+ )) {
+ 0, AsyncClosure.canceling_tid => {},
+ else => |cancel_tid| switch (builtin.os.tag) {
+ .linux => _ = std.os.linux.tgkill(
+ std.os.linux.getpid(),
+ @bitCast(cancel_tid),
+ std.posix.SIG.IO,
+ ),
+ else => {},
+ },
+ }
+ closure.waitAndFree(pool.allocator, result);
+}
+
+fn cancelRequested(userdata: ?*anyopaque) bool {
+ const pool: *Pool = @alignCast(@ptrCast(userdata));
+ _ = pool;
+ const closure = current_closure orelse return false;
+ return @atomicLoad(std.Thread.Id, &closure.cancel_tid, .acquire) == AsyncClosure.canceling_tid;
+}
+
+fn checkCancel(pool: *Pool) error{Canceled}!void {
+ if (cancelRequested(pool)) return error.Canceled;
+}
+
+fn mutexLock(userdata: ?*anyopaque, prev_state: Io.Mutex.State, mutex: *Io.Mutex) error{Canceled}!void {
+ _ = userdata;
+ if (prev_state == .contended) {
+ std.Thread.Futex.wait(@ptrCast(&mutex.state), @intFromEnum(Io.Mutex.State.contended));
+ }
+ while (@atomicRmw(
+ Io.Mutex.State,
+ &mutex.state,
+ .Xchg,
+ .contended,
+ .acquire,
+ ) != .unlocked) {
+ std.Thread.Futex.wait(@ptrCast(&mutex.state), @intFromEnum(Io.Mutex.State.contended));
+ }
+}
+fn mutexUnlock(userdata: ?*anyopaque, prev_state: Io.Mutex.State, mutex: *Io.Mutex) void {
+ _ = userdata;
+ _ = prev_state;
+ if (@atomicRmw(Io.Mutex.State, &mutex.state, .Xchg, .unlocked, .release) == .contended) {
+ std.Thread.Futex.wake(@ptrCast(&mutex.state), 1);
+ }
+}
+
+fn conditionWait(userdata: ?*anyopaque, cond: *Io.Condition, mutex: *Io.Mutex) Io.Cancelable!void {
+ const pool: *Pool = @alignCast(@ptrCast(userdata));
+ comptime assert(@TypeOf(cond.state) == u64);
+ const ints: *[2]std.atomic.Value(u32) = @ptrCast(&cond.state);
+ const cond_state = &ints[0];
+ const cond_epoch = &ints[1];
+ const one_waiter = 1;
+ const waiter_mask = 0xffff;
+ const one_signal = 1 << 16;
+ const signal_mask = 0xffff << 16;
+ // Observe the epoch, then check the state again to see if we should wake up.
+ // The epoch must be observed before we check the state or we could potentially miss a wake() and deadlock:
+ //
+ // - T1: s = LOAD(&state)
+ // - T2: UPDATE(&s, signal)
+ // - T2: UPDATE(&epoch, 1) + FUTEX_WAKE(&epoch)
+ // - T1: e = LOAD(&epoch) (was reordered after the state load)
+ // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed the state update + the epoch change)
+ //
+ // Acquire barrier to ensure the epoch load happens before the state load.
+ var epoch = cond_epoch.load(.acquire);
+ var state = cond_state.fetchAdd(one_waiter, .monotonic);
+ assert(state & waiter_mask != waiter_mask);
+ state += one_waiter;
+
+ mutex.unlock(pool.io());
+ defer mutex.lock(pool.io()) catch @panic("TODO");
+
+ var futex_deadline = std.Thread.Futex.Deadline.init(null);
+
+ while (true) {
+ futex_deadline.wait(cond_epoch, epoch) catch |err| switch (err) {
+ error.Timeout => unreachable,
+ };
+
+ epoch = cond_epoch.load(.acquire);
+ state = cond_state.load(.monotonic);
+
+ // Try to wake up by consuming a signal and decremented the waiter we added previously.
+ // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return.
+ while (state & signal_mask != 0) {
+ const new_state = state - one_waiter - one_signal;
+ state = cond_state.cmpxchgWeak(state, new_state, .acquire, .monotonic) orelse return;
+ }
+ }
+}
+
+fn conditionWake(userdata: ?*anyopaque, cond: *Io.Condition, wake: Io.Condition.Wake) void {
+ const pool: *Pool = @alignCast(@ptrCast(userdata));
+ _ = pool;
+ comptime assert(@TypeOf(cond.state) == u64);
+ const ints: *[2]std.atomic.Value(u32) = @ptrCast(&cond.state);
+ const cond_state = &ints[0];
+ const cond_epoch = &ints[1];
+ const one_waiter = 1;
+ const waiter_mask = 0xffff;
+ const one_signal = 1 << 16;
+ const signal_mask = 0xffff << 16;
+ var state = cond_state.load(.monotonic);
+ while (true) {
+ const waiters = (state & waiter_mask) / one_waiter;
+ const signals = (state & signal_mask) / one_signal;
+
+ // Reserves which waiters to wake up by incrementing the signals count.
+ // Therefore, the signals count is always less than or equal to the waiters count.
+ // We don't need to Futex.wake if there's nothing to wake up or if other wake() threads have reserved to wake up the current waiters.
+ const wakeable = waiters - signals;
+ if (wakeable == 0) {
+ return;
+ }
+
+ const to_wake = switch (wake) {
+ .one => 1,
+ .all => wakeable,
+ };
+
+ // Reserve the amount of waiters to wake by incrementing the signals count.
+ // Release barrier ensures code before the wake() happens before the signal it posted and consumed by the wait() threads.
+ const new_state = state + (one_signal * to_wake);
+ state = cond_state.cmpxchgWeak(state, new_state, .release, .monotonic) orelse {
+ // Wake up the waiting threads we reserved above by changing the epoch value.
+ // NOTE: a waiting thread could miss a wake up if *exactly* ((1<<32)-1) wake()s happen between it observing the epoch and sleeping on it.
+ // This is very unlikely due to how many precise amount of Futex.wake() calls that would be between the waiting thread's potential preemption.
+ //
+ // Release barrier ensures the signal being added to the state happens before the epoch is changed.
+ // If not, the waiting thread could potentially deadlock from missing both the state and epoch change:
+ //
+ // - T2: UPDATE(&epoch, 1) (reordered before the state change)
+ // - T1: e = LOAD(&epoch)
+ // - T1: s = LOAD(&state)
+ // - T2: UPDATE(&state, signal) + FUTEX_WAKE(&epoch)
+ // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed both epoch change and state change)
+ _ = cond_epoch.fetchAdd(1, .release);
+ std.Thread.Futex.wake(cond_epoch, to_wake);
+ return;
+ };
+ }
+}
+
+fn createFile(
+ userdata: ?*anyopaque,
+ dir: Io.Dir,
+ sub_path: []const u8,
+ flags: Io.File.CreateFlags,
+) Io.File.OpenError!Io.File {
+ const pool: *Pool = @alignCast(@ptrCast(userdata));
+ try pool.checkCancel();
+ const fs_dir: std.fs.Dir = .{ .fd = dir.handle };
+ const fs_file = try fs_dir.createFile(sub_path, flags);
+ return .{ .handle = fs_file.handle };
+}
+
+fn openFile(
+ userdata: ?*anyopaque,
+ dir: Io.Dir,
+ sub_path: []const u8,
+ flags: Io.File.OpenFlags,
+) Io.File.OpenError!Io.File {
+ const pool: *Pool = @alignCast(@ptrCast(userdata));
+ try pool.checkCancel();
+ const fs_dir: std.fs.Dir = .{ .fd = dir.handle };
+ const fs_file = try fs_dir.openFile(sub_path, flags);
+ return .{ .handle = fs_file.handle };
+}
+
+fn closeFile(userdata: ?*anyopaque, file: Io.File) void {
+ const pool: *Pool = @alignCast(@ptrCast(userdata));
+ _ = pool;
+ const fs_file: std.fs.File = .{ .handle = file.handle };
+ return fs_file.close();
+}
+
+fn pread(userdata: ?*anyopaque, file: Io.File, buffer: []u8, offset: std.posix.off_t) Io.File.PReadError!usize {
+ const pool: *Pool = @alignCast(@ptrCast(userdata));
+ try pool.checkCancel();
+ const fs_file: std.fs.File = .{ .handle = file.handle };
+ return switch (offset) {
+ -1 => fs_file.read(buffer),
+ else => fs_file.pread(buffer, @bitCast(offset)),
+ };
+}
+
+fn pwrite(userdata: ?*anyopaque, file: Io.File, buffer: []const u8, offset: std.posix.off_t) Io.File.PWriteError!usize {
+ const pool: *Pool = @alignCast(@ptrCast(userdata));
+ try pool.checkCancel();
+ const fs_file: std.fs.File = .{ .handle = file.handle };
+ return switch (offset) {
+ -1 => fs_file.write(buffer),
+ else => fs_file.pwrite(buffer, @bitCast(offset)),
+ };
+}
+
+fn now(userdata: ?*anyopaque, clockid: std.posix.clockid_t) Io.ClockGetTimeError!Io.Timestamp {
+ const pool: *Pool = @alignCast(@ptrCast(userdata));
+ try pool.checkCancel();
+ const timespec = try std.posix.clock_gettime(clockid);
+ return @enumFromInt(@as(i128, timespec.sec) * std.time.ns_per_s + timespec.nsec);
+}
+
+fn sleep(userdata: ?*anyopaque, clockid: std.posix.clockid_t, deadline: Io.Deadline) Io.SleepError!void {
+ const pool: *Pool = @alignCast(@ptrCast(userdata));
+ const deadline_nanoseconds: i96 = switch (deadline) {
+ .duration => |duration| duration.nanoseconds,
+ .timestamp => |timestamp| @intFromEnum(timestamp),
+ };
+ var timespec: std.posix.timespec = .{
+ .sec = @intCast(@divFloor(deadline_nanoseconds, std.time.ns_per_s)),
+ .nsec = @intCast(@mod(deadline_nanoseconds, std.time.ns_per_s)),
+ };
+ while (true) {
+ try pool.checkCancel();
+ switch (std.os.linux.E.init(std.os.linux.clock_nanosleep(clockid, .{ .ABSTIME = switch (deadline) {
+ .duration => false,
+ .timestamp => true,
+ } }, ×pec, ×pec))) {
+ .SUCCESS => return,
+ .FAULT => unreachable,
+ .INTR => {},
+ .INVAL => return error.UnsupportedClock,
+ else => |err| return std.posix.unexpectedErrno(err),
+ }
+ }
+}
+
+fn select(userdata: ?*anyopaque, futures: []const *Io.AnyFuture) usize {
+ const pool: *Pool = @alignCast(@ptrCast(userdata));
+ _ = pool;
+
+ var reset_event: std.Thread.ResetEvent = .{};
+
+ for (futures, 0..) |future, i| {
+ const closure: *AsyncClosure = @ptrCast(@alignCast(future));
+ if (@atomicRmw(?*std.Thread.ResetEvent, &closure.select_condition, .Xchg, &reset_event, .seq_cst) == AsyncClosure.done_reset_event) {
+ for (futures[0..i]) |cleanup_future| {
+ const cleanup_closure: *AsyncClosure = @ptrCast(@alignCast(cleanup_future));
+ if (@atomicRmw(?*std.Thread.ResetEvent, &cleanup_closure.select_condition, .Xchg, null, .seq_cst) == AsyncClosure.done_reset_event) {
+ cleanup_closure.reset_event.wait(); // Ensure no reference to our stack-allocated reset_event.
+ }
+ }
+ return i;
+ }
+ }
+
+ reset_event.wait();
+
+ var result: ?usize = null;
+ for (futures, 0..) |future, i| {
+ const closure: *AsyncClosure = @ptrCast(@alignCast(future));
+ if (@atomicRmw(?*std.Thread.ResetEvent, &closure.select_condition, .Xchg, null, .seq_cst) == AsyncClosure.done_reset_event) {
+ closure.reset_event.wait(); // Ensure no reference to our stack-allocated reset_event.
+ if (result == null) result = i; // In case multiple are ready, return first.
+ }
+ }
+ return result.?;
+}
lib/std/Thread/Pool.zig
@@ -1,34 +1,27 @@
-const builtin = @import("builtin");
const std = @import("std");
-const Allocator = std.mem.Allocator;
-const assert = std.debug.assert;
-const WaitGroup = @import("WaitGroup.zig");
-const Io = std.Io;
+const builtin = @import("builtin");
const Pool = @This();
+const WaitGroup = @import("WaitGroup.zig");
-/// Must be a thread-safe allocator.
-allocator: std.mem.Allocator,
mutex: std.Thread.Mutex = .{},
cond: std.Thread.Condition = .{},
run_queue: std.SinglyLinkedList = .{},
is_running: bool = true,
-threads: std.ArrayListUnmanaged(std.Thread),
+allocator: std.mem.Allocator,
+threads: if (builtin.single_threaded) [0]std.Thread else []std.Thread,
ids: if (builtin.single_threaded) struct {
inline fn deinit(_: @This(), _: std.mem.Allocator) void {}
fn getIndex(_: @This(), _: std.Thread.Id) usize {
return 0;
}
} else std.AutoArrayHashMapUnmanaged(std.Thread.Id, void),
-stack_size: usize,
-threadlocal var current_closure: ?*AsyncClosure = null;
-
-pub const Runnable = struct {
+const Runnable = struct {
runFn: RunProto,
node: std.SinglyLinkedList.Node = .{},
};
-pub const RunProto = *const fn (*Runnable, id: ?usize) void;
+const RunProto = *const fn (*Runnable, id: ?usize) void;
pub const Options = struct {
allocator: std.mem.Allocator,
@@ -38,36 +31,48 @@ pub const Options = struct {
};
pub fn init(pool: *Pool, options: Options) !void {
- const gpa = options.allocator;
- const thread_count = options.n_jobs orelse @max(1, std.Thread.getCpuCount() catch 1);
- const threads = try gpa.alloc(std.Thread, thread_count);
- errdefer gpa.free(threads);
+ const allocator = options.allocator;
pool.* = .{
- .allocator = gpa,
- .threads = .initBuffer(threads),
+ .allocator = allocator,
+ .threads = if (builtin.single_threaded) .{} else &.{},
.ids = .{},
- .stack_size = options.stack_size,
};
- if (builtin.single_threaded) return;
+ if (builtin.single_threaded) {
+ return;
+ }
+ const thread_count = options.n_jobs orelse @max(1, std.Thread.getCpuCount() catch 1);
if (options.track_ids) {
- try pool.ids.ensureTotalCapacity(gpa, 1 + thread_count);
+ try pool.ids.ensureTotalCapacity(allocator, 1 + thread_count);
pool.ids.putAssumeCapacityNoClobber(std.Thread.getCurrentId(), {});
}
+
+ // kill and join any threads we spawned and free memory on error.
+ pool.threads = try allocator.alloc(std.Thread, thread_count);
+ var spawned: usize = 0;
+ errdefer pool.join(spawned);
+
+ for (pool.threads) |*thread| {
+ thread.* = try std.Thread.spawn(.{
+ .stack_size = options.stack_size,
+ .allocator = allocator,
+ }, worker, .{pool});
+ spawned += 1;
+ }
}
pub fn deinit(pool: *Pool) void {
- const gpa = pool.allocator;
- pool.join();
- pool.threads.deinit(gpa);
- pool.ids.deinit(gpa);
+ pool.join(pool.threads.len); // kill and join all threads.
+ pool.ids.deinit(pool.allocator);
pool.* = undefined;
}
-fn join(pool: *Pool) void {
- if (builtin.single_threaded) return;
+fn join(pool: *Pool, spawned: usize) void {
+ if (builtin.single_threaded) {
+ return;
+ }
{
pool.mutex.lock();
@@ -80,7 +85,11 @@ fn join(pool: *Pool) void {
// wake up any sleeping threads (this can be done outside the mutex)
// then wait for all the threads we know are spawned to complete.
pool.cond.broadcast();
- for (pool.threads.items) |thread| thread.join();
+ for (pool.threads[0..spawned]) |thread| {
+ thread.join();
+ }
+
+ pool.allocator.free(pool.threads);
}
/// Runs `func` in the thread pool, calling `WaitGroup.start` beforehand, and
@@ -108,38 +117,36 @@ pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args
const closure: *@This() = @alignCast(@fieldParentPtr("runnable", runnable));
@call(.auto, func, closure.arguments);
closure.wait_group.finish();
- closure.pool.allocator.destroy(closure);
- }
- };
- pool.mutex.lock();
+ // The thread pool's allocator is protected by the mutex.
+ const mutex = &closure.pool.mutex;
+ mutex.lock();
+ defer mutex.unlock();
- const gpa = pool.allocator;
- const closure = gpa.create(Closure) catch {
- pool.mutex.unlock();
- @call(.auto, func, args);
- wait_group.finish();
- return;
- };
- closure.* = .{
- .arguments = args,
- .pool = pool,
- .wait_group = wait_group,
+ closure.pool.allocator.destroy(closure);
+ }
};
- pool.run_queue.prepend(&closure.runnable.node);
+ {
+ pool.mutex.lock();
- if (pool.threads.items.len < pool.threads.capacity) {
- pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{
- .stack_size = pool.stack_size,
- .allocator = gpa,
- }, worker, .{pool}) catch t: {
- pool.threads.items.len -= 1;
- break :t undefined;
+ const closure = pool.allocator.create(Closure) catch {
+ pool.mutex.unlock();
+ @call(.auto, func, args);
+ wait_group.finish();
+ return;
+ };
+ closure.* = .{
+ .arguments = args,
+ .pool = pool,
+ .wait_group = wait_group,
};
+
+ pool.run_queue.prepend(&closure.runnable.node);
+ pool.mutex.unlock();
}
- pool.mutex.unlock();
+ // Notify waiting threads outside the lock to try and keep the critical section small.
pool.cond.signal();
}
@@ -172,43 +179,41 @@ pub fn spawnWgId(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, ar
const closure: *@This() = @alignCast(@fieldParentPtr("runnable", runnable));
@call(.auto, func, .{id.?} ++ closure.arguments);
closure.wait_group.finish();
- closure.pool.allocator.destroy(closure);
- }
- };
- pool.mutex.lock();
+ // The thread pool's allocator is protected by the mutex.
+ const mutex = &closure.pool.mutex;
+ mutex.lock();
+ defer mutex.unlock();
- const gpa = pool.allocator;
- const closure = gpa.create(Closure) catch {
- const id: ?usize = pool.ids.getIndex(std.Thread.getCurrentId());
- pool.mutex.unlock();
- @call(.auto, func, .{id.?} ++ args);
- wait_group.finish();
- return;
- };
- closure.* = .{
- .arguments = args,
- .pool = pool,
- .wait_group = wait_group,
+ closure.pool.allocator.destroy(closure);
+ }
};
- pool.run_queue.prepend(&closure.runnable.node);
+ {
+ pool.mutex.lock();
- if (pool.threads.items.len < pool.threads.capacity) {
- pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{
- .stack_size = pool.stack_size,
- .allocator = gpa,
- }, worker, .{pool}) catch t: {
- pool.threads.items.len -= 1;
- break :t undefined;
+ const closure = pool.allocator.create(Closure) catch {
+ const id: ?usize = pool.ids.getIndex(std.Thread.getCurrentId());
+ pool.mutex.unlock();
+ @call(.auto, func, .{id.?} ++ args);
+ wait_group.finish();
+ return;
+ };
+ closure.* = .{
+ .arguments = args,
+ .pool = pool,
+ .wait_group = wait_group,
};
+
+ pool.run_queue.prepend(&closure.runnable.node);
+ pool.mutex.unlock();
}
- pool.mutex.unlock();
+ // Notify waiting threads outside the lock to try and keep the critical section small.
pool.cond.signal();
}
-pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) void {
+pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void {
if (builtin.single_threaded) {
@call(.auto, func, args);
return;
@@ -223,36 +228,30 @@ pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) void {
fn runFn(runnable: *Runnable, _: ?usize) void {
const closure: *@This() = @alignCast(@fieldParentPtr("runnable", runnable));
@call(.auto, func, closure.arguments);
- closure.pool.allocator.destroy(closure);
- }
- };
- pool.mutex.lock();
+ // The thread pool's allocator is protected by the mutex.
+ const mutex = &closure.pool.mutex;
+ mutex.lock();
+ defer mutex.unlock();
- const gpa = pool.allocator;
- const closure = gpa.create(Closure) catch {
- pool.mutex.unlock();
- @call(.auto, func, args);
- return;
- };
- closure.* = .{
- .arguments = args,
- .pool = pool,
+ closure.pool.allocator.destroy(closure);
+ }
};
- pool.run_queue.prepend(&closure.runnable.node);
+ {
+ pool.mutex.lock();
+ defer pool.mutex.unlock();
- if (pool.threads.items.len < pool.threads.capacity) {
- pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{
- .stack_size = pool.stack_size,
- .allocator = gpa,
- }, worker, .{pool}) catch t: {
- pool.threads.items.len -= 1;
- break :t undefined;
+ const closure = try pool.allocator.create(Closure);
+ closure.* = .{
+ .arguments = args,
+ .pool = pool,
};
+
+ pool.run_queue.prepend(&closure.runnable.node);
}
- pool.mutex.unlock();
+ // Notify waiting threads outside the lock to try and keep the critical section small.
pool.cond.signal();
}
@@ -271,7 +270,7 @@ test spawn {
.allocator = std.testing.allocator,
});
defer pool.deinit();
- pool.spawn(TestFn.checkRun, .{&completed});
+ try pool.spawn(TestFn.checkRun, .{&completed});
}
try std.testing.expectEqual(true, completed);
@@ -323,530 +322,5 @@ pub fn waitAndWork(pool: *Pool, wait_group: *WaitGroup) void {
}
pub fn getIdCount(pool: *Pool) usize {
- return @intCast(1 + pool.threads.items.len);
-}
-
-pub fn io(pool: *Pool) Io {
- return .{
- .userdata = pool,
- .vtable = &.{
- .async = async,
- .await = await,
- .go = go,
- .cancel = cancel,
- .cancelRequested = cancelRequested,
- .select = select,
-
- .mutexLock = mutexLock,
- .mutexUnlock = mutexUnlock,
-
- .conditionWait = conditionWait,
- .conditionWake = conditionWake,
-
- .createFile = createFile,
- .openFile = openFile,
- .closeFile = closeFile,
- .pread = pread,
- .pwrite = pwrite,
-
- .now = now,
- .sleep = sleep,
- },
- };
-}
-
-const AsyncClosure = struct {
- func: *const fn (context: *anyopaque, result: *anyopaque) void,
- runnable: Runnable = .{ .runFn = runFn },
- reset_event: std.Thread.ResetEvent,
- select_condition: ?*std.Thread.ResetEvent,
- cancel_tid: std.Thread.Id,
- context_offset: usize,
- result_offset: usize,
-
- const done_reset_event: *std.Thread.ResetEvent = @ptrFromInt(@alignOf(std.Thread.ResetEvent));
-
- const canceling_tid: std.Thread.Id = switch (@typeInfo(std.Thread.Id)) {
- .int => |int_info| switch (int_info.signedness) {
- .signed => -1,
- .unsigned => std.math.maxInt(std.Thread.Id),
- },
- .pointer => @ptrFromInt(std.math.maxInt(usize)),
- else => @compileError("unsupported std.Thread.Id: " ++ @typeName(std.Thread.Id)),
- };
-
- fn runFn(runnable: *std.Thread.Pool.Runnable, _: ?usize) void {
- const closure: *AsyncClosure = @alignCast(@fieldParentPtr("runnable", runnable));
- const tid = std.Thread.getCurrentId();
- if (@cmpxchgStrong(
- std.Thread.Id,
- &closure.cancel_tid,
- 0,
- tid,
- .acq_rel,
- .acquire,
- )) |cancel_tid| {
- assert(cancel_tid == canceling_tid);
- return;
- }
- current_closure = closure;
- closure.func(closure.contextPointer(), closure.resultPointer());
- current_closure = null;
- if (@cmpxchgStrong(
- std.Thread.Id,
- &closure.cancel_tid,
- tid,
- 0,
- .acq_rel,
- .acquire,
- )) |cancel_tid| assert(cancel_tid == canceling_tid);
-
- if (@atomicRmw(
- ?*std.Thread.ResetEvent,
- &closure.select_condition,
- .Xchg,
- done_reset_event,
- .release,
- )) |select_reset| {
- assert(select_reset != done_reset_event);
- select_reset.set();
- }
- closure.reset_event.set();
- }
-
- fn contextOffset(context_alignment: std.mem.Alignment) usize {
- return context_alignment.forward(@sizeOf(AsyncClosure));
- }
-
- fn resultOffset(
- context_alignment: std.mem.Alignment,
- context_len: usize,
- result_alignment: std.mem.Alignment,
- ) usize {
- return result_alignment.forward(contextOffset(context_alignment) + context_len);
- }
-
- fn resultPointer(closure: *AsyncClosure) [*]u8 {
- const base: [*]u8 = @ptrCast(closure);
- return base + closure.result_offset;
- }
-
- fn contextPointer(closure: *AsyncClosure) [*]u8 {
- const base: [*]u8 = @ptrCast(closure);
- return base + closure.context_offset;
- }
-
- fn waitAndFree(closure: *AsyncClosure, gpa: Allocator, result: []u8) void {
- closure.reset_event.wait();
- const base: [*]align(@alignOf(AsyncClosure)) u8 = @ptrCast(closure);
- @memcpy(result, closure.resultPointer()[0..result.len]);
- gpa.free(base[0 .. closure.result_offset + result.len]);
- }
-};
-
-fn async(
- userdata: ?*anyopaque,
- result: []u8,
- result_alignment: std.mem.Alignment,
- context: []const u8,
- context_alignment: std.mem.Alignment,
- start: *const fn (context: *const anyopaque, result: *anyopaque) void,
-) ?*Io.AnyFuture {
- const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
- pool.mutex.lock();
-
- const gpa = pool.allocator;
- const context_offset = context_alignment.forward(@sizeOf(AsyncClosure));
- const result_offset = result_alignment.forward(context_offset + context.len);
- const n = result_offset + result.len;
- const closure: *AsyncClosure = @alignCast(@ptrCast(gpa.alignedAlloc(u8, @alignOf(AsyncClosure), n) catch {
- pool.mutex.unlock();
- start(context.ptr, result.ptr);
- return null;
- }));
- closure.* = .{
- .func = start,
- .context_offset = context_offset,
- .result_offset = result_offset,
- .reset_event = .{},
- .cancel_tid = 0,
- .select_condition = null,
- };
- @memcpy(closure.contextPointer()[0..context.len], context);
- pool.run_queue.prepend(&closure.runnable.node);
-
- if (pool.threads.items.len < pool.threads.capacity) {
- pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{
- .stack_size = pool.stack_size,
- .allocator = gpa,
- }, worker, .{pool}) catch t: {
- pool.threads.items.len -= 1;
- break :t undefined;
- };
- }
-
- pool.mutex.unlock();
- pool.cond.signal();
-
- return @ptrCast(closure);
-}
-
-const DetachedClosure = struct {
- pool: *Pool,
- func: *const fn (context: *anyopaque) void,
- run_node: std.Thread.Pool.RunQueue.Node = .{ .data = .{ .runFn = runFn } },
- context_alignment: std.mem.Alignment,
- context_len: usize,
-
- fn runFn(runnable: *std.Thread.Pool.Runnable, _: ?usize) void {
- const run_node: *std.Thread.Pool.RunQueue.Node = @fieldParentPtr("data", runnable);
- const closure: *DetachedClosure = @alignCast(@fieldParentPtr("run_node", run_node));
- closure.func(closure.contextPointer());
- const gpa = closure.pool.allocator;
- const base: [*]align(@alignOf(DetachedClosure)) u8 = @ptrCast(closure);
- gpa.free(base[0..contextEnd(closure.context_alignment, closure.context_len)]);
- }
-
- fn contextOffset(context_alignment: std.mem.Alignment) usize {
- return context_alignment.forward(@sizeOf(DetachedClosure));
- }
-
- fn contextEnd(context_alignment: std.mem.Alignment, context_len: usize) usize {
- return contextOffset(context_alignment) + context_len;
- }
-
- fn contextPointer(closure: *DetachedClosure) [*]u8 {
- const base: [*]u8 = @ptrCast(closure);
- return base + contextOffset(closure.context_alignment);
- }
-};
-
-fn go(
- userdata: ?*anyopaque,
- context: []const u8,
- context_alignment: std.mem.Alignment,
- start: *const fn (context: *const anyopaque) void,
-) void {
- const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
- pool.mutex.lock();
-
- const gpa = pool.allocator;
- const n = DetachedClosure.contextEnd(context_alignment, context.len);
- const closure: *DetachedClosure = @alignCast(@ptrCast(gpa.alignedAlloc(u8, @alignOf(DetachedClosure), n) catch {
- pool.mutex.unlock();
- start(context.ptr);
- return;
- }));
- closure.* = .{
- .pool = pool,
- .func = start,
- .context_alignment = context_alignment,
- .context_len = context.len,
- };
- @memcpy(closure.contextPointer()[0..context.len], context);
- pool.run_queue.prepend(&closure.run_node);
-
- if (pool.threads.items.len < pool.threads.capacity) {
- pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{
- .stack_size = pool.stack_size,
- .allocator = gpa,
- }, worker, .{pool}) catch t: {
- pool.threads.items.len -= 1;
- break :t undefined;
- };
- }
-
- pool.mutex.unlock();
- pool.cond.signal();
-}
-
-fn await(
- userdata: ?*anyopaque,
- any_future: *std.Io.AnyFuture,
- result: []u8,
- result_alignment: std.mem.Alignment,
-) void {
- _ = result_alignment;
- const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
- const closure: *AsyncClosure = @ptrCast(@alignCast(any_future));
- closure.waitAndFree(pool.allocator, result);
-}
-
-fn cancel(
- userdata: ?*anyopaque,
- any_future: *Io.AnyFuture,
- result: []u8,
- result_alignment: std.mem.Alignment,
-) void {
- _ = result_alignment;
- const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
- const closure: *AsyncClosure = @ptrCast(@alignCast(any_future));
- switch (@atomicRmw(
- std.Thread.Id,
- &closure.cancel_tid,
- .Xchg,
- AsyncClosure.canceling_tid,
- .acq_rel,
- )) {
- 0, AsyncClosure.canceling_tid => {},
- else => |cancel_tid| switch (builtin.os.tag) {
- .linux => _ = std.os.linux.tgkill(
- std.os.linux.getpid(),
- @bitCast(cancel_tid),
- std.posix.SIG.IO,
- ),
- else => {},
- },
- }
- closure.waitAndFree(pool.allocator, result);
-}
-
-fn cancelRequested(userdata: ?*anyopaque) bool {
- const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
- _ = pool;
- const closure = current_closure orelse return false;
- return @atomicLoad(std.Thread.Id, &closure.cancel_tid, .acquire) == AsyncClosure.canceling_tid;
-}
-
-fn checkCancel(pool: *Pool) error{Canceled}!void {
- if (cancelRequested(pool)) return error.Canceled;
-}
-
-fn mutexLock(userdata: ?*anyopaque, prev_state: Io.Mutex.State, mutex: *Io.Mutex) error{Canceled}!void {
- _ = userdata;
- if (prev_state == .contended) {
- std.Thread.Futex.wait(@ptrCast(&mutex.state), @intFromEnum(Io.Mutex.State.contended));
- }
- while (@atomicRmw(
- Io.Mutex.State,
- &mutex.state,
- .Xchg,
- .contended,
- .acquire,
- ) != .unlocked) {
- std.Thread.Futex.wait(@ptrCast(&mutex.state), @intFromEnum(Io.Mutex.State.contended));
- }
-}
-fn mutexUnlock(userdata: ?*anyopaque, prev_state: Io.Mutex.State, mutex: *Io.Mutex) void {
- _ = userdata;
- _ = prev_state;
- if (@atomicRmw(Io.Mutex.State, &mutex.state, .Xchg, .unlocked, .release) == .contended) {
- std.Thread.Futex.wake(@ptrCast(&mutex.state), 1);
- }
-}
-
-fn conditionWait(userdata: ?*anyopaque, cond: *Io.Condition, mutex: *Io.Mutex) Io.Cancelable!void {
- const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
- comptime assert(@TypeOf(cond.state) == u64);
- const ints: *[2]std.atomic.Value(u32) = @ptrCast(&cond.state);
- const cond_state = &ints[0];
- const cond_epoch = &ints[1];
- const one_waiter = 1;
- const waiter_mask = 0xffff;
- const one_signal = 1 << 16;
- const signal_mask = 0xffff << 16;
- // Observe the epoch, then check the state again to see if we should wake up.
- // The epoch must be observed before we check the state or we could potentially miss a wake() and deadlock:
- //
- // - T1: s = LOAD(&state)
- // - T2: UPDATE(&s, signal)
- // - T2: UPDATE(&epoch, 1) + FUTEX_WAKE(&epoch)
- // - T1: e = LOAD(&epoch) (was reordered after the state load)
- // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed the state update + the epoch change)
- //
- // Acquire barrier to ensure the epoch load happens before the state load.
- var epoch = cond_epoch.load(.acquire);
- var state = cond_state.fetchAdd(one_waiter, .monotonic);
- assert(state & waiter_mask != waiter_mask);
- state += one_waiter;
-
- mutex.unlock(pool.io());
- defer mutex.lock(pool.io()) catch @panic("TODO");
-
- var futex_deadline = std.Thread.Futex.Deadline.init(null);
-
- while (true) {
- futex_deadline.wait(cond_epoch, epoch) catch |err| switch (err) {
- error.Timeout => unreachable,
- };
-
- epoch = cond_epoch.load(.acquire);
- state = cond_state.load(.monotonic);
-
- // Try to wake up by consuming a signal and decremented the waiter we added previously.
- // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return.
- while (state & signal_mask != 0) {
- const new_state = state - one_waiter - one_signal;
- state = cond_state.cmpxchgWeak(state, new_state, .acquire, .monotonic) orelse return;
- }
- }
-}
-
-fn conditionWake(userdata: ?*anyopaque, cond: *Io.Condition, wake: Io.Condition.Wake) void {
- const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
- _ = pool;
- comptime assert(@TypeOf(cond.state) == u64);
- const ints: *[2]std.atomic.Value(u32) = @ptrCast(&cond.state);
- const cond_state = &ints[0];
- const cond_epoch = &ints[1];
- const one_waiter = 1;
- const waiter_mask = 0xffff;
- const one_signal = 1 << 16;
- const signal_mask = 0xffff << 16;
- var state = cond_state.load(.monotonic);
- while (true) {
- const waiters = (state & waiter_mask) / one_waiter;
- const signals = (state & signal_mask) / one_signal;
-
- // Reserves which waiters to wake up by incrementing the signals count.
- // Therefore, the signals count is always less than or equal to the waiters count.
- // We don't need to Futex.wake if there's nothing to wake up or if other wake() threads have reserved to wake up the current waiters.
- const wakeable = waiters - signals;
- if (wakeable == 0) {
- return;
- }
-
- const to_wake = switch (wake) {
- .one => 1,
- .all => wakeable,
- };
-
- // Reserve the amount of waiters to wake by incrementing the signals count.
- // Release barrier ensures code before the wake() happens before the signal it posted and consumed by the wait() threads.
- const new_state = state + (one_signal * to_wake);
- state = cond_state.cmpxchgWeak(state, new_state, .release, .monotonic) orelse {
- // Wake up the waiting threads we reserved above by changing the epoch value.
- // NOTE: a waiting thread could miss a wake up if *exactly* ((1<<32)-1) wake()s happen between it observing the epoch and sleeping on it.
- // This is very unlikely due to how many precise amount of Futex.wake() calls that would be between the waiting thread's potential preemption.
- //
- // Release barrier ensures the signal being added to the state happens before the epoch is changed.
- // If not, the waiting thread could potentially deadlock from missing both the state and epoch change:
- //
- // - T2: UPDATE(&epoch, 1) (reordered before the state change)
- // - T1: e = LOAD(&epoch)
- // - T1: s = LOAD(&state)
- // - T2: UPDATE(&state, signal) + FUTEX_WAKE(&epoch)
- // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed both epoch change and state change)
- _ = cond_epoch.fetchAdd(1, .release);
- std.Thread.Futex.wake(cond_epoch, to_wake);
- return;
- };
- }
-}
-
-fn createFile(
- userdata: ?*anyopaque,
- dir: Io.Dir,
- sub_path: []const u8,
- flags: Io.File.CreateFlags,
-) Io.File.OpenError!Io.File {
- const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
- try pool.checkCancel();
- const fs_dir: std.fs.Dir = .{ .fd = dir.handle };
- const fs_file = try fs_dir.createFile(sub_path, flags);
- return .{ .handle = fs_file.handle };
-}
-
-fn openFile(
- userdata: ?*anyopaque,
- dir: Io.Dir,
- sub_path: []const u8,
- flags: Io.File.OpenFlags,
-) Io.File.OpenError!Io.File {
- const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
- try pool.checkCancel();
- const fs_dir: std.fs.Dir = .{ .fd = dir.handle };
- const fs_file = try fs_dir.openFile(sub_path, flags);
- return .{ .handle = fs_file.handle };
-}
-
-fn closeFile(userdata: ?*anyopaque, file: Io.File) void {
- const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
- _ = pool;
- const fs_file: std.fs.File = .{ .handle = file.handle };
- return fs_file.close();
-}
-
-fn pread(userdata: ?*anyopaque, file: Io.File, buffer: []u8, offset: std.posix.off_t) Io.File.PReadError!usize {
- const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
- try pool.checkCancel();
- const fs_file: std.fs.File = .{ .handle = file.handle };
- return switch (offset) {
- -1 => fs_file.read(buffer),
- else => fs_file.pread(buffer, @bitCast(offset)),
- };
-}
-
-fn pwrite(userdata: ?*anyopaque, file: Io.File, buffer: []const u8, offset: std.posix.off_t) Io.File.PWriteError!usize {
- const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
- try pool.checkCancel();
- const fs_file: std.fs.File = .{ .handle = file.handle };
- return switch (offset) {
- -1 => fs_file.write(buffer),
- else => fs_file.pwrite(buffer, @bitCast(offset)),
- };
-}
-
-fn now(userdata: ?*anyopaque, clockid: std.posix.clockid_t) Io.ClockGetTimeError!Io.Timestamp {
- const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
- try pool.checkCancel();
- const timespec = try std.posix.clock_gettime(clockid);
- return @enumFromInt(@as(i128, timespec.sec) * std.time.ns_per_s + timespec.nsec);
-}
-
-fn sleep(userdata: ?*anyopaque, clockid: std.posix.clockid_t, deadline: Io.Deadline) Io.SleepError!void {
- const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
- const deadline_nanoseconds: i96 = switch (deadline) {
- .duration => |duration| duration.nanoseconds,
- .timestamp => |timestamp| @intFromEnum(timestamp),
- };
- var timespec: std.posix.timespec = .{
- .sec = @intCast(@divFloor(deadline_nanoseconds, std.time.ns_per_s)),
- .nsec = @intCast(@mod(deadline_nanoseconds, std.time.ns_per_s)),
- };
- while (true) {
- try pool.checkCancel();
- switch (std.os.linux.E.init(std.os.linux.clock_nanosleep(clockid, .{ .ABSTIME = switch (deadline) {
- .duration => false,
- .timestamp => true,
- } }, ×pec, ×pec))) {
- .SUCCESS => return,
- .FAULT => unreachable,
- .INTR => {},
- .INVAL => return error.UnsupportedClock,
- else => |err| return std.posix.unexpectedErrno(err),
- }
- }
-}
-
-fn select(userdata: ?*anyopaque, futures: []const *Io.AnyFuture) usize {
- const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
- _ = pool;
-
- var reset_event: std.Thread.ResetEvent = .{};
-
- for (futures, 0..) |future, i| {
- const closure: *AsyncClosure = @ptrCast(@alignCast(future));
- if (@atomicRmw(?*std.Thread.ResetEvent, &closure.select_condition, .Xchg, &reset_event, .seq_cst) == AsyncClosure.done_reset_event) {
- for (futures[0..i]) |cleanup_future| {
- const cleanup_closure: *AsyncClosure = @ptrCast(@alignCast(cleanup_future));
- if (@atomicRmw(?*std.Thread.ResetEvent, &cleanup_closure.select_condition, .Xchg, null, .seq_cst) == AsyncClosure.done_reset_event) {
- cleanup_closure.reset_event.wait(); // Ensure no reference to our stack-allocated reset_event.
- }
- }
- return i;
- }
- }
-
- reset_event.wait();
-
- var result: ?usize = null;
- for (futures, 0..) |future, i| {
- const closure: *AsyncClosure = @ptrCast(@alignCast(future));
- if (@atomicRmw(?*std.Thread.ResetEvent, &closure.select_condition, .Xchg, null, .seq_cst) == AsyncClosure.done_reset_event) {
- closure.reset_event.wait(); // Ensure no reference to our stack-allocated reset_event.
- if (result == null) result = i; // In case multiple are ready, return first.
- }
- }
- return result.?;
+ return @intCast(1 + pool.threads.len);
}
lib/std/Io.zig
@@ -558,6 +558,7 @@ test {
const Io = @This();
pub const EventLoop = @import("Io/EventLoop.zig");
+pub const ThreadPool = @import("Io/ThreadPool.zig");
userdata: ?*anyopaque,
vtable: *const VTable,