master
1const std = @import("std");
2const builtin = @import("builtin");
3const Pool = @This();
4const WaitGroup = @import("WaitGroup.zig");
5
6mutex: std.Thread.Mutex = .{},
7cond: std.Thread.Condition = .{},
8run_queue: std.SinglyLinkedList = .{},
9is_running: bool = true,
10allocator: std.mem.Allocator,
11threads: if (builtin.single_threaded) [0]std.Thread else []std.Thread,
12ids: if (builtin.single_threaded) struct {
13 inline fn deinit(_: @This(), _: std.mem.Allocator) void {}
14 fn getIndex(_: @This(), _: std.Thread.Id) usize {
15 return 0;
16 }
17} else std.AutoArrayHashMapUnmanaged(std.Thread.Id, void),
18
19const Runnable = struct {
20 runFn: RunProto,
21 node: std.SinglyLinkedList.Node = .{},
22};
23
24const RunProto = *const fn (*Runnable, id: ?usize) void;
25
26pub const Options = struct {
27 allocator: std.mem.Allocator,
28 n_jobs: ?usize = null,
29 track_ids: bool = false,
30 stack_size: usize = std.Thread.SpawnConfig.default_stack_size,
31};
32
33pub fn init(pool: *Pool, options: Options) !void {
34 const allocator = options.allocator;
35
36 pool.* = .{
37 .allocator = allocator,
38 .threads = if (builtin.single_threaded) .{} else &.{},
39 .ids = .{},
40 };
41
42 if (builtin.single_threaded) {
43 return;
44 }
45
46 const thread_count = options.n_jobs orelse @max(1, std.Thread.getCpuCount() catch 1);
47 if (options.track_ids) {
48 try pool.ids.ensureTotalCapacity(allocator, 1 + thread_count);
49 pool.ids.putAssumeCapacityNoClobber(std.Thread.getCurrentId(), {});
50 }
51
52 // kill and join any threads we spawned and free memory on error.
53 pool.threads = try allocator.alloc(std.Thread, thread_count);
54 var spawned: usize = 0;
55 errdefer pool.join(spawned);
56
57 for (pool.threads) |*thread| {
58 thread.* = try std.Thread.spawn(.{
59 .stack_size = options.stack_size,
60 .allocator = allocator,
61 }, worker, .{pool});
62 spawned += 1;
63 }
64}
65
66pub fn deinit(pool: *Pool) void {
67 pool.join(pool.threads.len); // kill and join all threads.
68 pool.ids.deinit(pool.allocator);
69 pool.* = undefined;
70}
71
72fn join(pool: *Pool, spawned: usize) void {
73 if (builtin.single_threaded) {
74 return;
75 }
76
77 {
78 pool.mutex.lock();
79 defer pool.mutex.unlock();
80
81 // ensure future worker threads exit the dequeue loop
82 pool.is_running = false;
83 }
84
85 // wake up any sleeping threads (this can be done outside the mutex)
86 // then wait for all the threads we know are spawned to complete.
87 pool.cond.broadcast();
88 for (pool.threads[0..spawned]) |thread| {
89 thread.join();
90 }
91
92 pool.allocator.free(pool.threads);
93}
94
95/// Runs `func` in the thread pool, calling `WaitGroup.start` beforehand, and
96/// `WaitGroup.finish` after it returns.
97///
98/// In the case that queuing the function call fails to allocate memory, or the
99/// target is single-threaded, the function is called directly.
100pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args: anytype) void {
101 wait_group.start();
102
103 if (builtin.single_threaded) {
104 @call(.auto, func, args);
105 wait_group.finish();
106 return;
107 }
108
109 const Args = @TypeOf(args);
110 const Closure = struct {
111 arguments: Args,
112 pool: *Pool,
113 runnable: Runnable = .{ .runFn = runFn },
114 wait_group: *WaitGroup,
115
116 fn runFn(runnable: *Runnable, _: ?usize) void {
117 const closure: *@This() = @alignCast(@fieldParentPtr("runnable", runnable));
118 @call(.auto, func, closure.arguments);
119 closure.wait_group.finish();
120
121 // The thread pool's allocator is protected by the mutex.
122 const mutex = &closure.pool.mutex;
123 mutex.lock();
124 defer mutex.unlock();
125
126 closure.pool.allocator.destroy(closure);
127 }
128 };
129
130 {
131 pool.mutex.lock();
132
133 const closure = pool.allocator.create(Closure) catch {
134 pool.mutex.unlock();
135 @call(.auto, func, args);
136 wait_group.finish();
137 return;
138 };
139 closure.* = .{
140 .arguments = args,
141 .pool = pool,
142 .wait_group = wait_group,
143 };
144
145 pool.run_queue.prepend(&closure.runnable.node);
146 pool.mutex.unlock();
147 }
148
149 // Notify waiting threads outside the lock to try and keep the critical section small.
150 pool.cond.signal();
151}
152
153/// Runs `func` in the thread pool, calling `WaitGroup.start` beforehand, and
154/// `WaitGroup.finish` after it returns.
155///
156/// The first argument passed to `func` is a dense `usize` thread id, the rest
157/// of the arguments are passed from `args`. Requires the pool to have been
158/// initialized with `.track_ids = true`.
159///
160/// In the case that queuing the function call fails to allocate memory, or the
161/// target is single-threaded, the function is called directly.
162pub fn spawnWgId(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args: anytype) void {
163 wait_group.start();
164
165 if (builtin.single_threaded) {
166 @call(.auto, func, .{0} ++ args);
167 wait_group.finish();
168 return;
169 }
170
171 const Args = @TypeOf(args);
172 const Closure = struct {
173 arguments: Args,
174 pool: *Pool,
175 runnable: Runnable = .{ .runFn = runFn },
176 wait_group: *WaitGroup,
177
178 fn runFn(runnable: *Runnable, id: ?usize) void {
179 const closure: *@This() = @alignCast(@fieldParentPtr("runnable", runnable));
180 @call(.auto, func, .{id.?} ++ closure.arguments);
181 closure.wait_group.finish();
182
183 // The thread pool's allocator is protected by the mutex.
184 const mutex = &closure.pool.mutex;
185 mutex.lock();
186 defer mutex.unlock();
187
188 closure.pool.allocator.destroy(closure);
189 }
190 };
191
192 {
193 pool.mutex.lock();
194
195 const closure = pool.allocator.create(Closure) catch {
196 const id: ?usize = pool.ids.getIndex(std.Thread.getCurrentId());
197 pool.mutex.unlock();
198 @call(.auto, func, .{id.?} ++ args);
199 wait_group.finish();
200 return;
201 };
202 closure.* = .{
203 .arguments = args,
204 .pool = pool,
205 .wait_group = wait_group,
206 };
207
208 pool.run_queue.prepend(&closure.runnable.node);
209 pool.mutex.unlock();
210 }
211
212 // Notify waiting threads outside the lock to try and keep the critical section small.
213 pool.cond.signal();
214}
215
216pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void {
217 if (builtin.single_threaded) {
218 @call(.auto, func, args);
219 return;
220 }
221
222 const Args = @TypeOf(args);
223 const Closure = struct {
224 arguments: Args,
225 pool: *Pool,
226 runnable: Runnable = .{ .runFn = runFn },
227
228 fn runFn(runnable: *Runnable, _: ?usize) void {
229 const closure: *@This() = @alignCast(@fieldParentPtr("runnable", runnable));
230 @call(.auto, func, closure.arguments);
231
232 // The thread pool's allocator is protected by the mutex.
233 const mutex = &closure.pool.mutex;
234 mutex.lock();
235 defer mutex.unlock();
236
237 closure.pool.allocator.destroy(closure);
238 }
239 };
240
241 {
242 pool.mutex.lock();
243 defer pool.mutex.unlock();
244
245 const closure = try pool.allocator.create(Closure);
246 closure.* = .{
247 .arguments = args,
248 .pool = pool,
249 };
250
251 pool.run_queue.prepend(&closure.runnable.node);
252 }
253
254 // Notify waiting threads outside the lock to try and keep the critical section small.
255 pool.cond.signal();
256}
257
258test spawn {
259 const TestFn = struct {
260 fn checkRun(completed: *bool) void {
261 completed.* = true;
262 }
263 };
264
265 var completed: bool = false;
266
267 {
268 var pool: Pool = undefined;
269 try pool.init(.{
270 .allocator = std.testing.allocator,
271 });
272 defer pool.deinit();
273 try pool.spawn(TestFn.checkRun, .{&completed});
274 }
275
276 try std.testing.expectEqual(true, completed);
277}
278
279fn worker(pool: *Pool) void {
280 pool.mutex.lock();
281 defer pool.mutex.unlock();
282
283 const id: ?usize = if (pool.ids.count() > 0) @intCast(pool.ids.count()) else null;
284 if (id) |_| pool.ids.putAssumeCapacityNoClobber(std.Thread.getCurrentId(), {});
285
286 while (true) {
287 while (pool.run_queue.popFirst()) |run_node| {
288 // Temporarily unlock the mutex in order to execute the run_node
289 pool.mutex.unlock();
290 defer pool.mutex.lock();
291
292 const runnable: *Runnable = @fieldParentPtr("node", run_node);
293 runnable.runFn(runnable, id);
294 }
295
296 // Stop executing instead of waiting if the thread pool is no longer running.
297 if (pool.is_running) {
298 pool.cond.wait(&pool.mutex);
299 } else {
300 break;
301 }
302 }
303}
304
305pub fn waitAndWork(pool: *Pool, wait_group: *WaitGroup) void {
306 var id: ?usize = null;
307
308 while (!wait_group.isDone()) {
309 pool.mutex.lock();
310 if (pool.run_queue.popFirst()) |run_node| {
311 id = id orelse pool.ids.getIndex(std.Thread.getCurrentId());
312 pool.mutex.unlock();
313 const runnable: *Runnable = @fieldParentPtr("node", run_node);
314 runnable.runFn(runnable, id);
315 continue;
316 }
317
318 pool.mutex.unlock();
319 wait_group.wait();
320 return;
321 }
322}
323
324pub fn getIdCount(pool: *Pool) usize {
325 return @intCast(1 + pool.threads.len);
326}