master
  1const std = @import("std");
  2const builtin = @import("builtin");
  3const Pool = @This();
  4const WaitGroup = @import("WaitGroup.zig");
  5
  6mutex: std.Thread.Mutex = .{},
  7cond: std.Thread.Condition = .{},
  8run_queue: std.SinglyLinkedList = .{},
  9is_running: bool = true,
 10allocator: std.mem.Allocator,
 11threads: if (builtin.single_threaded) [0]std.Thread else []std.Thread,
 12ids: if (builtin.single_threaded) struct {
 13    inline fn deinit(_: @This(), _: std.mem.Allocator) void {}
 14    fn getIndex(_: @This(), _: std.Thread.Id) usize {
 15        return 0;
 16    }
 17} else std.AutoArrayHashMapUnmanaged(std.Thread.Id, void),
 18
 19const Runnable = struct {
 20    runFn: RunProto,
 21    node: std.SinglyLinkedList.Node = .{},
 22};
 23
 24const RunProto = *const fn (*Runnable, id: ?usize) void;
 25
 26pub const Options = struct {
 27    allocator: std.mem.Allocator,
 28    n_jobs: ?usize = null,
 29    track_ids: bool = false,
 30    stack_size: usize = std.Thread.SpawnConfig.default_stack_size,
 31};
 32
 33pub fn init(pool: *Pool, options: Options) !void {
 34    const allocator = options.allocator;
 35
 36    pool.* = .{
 37        .allocator = allocator,
 38        .threads = if (builtin.single_threaded) .{} else &.{},
 39        .ids = .{},
 40    };
 41
 42    if (builtin.single_threaded) {
 43        return;
 44    }
 45
 46    const thread_count = options.n_jobs orelse @max(1, std.Thread.getCpuCount() catch 1);
 47    if (options.track_ids) {
 48        try pool.ids.ensureTotalCapacity(allocator, 1 + thread_count);
 49        pool.ids.putAssumeCapacityNoClobber(std.Thread.getCurrentId(), {});
 50    }
 51
 52    // kill and join any threads we spawned and free memory on error.
 53    pool.threads = try allocator.alloc(std.Thread, thread_count);
 54    var spawned: usize = 0;
 55    errdefer pool.join(spawned);
 56
 57    for (pool.threads) |*thread| {
 58        thread.* = try std.Thread.spawn(.{
 59            .stack_size = options.stack_size,
 60            .allocator = allocator,
 61        }, worker, .{pool});
 62        spawned += 1;
 63    }
 64}
 65
 66pub fn deinit(pool: *Pool) void {
 67    pool.join(pool.threads.len); // kill and join all threads.
 68    pool.ids.deinit(pool.allocator);
 69    pool.* = undefined;
 70}
 71
 72fn join(pool: *Pool, spawned: usize) void {
 73    if (builtin.single_threaded) {
 74        return;
 75    }
 76
 77    {
 78        pool.mutex.lock();
 79        defer pool.mutex.unlock();
 80
 81        // ensure future worker threads exit the dequeue loop
 82        pool.is_running = false;
 83    }
 84
 85    // wake up any sleeping threads (this can be done outside the mutex)
 86    // then wait for all the threads we know are spawned to complete.
 87    pool.cond.broadcast();
 88    for (pool.threads[0..spawned]) |thread| {
 89        thread.join();
 90    }
 91
 92    pool.allocator.free(pool.threads);
 93}
 94
 95/// Runs `func` in the thread pool, calling `WaitGroup.start` beforehand, and
 96/// `WaitGroup.finish` after it returns.
 97///
 98/// In the case that queuing the function call fails to allocate memory, or the
 99/// target is single-threaded, the function is called directly.
100pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args: anytype) void {
101    wait_group.start();
102
103    if (builtin.single_threaded) {
104        @call(.auto, func, args);
105        wait_group.finish();
106        return;
107    }
108
109    const Args = @TypeOf(args);
110    const Closure = struct {
111        arguments: Args,
112        pool: *Pool,
113        runnable: Runnable = .{ .runFn = runFn },
114        wait_group: *WaitGroup,
115
116        fn runFn(runnable: *Runnable, _: ?usize) void {
117            const closure: *@This() = @alignCast(@fieldParentPtr("runnable", runnable));
118            @call(.auto, func, closure.arguments);
119            closure.wait_group.finish();
120
121            // The thread pool's allocator is protected by the mutex.
122            const mutex = &closure.pool.mutex;
123            mutex.lock();
124            defer mutex.unlock();
125
126            closure.pool.allocator.destroy(closure);
127        }
128    };
129
130    {
131        pool.mutex.lock();
132
133        const closure = pool.allocator.create(Closure) catch {
134            pool.mutex.unlock();
135            @call(.auto, func, args);
136            wait_group.finish();
137            return;
138        };
139        closure.* = .{
140            .arguments = args,
141            .pool = pool,
142            .wait_group = wait_group,
143        };
144
145        pool.run_queue.prepend(&closure.runnable.node);
146        pool.mutex.unlock();
147    }
148
149    // Notify waiting threads outside the lock to try and keep the critical section small.
150    pool.cond.signal();
151}
152
153/// Runs `func` in the thread pool, calling `WaitGroup.start` beforehand, and
154/// `WaitGroup.finish` after it returns.
155///
156/// The first argument passed to `func` is a dense `usize` thread id, the rest
157/// of the arguments are passed from `args`. Requires the pool to have been
158/// initialized with `.track_ids = true`.
159///
160/// In the case that queuing the function call fails to allocate memory, or the
161/// target is single-threaded, the function is called directly.
162pub fn spawnWgId(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args: anytype) void {
163    wait_group.start();
164
165    if (builtin.single_threaded) {
166        @call(.auto, func, .{0} ++ args);
167        wait_group.finish();
168        return;
169    }
170
171    const Args = @TypeOf(args);
172    const Closure = struct {
173        arguments: Args,
174        pool: *Pool,
175        runnable: Runnable = .{ .runFn = runFn },
176        wait_group: *WaitGroup,
177
178        fn runFn(runnable: *Runnable, id: ?usize) void {
179            const closure: *@This() = @alignCast(@fieldParentPtr("runnable", runnable));
180            @call(.auto, func, .{id.?} ++ closure.arguments);
181            closure.wait_group.finish();
182
183            // The thread pool's allocator is protected by the mutex.
184            const mutex = &closure.pool.mutex;
185            mutex.lock();
186            defer mutex.unlock();
187
188            closure.pool.allocator.destroy(closure);
189        }
190    };
191
192    {
193        pool.mutex.lock();
194
195        const closure = pool.allocator.create(Closure) catch {
196            const id: ?usize = pool.ids.getIndex(std.Thread.getCurrentId());
197            pool.mutex.unlock();
198            @call(.auto, func, .{id.?} ++ args);
199            wait_group.finish();
200            return;
201        };
202        closure.* = .{
203            .arguments = args,
204            .pool = pool,
205            .wait_group = wait_group,
206        };
207
208        pool.run_queue.prepend(&closure.runnable.node);
209        pool.mutex.unlock();
210    }
211
212    // Notify waiting threads outside the lock to try and keep the critical section small.
213    pool.cond.signal();
214}
215
216pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void {
217    if (builtin.single_threaded) {
218        @call(.auto, func, args);
219        return;
220    }
221
222    const Args = @TypeOf(args);
223    const Closure = struct {
224        arguments: Args,
225        pool: *Pool,
226        runnable: Runnable = .{ .runFn = runFn },
227
228        fn runFn(runnable: *Runnable, _: ?usize) void {
229            const closure: *@This() = @alignCast(@fieldParentPtr("runnable", runnable));
230            @call(.auto, func, closure.arguments);
231
232            // The thread pool's allocator is protected by the mutex.
233            const mutex = &closure.pool.mutex;
234            mutex.lock();
235            defer mutex.unlock();
236
237            closure.pool.allocator.destroy(closure);
238        }
239    };
240
241    {
242        pool.mutex.lock();
243        defer pool.mutex.unlock();
244
245        const closure = try pool.allocator.create(Closure);
246        closure.* = .{
247            .arguments = args,
248            .pool = pool,
249        };
250
251        pool.run_queue.prepend(&closure.runnable.node);
252    }
253
254    // Notify waiting threads outside the lock to try and keep the critical section small.
255    pool.cond.signal();
256}
257
258test spawn {
259    const TestFn = struct {
260        fn checkRun(completed: *bool) void {
261            completed.* = true;
262        }
263    };
264
265    var completed: bool = false;
266
267    {
268        var pool: Pool = undefined;
269        try pool.init(.{
270            .allocator = std.testing.allocator,
271        });
272        defer pool.deinit();
273        try pool.spawn(TestFn.checkRun, .{&completed});
274    }
275
276    try std.testing.expectEqual(true, completed);
277}
278
279fn worker(pool: *Pool) void {
280    pool.mutex.lock();
281    defer pool.mutex.unlock();
282
283    const id: ?usize = if (pool.ids.count() > 0) @intCast(pool.ids.count()) else null;
284    if (id) |_| pool.ids.putAssumeCapacityNoClobber(std.Thread.getCurrentId(), {});
285
286    while (true) {
287        while (pool.run_queue.popFirst()) |run_node| {
288            // Temporarily unlock the mutex in order to execute the run_node
289            pool.mutex.unlock();
290            defer pool.mutex.lock();
291
292            const runnable: *Runnable = @fieldParentPtr("node", run_node);
293            runnable.runFn(runnable, id);
294        }
295
296        // Stop executing instead of waiting if the thread pool is no longer running.
297        if (pool.is_running) {
298            pool.cond.wait(&pool.mutex);
299        } else {
300            break;
301        }
302    }
303}
304
305pub fn waitAndWork(pool: *Pool, wait_group: *WaitGroup) void {
306    var id: ?usize = null;
307
308    while (!wait_group.isDone()) {
309        pool.mutex.lock();
310        if (pool.run_queue.popFirst()) |run_node| {
311            id = id orelse pool.ids.getIndex(std.Thread.getCurrentId());
312            pool.mutex.unlock();
313            const runnable: *Runnable = @fieldParentPtr("node", run_node);
314            runnable.runFn(runnable, id);
315            continue;
316        }
317
318        pool.mutex.unlock();
319        wait_group.wait();
320        return;
321    }
322}
323
324pub fn getIdCount(pool: *Pool) usize {
325    return @intCast(1 + pool.threads.len);
326}