Commit c88b8e3c15

Andrew Kelley <andrew@ziglang.org>
2025-04-04 06:15:19
std.Io.EventLoop: implement select
1 parent f158ec5
Changed files (2)
lib
std
lib/std/Io/EventLoop.zig
@@ -485,6 +485,7 @@ const SwitchMessage = struct {
         reschedule,
         recycle,
         register_awaiter: *?*Fiber,
+        register_select: []const *Io.AnyFuture,
         mutex_lock: struct {
             prev_state: Io.Mutex.State,
             mutex: *Io.Mutex,
@@ -514,13 +515,21 @@ const SwitchMessage = struct {
             .register_awaiter => |awaiter| {
                 const prev_fiber: *Fiber = @alignCast(@fieldParentPtr("context", message.contexts.prev));
                 assert(prev_fiber.queue_next == null);
-                if (@atomicRmw(
-                    ?*Fiber,
-                    awaiter,
-                    .Xchg,
-                    prev_fiber,
-                    .acq_rel,
-                ) == Fiber.finished) el.schedule(thread, .{ .head = prev_fiber, .tail = prev_fiber });
+                if (@atomicRmw(?*Fiber, awaiter, .Xchg, prev_fiber, .acq_rel) == Fiber.finished)
+                    el.schedule(thread, .{ .head = prev_fiber, .tail = prev_fiber });
+            },
+            .register_select => |futures| {
+                const prev_fiber: *Fiber = @alignCast(@fieldParentPtr("context", message.contexts.prev));
+                assert(prev_fiber.queue_next == null);
+                for (futures) |any_future| {
+                    const future_fiber: *Fiber = @alignCast(@ptrCast(any_future));
+                    if (@atomicRmw(?*Fiber, &future_fiber.awaiter, .Xchg, prev_fiber, .acq_rel) == Fiber.finished) {
+                        const closure: *AsyncClosure = .fromFiber(future_fiber);
+                        if (!@atomicRmw(bool, &closure.already_awaited, .Xchg, true, .seq_cst)) {
+                            el.schedule(thread, .{ .head = prev_fiber, .tail = prev_fiber });
+                        }
+                    }
+                }
             },
             .mutex_lock => |mutex_lock| {
                 const prev_fiber: *Fiber = @alignCast(@fieldParentPtr("context", message.contexts.prev));
@@ -661,6 +670,7 @@ const AsyncClosure = struct {
     fiber: *Fiber,
     start: *const fn (context: *const anyopaque, result: *anyopaque) void,
     result_align: Alignment,
+    already_awaited: bool,
 
     fn contextPointer(closure: *AsyncClosure) [*]align(Fiber.max_context_align.toByteUnits()) u8 {
         return @alignCast(@as([*]u8, @ptrCast(closure)) + @sizeOf(AsyncClosure));
@@ -668,12 +678,24 @@ const AsyncClosure = struct {
 
     fn call(closure: *AsyncClosure, message: *const SwitchMessage) callconv(.withStackAlign(.c, @alignOf(AsyncClosure))) noreturn {
         message.handle(closure.event_loop);
-        std.log.debug("{*} performing async", .{closure.fiber});
-        closure.start(closure.contextPointer(), closure.fiber.resultBytes(closure.result_align));
-        const awaiter = @atomicRmw(?*Fiber, &closure.fiber.awaiter, .Xchg, Fiber.finished, .acq_rel);
-        closure.event_loop.yield(awaiter, .nothing);
+        const fiber = closure.fiber;
+        std.log.debug("{*} performing async", .{fiber});
+        closure.start(closure.contextPointer(), fiber.resultBytes(closure.result_align));
+        const awaiter = @atomicRmw(?*Fiber, &fiber.awaiter, .Xchg, Fiber.finished, .acq_rel);
+        const ready_awaiter = r: {
+            const a = awaiter orelse break :r null;
+            if (@atomicRmw(bool, &closure.already_awaited, .Xchg, true, .acq_rel)) break :r null;
+            break :r a;
+        };
+        closure.event_loop.yield(ready_awaiter, .nothing);
         unreachable; // switched to dead fiber
     }
+
+    fn fromFiber(fiber: *Fiber) *AsyncClosure {
+        return @ptrFromInt(Fiber.max_context_align.max(.of(AsyncClosure)).backward(
+            @intFromPtr(fiber.allocatedEnd()) - Fiber.max_context_size,
+        ) - @sizeOf(AsyncClosure));
+    }
 };
 
 fn @"async"(
@@ -696,9 +718,7 @@ fn @"async"(
     };
     std.log.debug("allocated {*}", .{fiber});
 
-    const closure: *AsyncClosure = @ptrFromInt(Fiber.max_context_align.max(.of(AsyncClosure)).backward(
-        @intFromPtr(fiber.allocatedEnd()) - Fiber.max_context_size,
-    ) - @sizeOf(AsyncClosure));
+    const closure: *AsyncClosure = .fromFiber(fiber);
     const stack_end: [*]usize = @alignCast(@ptrCast(closure));
     (stack_end - 1)[0..1].* = .{@intFromPtr(&AsyncClosure.call)};
     fiber.* = .{
@@ -721,6 +741,7 @@ fn @"async"(
         .fiber = fiber,
         .start = start,
         .result_align = result_alignment,
+        .already_awaited = false,
     };
     @memcpy(closure.contextPointer(), context);
 
@@ -728,13 +749,6 @@ fn @"async"(
     return @ptrCast(fiber);
 }
 
-fn select(userdata: ?*anyopaque, futures: []const *Io.AnyFuture) usize {
-    const el: *EventLoop = @alignCast(@ptrCast(userdata));
-    _ = el;
-    _ = futures;
-    @panic("TODO");
-}
-
 const DetachedClosure = struct {
     event_loop: *EventLoop,
     fiber: *Fiber,
@@ -836,6 +850,42 @@ fn @"await"(
     event_loop.recycle(future_fiber);
 }
 
+fn select(userdata: ?*anyopaque, futures: []const *Io.AnyFuture) usize {
+    const el: *EventLoop = @alignCast(@ptrCast(userdata));
+
+    // Optimization to avoid the yield below.
+    for (futures, 0..) |any_future, i| {
+        const future_fiber: *Fiber = @alignCast(@ptrCast(any_future));
+        if (@atomicLoad(?*Fiber, &future_fiber.awaiter, .acquire) == Fiber.finished)
+            return i;
+    }
+
+    el.yield(null, .{ .register_select = futures });
+
+    std.log.debug("back from select yield", .{});
+
+    const my_thread: *Thread = .current();
+    const my_fiber = my_thread.currentFiber();
+    var result: ?usize = null;
+
+    for (futures, 0..) |any_future, i| {
+        const future_fiber: *Fiber = @alignCast(@ptrCast(any_future));
+        if (@cmpxchgStrong(?*Fiber, &future_fiber.awaiter, my_fiber, null, .seq_cst, .seq_cst)) |awaiter| {
+            if (awaiter == Fiber.finished) {
+                if (result == null) result = i;
+            } else if (awaiter) |a| {
+                const closure: *AsyncClosure = .fromFiber(a);
+                closure.already_awaited = false;
+            }
+        } else {
+            const closure: *AsyncClosure = .fromFiber(my_fiber);
+            closure.already_awaited = false;
+        }
+    }
+
+    return result.?;
+}
+
 fn cancel(
     userdata: ?*anyopaque,
     any_future: *std.Io.AnyFuture,
lib/std/Thread/Pool.zig
@@ -364,7 +364,7 @@ const AsyncClosure = struct {
     context_offset: usize,
     result_offset: usize,
 
-    const done_reset_event: *std.Thread.ResetEvent = @ptrFromInt(std.mem.alignBackward(usize, std.math.maxInt(usize), @alignOf(std.Thread.ResetEvent)));
+    const done_reset_event: *std.Thread.ResetEvent = @ptrFromInt(@alignOf(std.Thread.ResetEvent));
 
     const canceling_tid: std.Thread.Id = switch (@typeInfo(std.Thread.Id)) {
         .int => |int_info| switch (int_info.signedness) {