Commit cb9f9bf58d

Andrew Kelley <andrew@ziglang.org>
2025-03-25 02:49:03
make thread pool satisfy async/await interface
1 parent 21b7316
Changed files (1)
lib
std
Thread
lib/std/Thread/Pool.zig
@@ -1,7 +1,8 @@
-const std = @import("std");
 const builtin = @import("builtin");
-const Pool = @This();
+const std = @import("std");
+const assert = std.debug.assert;
 const WaitGroup = @import("WaitGroup.zig");
+const Pool = @This();
 
 mutex: std.Thread.Mutex = .{},
 cond: std.Thread.Condition = .{},
@@ -307,3 +308,60 @@ pub fn waitAndWork(pool: *Pool, wait_group: *WaitGroup) void {
 pub fn getIdCount(pool: *Pool) usize {
     return @intCast(1 + pool.threads.len);
 }
+
+const AsyncClosure = struct {
+    func: *const fn (context: ?*anyopaque, result: *anyopaque) void,
+    context: ?*anyopaque,
+    run_node: std.Thread.Pool.RunQueue.Node = .{ .data = .{ .runFn = runFn } },
+    reset_event: std.Thread.ResetEvent,
+
+    fn runFn(runnable: *std.Thread.Pool.Runnable, _: ?usize) void {
+        const run_node: *std.Thread.Pool.RunQueue.Node = @fieldParentPtr("data", runnable);
+        const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node));
+        closure.func(closure.context, closure.resultPointer());
+        closure.reset_event.set();
+    }
+
+    fn resultPointer(closure: *@This()) [*]u8 {
+        const base: [*]u8 = @ptrCast(closure);
+        return base + @sizeOf(@This());
+    }
+};
+
+pub fn @"async"(
+    userdata: ?*anyopaque,
+    eager_result: []u8,
+    context: ?*anyopaque,
+    start: *const fn (context: ?*anyopaque, result: *anyopaque) void,
+) ?*std.Io.AnyFuture {
+    const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
+    pool.mutex.lock();
+
+    const gpa = pool.allocator;
+    const n = @sizeOf(AsyncClosure) + eager_result.len;
+    const closure: *AsyncClosure = @alignCast(@ptrCast(gpa.alignedAlloc(u8, @alignOf(AsyncClosure), n) catch {
+        pool.mutex.unlock();
+        start(context, eager_result.ptr);
+        return null;
+    }));
+    closure.* = .{
+        .func = start,
+        .context = context,
+        .reset_event = .{},
+    };
+    pool.run_queue.prepend(&closure.run_node);
+    pool.mutex.unlock();
+
+    pool.cond.signal();
+
+    return @ptrCast(closure);
+}
+
+pub fn @"await"(userdata: ?*anyopaque, any_future: *std.Io.AnyFuture, result: []u8) void {
+    const thread_pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
+    const closure: *AsyncClosure = @ptrCast(@alignCast(any_future));
+    closure.reset_event.wait();
+    const base: [*]align(@alignOf(AsyncClosure)) u8 = @ptrCast(closure);
+    @memcpy(result, (base + @sizeOf(AsyncClosure))[0..result.len]);
+    thread_pool.allocator.free(base[0 .. @sizeOf(AsyncClosure) + result.len]);
+}