Commit c4fcf85c43

Jacob Young <jacobly0@users.noreply.github.com>
2025-04-03 00:03:53
Io.Condition: implement full API
1 parent 3eb7be5
Changed files (3)
lib
lib/std/Io/EventLoop.zig
@@ -555,8 +555,24 @@ const SwitchMessage = struct {
             .condition_wait => |condition_wait| {
                 const prev_fiber: *Fiber = @alignCast(@fieldParentPtr("context", message.contexts.prev));
                 assert(prev_fiber.queue_next == null);
-                const cond_state: *?*Fiber = @ptrCast(&condition_wait.cond.state);
-                assert(@atomicRmw(?*Fiber, cond_state, .Xchg, prev_fiber, .release) == null); // More than one wait on same Condition is illegal.
+                const cond_impl = prev_fiber.resultPointer(ConditionImpl);
+                cond_impl.* = .{
+                    .tail = prev_fiber,
+                    .event = .queued,
+                };
+                if (@cmpxchgStrong(
+                    ?*Fiber,
+                    @as(*?*Fiber, @ptrCast(&condition_wait.cond.state)),
+                    null,
+                    prev_fiber,
+                    .release,
+                    .acquire,
+                )) |waiting_fiber| {
+                    const waiting_cond_impl = waiting_fiber.?.resultPointer(ConditionImpl);
+                    assert(waiting_cond_impl.tail.queue_next == null);
+                    waiting_cond_impl.tail.queue_next = prev_fiber;
+                    waiting_cond_impl.tail = prev_fiber;
+                }
                 condition_wait.mutex.unlock(el.io());
             },
             .exit => for (el.threads.allocated[0..@atomicLoad(u32, &el.threads.active, .acquire)]) |*each_thread| {
@@ -1267,10 +1283,7 @@ fn sleep(userdata: ?*anyopaque, clockid: std.posix.clockid_t, deadline: Io.Deadl
 
 fn mutexLock(userdata: ?*anyopaque, prev_state: Io.Mutex.State, mutex: *Io.Mutex) error{Canceled}!void {
     const el: *EventLoop = @alignCast(@ptrCast(userdata));
-    el.yield(null, .{ .mutex_lock = .{
-        .prev_state = prev_state,
-        .mutex = mutex,
-    } });
+    el.yield(null, .{ .mutex_lock = .{ .prev_state = prev_state, .mutex = mutex } });
 }
 fn mutexUnlock(userdata: ?*anyopaque, prev_state: Io.Mutex.State, mutex: *Io.Mutex) void {
     var maybe_waiting_fiber: ?*Fiber = @ptrFromInt(@intFromEnum(prev_state));
@@ -1294,21 +1307,48 @@ fn mutexUnlock(userdata: ?*anyopaque, prev_state: Io.Mutex.State, mutex: *Io.Mut
     el.yield(maybe_waiting_fiber.?, .reschedule);
 }
 
+const ConditionImpl = struct {
+    tail: *Fiber,
+    event: union(enum) {
+        queued,
+        wake: Io.Condition.Wake,
+    },
+};
+
 fn conditionWait(userdata: ?*anyopaque, cond: *Io.Condition, mutex: *Io.Mutex) Io.Cancelable!void {
     const el: *EventLoop = @alignCast(@ptrCast(userdata));
-    el.yield(null, .{ .condition_wait = .{
-        .cond = cond,
-        .mutex = mutex,
-    } });
+    el.yield(null, .{ .condition_wait = .{ .cond = cond, .mutex = mutex } });
+    const thread = Thread.current();
+    const fiber = thread.currentFiber();
+    const cond_impl = fiber.resultPointer(ConditionImpl);
     try mutex.lock(el.io());
+    switch (cond_impl.event) {
+        .queued => {},
+        .wake => |wake| if (fiber.queue_next) |next_fiber| switch (wake) {
+            .one => if (@cmpxchgStrong(
+                ?*Fiber,
+                @as(*?*Fiber, @ptrCast(&cond.state)),
+                null,
+                next_fiber,
+                .release,
+                .acquire,
+            )) |old_fiber| {
+                const old_cond_impl = old_fiber.?.resultPointer(ConditionImpl);
+                assert(old_cond_impl.tail.queue_next == null);
+                old_cond_impl.tail.queue_next = next_fiber;
+                old_cond_impl.tail = cond_impl.tail;
+            },
+            .all => el.schedule(thread, .{ .head = next_fiber, .tail = cond_impl.tail }),
+        },
+    }
+    fiber.queue_next = null;
 }
 
-fn conditionWake(userdata: ?*anyopaque, cond: *Io.Condition) void {
+fn conditionWake(userdata: ?*anyopaque, cond: *Io.Condition, wake: Io.Condition.Wake) void {
     const el: *EventLoop = @alignCast(@ptrCast(userdata));
-    const cond_state: *?*Fiber = @ptrCast(&cond.state);
-    if (@atomicRmw(?*Fiber, cond_state, .Xchg, null, .acquire)) |fiber| {
-        el.yield(fiber, .reschedule);
-    }
+    const waiting_fiber = @atomicRmw(?*Fiber, @as(*?*Fiber, @ptrCast(&cond.state)), .Xchg, null, .acquire) orelse return;
+    waiting_fiber.resultPointer(ConditionImpl).event = .{ .wake = wake };
+    el.yield(waiting_fiber, .reschedule);
 }
 
 fn errno(signed: i32) std.os.linux.E {
lib/std/Thread/Pool.zig
@@ -666,7 +666,7 @@ fn conditionWait(userdata: ?*anyopaque, cond: *Io.Condition, mutex: *Io.Mutex) I
     }
 }
 
-fn conditionWake(userdata: ?*anyopaque, cond: *Io.Condition) void {
+fn conditionWake(userdata: ?*anyopaque, cond: *Io.Condition, wake: Io.Condition.Wake) void {
     const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
     _ = pool;
     comptime assert(@TypeOf(cond.state) == u64);
@@ -690,7 +690,10 @@ fn conditionWake(userdata: ?*anyopaque, cond: *Io.Condition) void {
             return;
         }
 
-        const to_wake = 1;
+        const to_wake = switch (wake) {
+            .one => 1,
+            .all => wakeable,
+        };
 
         // Reserve the amount of waiters to wake by incrementing the signals count.
         // Release barrier ensures code before the wake() happens before the signal it posted and consumed by the wait() threads.
lib/std/Io.zig
@@ -630,7 +630,7 @@ pub const VTable = struct {
     mutexUnlock: *const fn (?*anyopaque, prev_state: Mutex.State, mutex: *Mutex) void,
 
     conditionWait: *const fn (?*anyopaque, cond: *Condition, mutex: *Mutex) Cancelable!void,
-    conditionWake: *const fn (?*anyopaque, cond: *Condition) void,
+    conditionWake: *const fn (?*anyopaque, cond: *Condition, wake: Condition.Wake) void,
 
     createFile: *const fn (?*anyopaque, dir: fs.Dir, sub_path: []const u8, flags: fs.File.CreateFlags) FileOpenError!fs.File,
     openFile: *const fn (?*anyopaque, dir: fs.Dir, sub_path: []const u8, flags: fs.File.OpenFlags) FileOpenError!fs.File,
@@ -809,9 +809,20 @@ pub const Condition = struct {
         return io.vtable.conditionWait(io.userdata, cond, mutex);
     }
 
-    pub fn wake(cond: *Condition, io: Io) void {
-        io.vtable.conditionWake(io.userdata, cond);
+    pub fn signal(cond: *Condition, io: Io) void {
+        io.vtable.conditionWake(io.userdata, cond, .one);
     }
+
+    pub fn broadcast(cond: *Condition, io: Io) void {
+        io.vtable.conditionWake(io.userdata, cond, .all);
+    }
+
+    pub const Wake = enum {
+        /// wake up only one thread
+        one,
+        /// wake up all thread
+        all,
+    };
 };
 
 pub const TypeErasedQueue = struct {
@@ -863,7 +874,7 @@ pub const TypeErasedQueue = struct {
             remaining = remaining[copy_len..];
             getter.data.remaining = getter.data.remaining[copy_len..];
             if (getter.data.remaining.len == 0) {
-                getter.data.condition.wake(io);
+                getter.data.condition.signal(io);
                 continue;
             }
             q.getters.prepend(getter);
@@ -946,7 +957,7 @@ pub const TypeErasedQueue = struct {
                 putter.data.remaining = putter.data.remaining[copy_len..];
                 remaining = remaining[copy_len..];
                 if (putter.data.remaining.len == 0) {
-                    putter.data.condition.wake(io);
+                    putter.data.condition.signal(io);
                 } else {
                     assert(remaining.len == 0);
                     q.putters.prepend(putter);
@@ -979,7 +990,7 @@ pub const TypeErasedQueue = struct {
             putter.data.remaining = putter.data.remaining[copy_len..];
             q.put_index += copy_len;
             if (putter.data.remaining.len == 0) {
-                putter.data.condition.wake(io);
+                putter.data.condition.signal(io);
                 continue;
             }
             const second_available = q.buffer[0..q.get_index];
@@ -988,7 +999,7 @@ pub const TypeErasedQueue = struct {
             putter.data.remaining = putter.data.remaining[copy_len..];
             q.put_index = copy_len;
             if (putter.data.remaining.len == 0) {
-                putter.data.condition.wake(io);
+                putter.data.condition.signal(io);
                 continue;
             }
             q.putters.prepend(putter);