Commit 31ed2d6715

Andrew Kelley <andrew@ziglang.org>
2025-03-28 04:53:14
fix context passing in threaded Io impl
1 parent f1dd06b
Changed files (3)
lib
lib/std/Io/EventLoop.zig
@@ -55,6 +55,17 @@ const Fiber = struct {
     }
 };
 
+pub fn io(el: *EventLoop) Io {
+    return .{
+        .userdata = el,
+        .vtable = &.{
+            .@"async" = @"async",
+            .@"await" = @"await",
+        },
+    };
+}
+
+
 pub fn init(el: *EventLoop, gpa: Allocator) error{OutOfMemory}!void {
     const threads_bytes = ((std.Thread.getCpuCount() catch 1) -| 1) * @sizeOf(Thread);
     const idle_context_offset = std.mem.alignForward(usize, threads_bytes, @alignOf(Context));
lib/std/Thread/Pool.zig
@@ -309,46 +309,80 @@ pub fn getIdCount(pool: *Pool) usize {
     return @intCast(1 + pool.threads.len);
 }
 
+pub fn io(pool: *Pool) std.Io {
+    return .{
+        .userdata = pool,
+        .vtable = &.{
+            .@"async" = @"async",
+            .@"await" = @"await",
+        },
+    };
+}
+
 const AsyncClosure = struct {
-    func: *const fn (context: ?*anyopaque, result: *anyopaque) void,
-    context: ?*anyopaque,
+    func: *const fn (context: *anyopaque, result: *anyopaque) void,
     run_node: std.Thread.Pool.RunQueue.Node = .{ .data = .{ .runFn = runFn } },
     reset_event: std.Thread.ResetEvent,
+    context_offset: usize,
+    result_offset: usize,
 
     fn runFn(runnable: *std.Thread.Pool.Runnable, _: ?usize) void {
         const run_node: *std.Thread.Pool.RunQueue.Node = @fieldParentPtr("data", runnable);
-        const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node));
-        closure.func(closure.context, closure.resultPointer());
+        const closure: *AsyncClosure = @alignCast(@fieldParentPtr("run_node", run_node));
+        closure.func(closure.contextPointer(), closure.resultPointer());
         closure.reset_event.set();
     }
 
-    fn resultPointer(closure: *@This()) [*]u8 {
+    fn contextOffset(context_alignment: std.mem.Alignment) usize {
+        return context_alignment.forward(@sizeOf(AsyncClosure));
+    }
+
+    fn resultOffset(
+        context_alignment: std.mem.Alignment,
+        context_len: usize,
+        result_alignment: std.mem.Alignment,
+    ) usize {
+        return result_alignment.forward(contextOffset(context_alignment) + context_len);
+    }
+
+    fn resultPointer(closure: *AsyncClosure) [*]u8 {
+        const base: [*]u8 = @ptrCast(closure);
+        return base + closure.result_offset;
+    }
+
+    fn contextPointer(closure: *AsyncClosure) [*]u8 {
         const base: [*]u8 = @ptrCast(closure);
-        return base + @sizeOf(@This());
+        return base + closure.context_offset;
     }
 };
 
 pub fn @"async"(
     userdata: ?*anyopaque,
-    eager_result: []u8,
-    context: ?*anyopaque,
-    start: *const fn (context: ?*anyopaque, result: *anyopaque) void,
+    result: []u8,
+    result_alignment: std.mem.Alignment,
+    context: []const u8,
+    context_alignment: std.mem.Alignment,
+    start: *const fn (context: *const anyopaque, result: *anyopaque) void,
 ) ?*std.Io.AnyFuture {
     const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
     pool.mutex.lock();
 
     const gpa = pool.allocator;
-    const n = @sizeOf(AsyncClosure) + eager_result.len;
+    const context_offset = context_alignment.forward(@sizeOf(AsyncClosure));
+    const result_offset = result_alignment.forward(context_offset + context.len);
+    const n = result_offset + result.len;
     const closure: *AsyncClosure = @alignCast(@ptrCast(gpa.alignedAlloc(u8, @alignOf(AsyncClosure), n) catch {
         pool.mutex.unlock();
-        start(context, eager_result.ptr);
+        start(context.ptr, result.ptr);
         return null;
     }));
     closure.* = .{
         .func = start,
-        .context = context,
+        .context_offset = context_offset,
+        .result_offset = result_offset,
         .reset_event = .{},
     };
+    @memcpy(closure.contextPointer()[0..context.len], context);
     pool.run_queue.prepend(&closure.run_node);
     pool.mutex.unlock();
 
lib/std/Io.zig
@@ -570,10 +570,12 @@ pub const VTable = struct {
         /// The pointer of this slice is an "eager" result value.
         /// The length is the size in bytes of the result type.
         /// This pointer's lifetime expires directly after the call to this function.
-        eager_result: []u8,
-        /// Passed to `start`.
-        context: ?*anyopaque,
-        start: *const fn (context: ?*anyopaque, result: *anyopaque) void,
+        result: []u8,
+        result_alignment: std.mem.Alignment,
+        /// Copied and then passed to `start`.
+        context: []const u8,
+        context_alignment: std.mem.Alignment,
+        start: *const fn (context: *const anyopaque, result: *anyopaque) void,
     ) ?*AnyFuture,
 
     /// This function is only called when `async` returns a non-null value.
@@ -611,17 +613,23 @@ pub fn Future(Result: type) type {
 /// }
 /// ```
 /// where `Result` is any type.
-pub fn async(io: Io, s: anytype) Future(@typeInfo(@TypeOf(@TypeOf(s).start)).@"fn".return_type.?) {
-    const S = @TypeOf(s);
+pub fn async(io: Io, S: type, s: S) Future(@typeInfo(@TypeOf(S.start)).@"fn".return_type.?) {
     const Result = @typeInfo(@TypeOf(S.start)).@"fn".return_type.?;
     const TypeErased = struct {
-        fn start(context: ?*anyopaque, result: *anyopaque) void {
+        fn start(context: *const anyopaque, result: *anyopaque) void {
             const context_casted: *const S = @alignCast(@ptrCast(context));
             const result_casted: *Result = @ptrCast(@alignCast(result));
             result_casted.* = S.start(context_casted.*);
         }
     };
     var future: Future(Result) = undefined;
-    future.any_future = io.vtable.async(io.userdata, @ptrCast((&future.result)[0..1]), @constCast(&s), TypeErased.start);
+    future.any_future = io.vtable.async(
+        io.userdata,
+        @ptrCast((&future.result)[0..1]),
+        .fromByteUnits(@alignOf(Result)),
+        if (@sizeOf(S) == 0) &.{} else @ptrCast((&s)[0..1]), // work around compiler bug
+        .fromByteUnits(@alignOf(S)),
+        TypeErased.start,
+    );
     return future;
 }