Commit 060fd975d9

Andrew Kelley <andrew@ziglang.org>
2025-10-15 22:48:51
std.Io.Group: add cancellation support to "wait"
1 parent 10bfbd7
Changed files (4)
lib/std/Io/net/HostName.zig
@@ -273,7 +273,7 @@ pub fn connectMany(
         .address => |address| group.async(io, enqueueConnection, .{ address, io, results, options }),
         .canonical_name => continue,
         .end => |lookup_result| {
-            group.wait(io);
+            group.waitUncancelable(io);
             results.putOneUncancelable(io, .{ .end = lookup_result });
             return;
         },
lib/std/Io/Threaded.zig
@@ -13,7 +13,6 @@ const IpAddress = std.Io.net.IpAddress;
 const Allocator = std.mem.Allocator;
 const assert = std.debug.assert;
 const posix = std.posix;
-const ResetEvent = std.Thread.ResetEvent;
 
 /// Thread-safe.
 allocator: Allocator,
@@ -153,8 +152,10 @@ pub fn io(t: *Threaded) Io {
             .cancel = cancel,
             .cancelRequested = cancelRequested,
             .select = select,
+
             .groupAsync = groupAsync,
             .groupWait = groupWait,
+            .groupWaitUncancelable = groupWaitUncancelable,
             .groupCancel = groupCancel,
 
             .mutexLock = mutexLock,
@@ -300,7 +301,7 @@ const AsyncClosure = struct {
     }
 
     fn waitAndFree(ac: *AsyncClosure, gpa: Allocator, result: []u8) void {
-        ac.reset_event.wait();
+        ac.reset_event.waitUncancelable();
         @memcpy(result, ac.resultPointer()[0..result.len]);
         free(ac, gpa, result.len);
     }
@@ -472,7 +473,7 @@ const GroupClosure = struct {
             assert(cancel_tid == Closure.canceling_tid);
             // We already know the task is canceled before running the callback. Since all closures
             // in a Group have void return type, we can return early.
-            std.Thread.WaitGroup.finishStateless(group_state, reset_event);
+            syncFinish(group_state, reset_event);
             return;
         }
         current_closure = closure;
@@ -485,7 +486,7 @@ const GroupClosure = struct {
             assert(cancel_tid == Closure.canceling_tid);
         }
 
-        std.Thread.WaitGroup.finishStateless(group_state, reset_event);
+        syncFinish(group_state, reset_event);
     }
 
     fn free(gc: *GroupClosure, gpa: Allocator) void {
@@ -505,6 +506,32 @@ const GroupClosure = struct {
         const base: [*]u8 = @ptrCast(gc);
         return base + contextOffset(gc.context_alignment);
     }
+
+    const sync_is_waiting: usize = 1 << 0;
+    const sync_one_pending: usize = 1 << 1;
+
+    fn syncStart(state: *std.atomic.Value(usize)) void {
+        const prev_state = state.fetchAdd(sync_one_pending, .monotonic);
+        assert((prev_state / sync_one_pending) < (std.math.maxInt(usize) / sync_one_pending));
+    }
+
+    fn syncFinish(state: *std.atomic.Value(usize), event: *ResetEvent) void {
+        const prev_state = state.fetchSub(sync_one_pending, .acq_rel);
+        assert((prev_state / sync_one_pending) > 0);
+        if (prev_state == (sync_one_pending | sync_is_waiting)) event.set();
+    }
+
+    fn syncWait(t: *Threaded, state: *std.atomic.Value(usize), event: *ResetEvent) Io.Cancelable!void {
+        const prev_state = state.fetchAdd(sync_is_waiting, .acquire);
+        assert(prev_state & sync_is_waiting == 0);
+        if ((prev_state / sync_one_pending) > 0) try event.wait(t);
+    }
+
+    fn syncWaitUncancelable(state: *std.atomic.Value(usize), event: *ResetEvent) void {
+        const prev_state = state.fetchAdd(sync_is_waiting, .acquire);
+        assert(prev_state & sync_is_waiting == 0);
+        if ((prev_state / sync_one_pending) > 0) event.waitUncancelable();
+    }
 };
 
 fn groupAsync(
@@ -566,22 +593,40 @@ fn groupAsync(
     // This needs to be done before unlocking the mutex to avoid a race with
     // the associated task finishing.
     const group_state: *std.atomic.Value(usize) = @ptrCast(&group.state);
-    std.Thread.WaitGroup.startStateless(group_state);
+    GroupClosure.syncStart(group_state);
 
     t.mutex.unlock();
     t.cond.signal();
 }
 
-fn groupWait(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) void {
+fn groupWait(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) Io.Cancelable!void {
     const t: *Threaded = @ptrCast(@alignCast(userdata));
     const gpa = t.allocator;
 
     if (builtin.single_threaded) return;
 
-    // TODO these primitives are too high level, need to check cancel on EINTR
     const group_state: *std.atomic.Value(usize) = @ptrCast(&group.state);
     const reset_event: *ResetEvent = @ptrCast(&group.context);
-    std.Thread.WaitGroup.waitStateless(group_state, reset_event);
+    try GroupClosure.syncWait(t, group_state, reset_event);
+
+    var node: *std.SinglyLinkedList.Node = @ptrCast(@alignCast(token));
+    while (true) {
+        const gc: *GroupClosure = @fieldParentPtr("node", node);
+        const node_next = node.next;
+        gc.free(gpa);
+        node = node_next orelse break;
+    }
+}
+
+fn groupWaitUncancelable(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) void {
+    const t: *Threaded = @ptrCast(@alignCast(userdata));
+    const gpa = t.allocator;
+
+    if (builtin.single_threaded) return;
+
+    const group_state: *std.atomic.Value(usize) = @ptrCast(&group.state);
+    const reset_event: *ResetEvent = @ptrCast(&group.context);
+    GroupClosure.syncWaitUncancelable(group_state, reset_event);
 
     var node: *std.SinglyLinkedList.Node = @ptrCast(@alignCast(token));
     while (true) {
@@ -609,7 +654,7 @@ fn groupCancel(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) void
 
     const group_state: *std.atomic.Value(usize) = @ptrCast(&group.state);
     const reset_event: *ResetEvent = @ptrCast(&group.context);
-    std.Thread.WaitGroup.waitStateless(group_state, reset_event);
+    GroupClosure.syncWaitUncancelable(group_state, reset_event);
 
     {
         var node: *std.SinglyLinkedList.Node = @ptrCast(@alignCast(token));
@@ -661,22 +706,20 @@ fn checkCancel(t: *Threaded) error{Canceled}!void {
 fn mutexLock(userdata: ?*anyopaque, prev_state: Io.Mutex.State, mutex: *Io.Mutex) Io.Cancelable!void {
     const t: *Threaded = @ptrCast(@alignCast(userdata));
     if (prev_state == .contended) {
-        try t.checkCancel();
-        futexWait(@ptrCast(&mutex.state), @intFromEnum(Io.Mutex.State.contended));
+        try futexWait(t, @ptrCast(&mutex.state), @intFromEnum(Io.Mutex.State.contended));
     }
     while (@atomicRmw(Io.Mutex.State, &mutex.state, .Xchg, .contended, .acquire) != .unlocked) {
-        try t.checkCancel();
-        futexWait(@ptrCast(&mutex.state), @intFromEnum(Io.Mutex.State.contended));
+        try futexWait(t, @ptrCast(&mutex.state), @intFromEnum(Io.Mutex.State.contended));
     }
 }
 
 fn mutexLockUncancelable(userdata: ?*anyopaque, prev_state: Io.Mutex.State, mutex: *Io.Mutex) void {
     _ = userdata;
     if (prev_state == .contended) {
-        futexWait(@ptrCast(&mutex.state), @intFromEnum(Io.Mutex.State.contended));
+        futexWaitUncancelable(@ptrCast(&mutex.state), @intFromEnum(Io.Mutex.State.contended));
     }
     while (@atomicRmw(Io.Mutex.State, &mutex.state, .Xchg, .contended, .acquire) != .unlocked) {
-        futexWait(@ptrCast(&mutex.state), @intFromEnum(Io.Mutex.State.contended));
+        futexWaitUncancelable(@ptrCast(&mutex.state), @intFromEnum(Io.Mutex.State.contended));
     }
 }
 
@@ -708,7 +751,7 @@ fn conditionWaitUncancelable(userdata: ?*anyopaque, cond: *Io.Condition, mutex:
     defer mutex.lockUncancelable(t_io);
 
     while (true) {
-        futexWait(cond_epoch, epoch);
+        futexWaitUncancelable(cond_epoch, epoch);
         epoch = cond_epoch.load(.acquire);
         state = cond_state.load(.monotonic);
         while (state & signal_mask != 0) {
@@ -747,8 +790,7 @@ fn conditionWait(userdata: ?*anyopaque, cond: *Io.Condition, mutex: *Io.Mutex) I
     defer mutex.lockUncancelable(t.io());
 
     while (true) {
-        try t.checkCancel();
-        futexWait(cond_epoch, epoch);
+        try futexWait(t, cond_epoch, epoch);
 
         epoch = cond_epoch.load(.acquire);
         state = cond_state.load(.monotonic);
@@ -1708,9 +1750,8 @@ fn sleepPosix(userdata: ?*anyopaque, timeout: Io.Timeout) Io.SleepError!void {
     }
 }
 
-fn select(userdata: ?*anyopaque, futures: []const *Io.AnyFuture) usize {
+fn select(userdata: ?*anyopaque, futures: []const *Io.AnyFuture) Io.Cancelable!usize {
     const t: *Threaded = @ptrCast(@alignCast(userdata));
-    _ = t;
 
     var reset_event: ResetEvent = .unset;
 
@@ -1720,20 +1761,20 @@ fn select(userdata: ?*anyopaque, futures: []const *Io.AnyFuture) usize {
             for (futures[0..i]) |cleanup_future| {
                 const cleanup_closure: *AsyncClosure = @ptrCast(@alignCast(cleanup_future));
                 if (@atomicRmw(?*ResetEvent, &cleanup_closure.select_condition, .Xchg, null, .seq_cst) == AsyncClosure.done_reset_event) {
-                    cleanup_closure.reset_event.wait(); // Ensure no reference to our stack-allocated reset_event.
+                    cleanup_closure.reset_event.waitUncancelable(); // Ensure no reference to our stack-allocated reset_event.
                 }
             }
             return i;
         }
     }
 
-    reset_event.wait();
+    try reset_event.wait(t);
 
     var result: ?usize = null;
     for (futures, 0..) |future, i| {
         const closure: *AsyncClosure = @ptrCast(@alignCast(future));
         if (@atomicRmw(?*ResetEvent, &closure.select_condition, .Xchg, null, .seq_cst) == AsyncClosure.done_reset_event) {
-            closure.reset_event.wait(); // Ensure no reference to our stack-allocated reset_event.
+            closure.reset_event.waitUncancelable(); // Ensure no reference to our stack-allocated reset_event.
             if (result == null) result = i; // In case multiple are ready, return first.
         }
     }
@@ -3320,11 +3361,12 @@ fn copyCanon(canonical_name_buffer: *[HostName.max_len]u8, name: []const u8) Hos
     return .{ .bytes = dest };
 }
 
-pub fn futexWait(ptr: *const std.atomic.Value(u32), expect: u32) void {
+fn futexWait(t: *Threaded, ptr: *const std.atomic.Value(u32), expect: u32) Io.Cancelable!void {
     @branchHint(.cold);
 
     if (native_os == .linux) {
         const linux = std.os.linux;
+        try t.checkCancel();
         const rc = linux.futex_4arg(ptr, .{ .cmd = .WAIT, .private = true }, expect, null);
         if (builtin.mode == .Debug) switch (linux.E.init(rc)) {
             .SUCCESS => {}, // notified by `wake()`
@@ -3341,7 +3383,28 @@ pub fn futexWait(ptr: *const std.atomic.Value(u32), expect: u32) void {
     @compileError("TODO");
 }
 
-pub fn futexWaitDuration(ptr: *const std.atomic.Value(u32), expect: u32, timeout: Io.Duration) void {
+pub fn futexWaitUncancelable(ptr: *const std.atomic.Value(u32), expect: u32) void {
+    @branchHint(.cold);
+
+    if (native_os == .linux) {
+        const linux = std.os.linux;
+        const rc = linux.futex_4arg(ptr, .{ .cmd = .WAIT, .private = true }, expect, null);
+        if (builtin.mode == .Debug) switch (linux.E.init(rc)) {
+            .SUCCESS => {}, // notified by `wake()`
+            .INTR => {}, // gives caller a chance to check cancellation
+            .AGAIN => {}, // ptr.* != expect
+            .INVAL => {}, // possibly timeout overflow
+            .TIMEDOUT => unreachable,
+            .FAULT => unreachable, // ptr was invalid
+            else => unreachable,
+        };
+        return;
+    }
+
+    @compileError("TODO");
+}
+
+pub fn futexWaitDurationUncancelable(ptr: *const std.atomic.Value(u32), expect: u32, timeout: Io.Duration) void {
     @branchHint(.cold);
 
     if (native_os == .linux) {
@@ -3384,3 +3447,114 @@ pub fn futexWake(ptr: *const std.atomic.Value(u32), max_waiters: u32) void {
 
     @compileError("TODO");
 }
+
+/// A thread-safe logical boolean value which can be `set` and `unset`.
+///
+/// It can also block threads until the value is set with cancelation via timed
+/// waits. Statically initializable; four bytes on all targets.
+pub const ResetEvent = enum(u32) {
+    unset = 0,
+    waiting = 1,
+    is_set = 2,
+
+    /// Returns whether the logical boolean is `set`.
+    ///
+    /// Once `reset` is called, this returns false until the next `set`.
+    ///
+    /// The memory accesses before the `set` can be said to happen before
+    /// `isSet` returns true.
+    pub fn isSet(re: *const ResetEvent) bool {
+        if (builtin.single_threaded) return switch (re.*) {
+            .unset => false,
+            .waiting => unreachable,
+            .is_set => true,
+        };
+        // Acquire barrier ensures memory accesses before `set` happen before
+        // returning true.
+        return @atomicLoad(ResetEvent, re, .acquire) == .is_set;
+    }
+
+    /// Blocks the calling thread until `set` is called.
+    ///
+    /// This is effectively a more efficient version of `while (!isSet()) {}`.
+    ///
+    /// The memory accesses before the `set` can be said to happen before `wait` returns.
+    pub fn wait(re: *ResetEvent, t: *Threaded) Io.Cancelable!void {
+        if (builtin.single_threaded) switch (re.*) {
+            .unset => unreachable, // Deadlock, no other threads to wake us up.
+            .waiting => unreachable, // Invalid state.
+            .is_set => return,
+        };
+        if (re.isSet()) {
+            @branchHint(.likely);
+            return;
+        }
+        // Try to set the state from `unset` to `waiting` to indicate to the
+        // `set` thread that others are blocked on the ResetEvent. Avoid using
+        // any strict barriers until we know the ResetEvent is set.
+        var state = @atomicLoad(ResetEvent, re, .acquire);
+        if (state == .unset) {
+            state = @cmpxchgStrong(ResetEvent, re, state, .waiting, .acquire, .acquire) orelse .waiting;
+        }
+        while (state == .waiting) {
+            try futexWait(t, @ptrCast(re), @intFromEnum(ResetEvent.waiting));
+            state = @atomicLoad(ResetEvent, re, .acquire);
+        }
+        assert(state == .is_set);
+    }
+
+    /// Same as `wait` except uninterruptible.
+    pub fn waitUncancelable(re: *ResetEvent) void {
+        if (builtin.single_threaded) switch (re.*) {
+            .unset => unreachable, // Deadlock, no other threads to wake us up.
+            .waiting => unreachable, // Invalid state.
+            .is_set => return,
+        };
+        if (re.isSet()) {
+            @branchHint(.likely);
+            return;
+        }
+        // Try to set the state from `unset` to `waiting` to indicate to the
+        // `set` thread that others are blocked on the ResetEvent. Avoid using
+        // any strict barriers until we know the ResetEvent is set.
+        var state = @atomicLoad(ResetEvent, re, .acquire);
+        if (state == .unset) {
+            state = @cmpxchgStrong(ResetEvent, re, state, .waiting, .acquire, .acquire) orelse .waiting;
+        }
+        while (state == .waiting) {
+            futexWaitUncancelable(@ptrCast(re), @intFromEnum(ResetEvent.waiting));
+            state = @atomicLoad(ResetEvent, re, .acquire);
+        }
+        assert(state == .is_set);
+    }
+
+    /// Marks the logical boolean as `set` and unblocks any threads in `wait`
+    /// or `timedWait` to observe the new state.
+    ///
+    /// The logical boolean stays `set` until `reset` is called, making future
+    /// `set` calls do nothing semantically.
+    ///
+    /// The memory accesses before `set` can be said to happen before `isSet`
+    /// returns true or `wait`/`timedWait` return successfully.
+    pub fn set(re: *ResetEvent) void {
+        if (builtin.single_threaded) {
+            re.* = .is_set;
+            return;
+        }
+        if (@atomicRmw(ResetEvent, re, .Xchg, .is_set, .release) == .waiting) {
+            futexWake(@ptrCast(re), std.math.maxInt(u32));
+        }
+    }
+
+    /// Unmarks the ResetEvent as if `set` was never called.
+    ///
+    /// Assumes no threads are blocked in `wait` or `timedWait`. Concurrent
+    /// calls to `set`, `isSet` and `reset` are allowed.
+    pub fn reset(re: *ResetEvent) void {
+        if (builtin.single_threaded) {
+            re.* = .unset;
+            return;
+        }
+        @atomicStore(ResetEvent, re, .unset, .monotonic);
+    }
+};
lib/std/Io.zig
@@ -641,12 +641,13 @@ pub const VTable = struct {
         context_alignment: std.mem.Alignment,
         start: *const fn (*Group, context: *const anyopaque) void,
     ) void,
-    groupWait: *const fn (?*anyopaque, *Group, token: *anyopaque) void,
+    groupWait: *const fn (?*anyopaque, *Group, token: *anyopaque) Cancelable!void,
+    groupWaitUncancelable: *const fn (?*anyopaque, *Group, token: *anyopaque) void,
     groupCancel: *const fn (?*anyopaque, *Group, token: *anyopaque) void,
 
     /// Blocks until one of the futures from the list has a result ready, such
     /// that awaiting it will not block. Returns that index.
-    select: *const fn (?*anyopaque, futures: []const *AnyFuture) usize,
+    select: *const fn (?*anyopaque, futures: []const *AnyFuture) Cancelable!usize,
 
     mutexLock: *const fn (?*anyopaque, prev_state: Mutex.State, mutex: *Mutex) Cancelable!void,
     mutexLockUncancelable: *const fn (?*anyopaque, prev_state: Mutex.State, mutex: *Mutex) void,
@@ -1017,10 +1018,19 @@ pub const Group = struct {
     /// Blocks until all tasks of the group finish.
     ///
     /// Idempotent. Not threadsafe.
-    pub fn wait(g: *Group, io: Io) void {
+    pub fn wait(g: *Group, io: Io) Cancelable!void {
         const token = g.token orelse return;
         g.token = null;
-        io.vtable.groupWait(io.userdata, g, token);
+        return io.vtable.groupWait(io.userdata, g, token);
+    }
+
+    /// Equivalent to `wait` except uninterruptible.
+    ///
+    /// Idempotent. Not threadsafe.
+    pub fn waitUncancelable(g: *Group, io: Io) void {
+        const token = g.token orelse return;
+        g.token = null;
+        io.vtable.groupWaitUncancelable(io.userdata, g, token);
     }
 
     /// Equivalent to `wait` but requests cancellation on all tasks owned by
@@ -1095,7 +1105,7 @@ pub fn Select(comptime U: type) type {
         /// Asserts there is at least one more `outstanding` task.
         ///
         /// Not threadsafe.
-        pub fn wait(s: *S) Io.Cancelable!U {
+        pub fn wait(s: *S) Cancelable!U {
             s.outstanding -= 1;
             return s.queue.getOne(s.io);
         }
BRANCH_TODO
@@ -3,7 +3,7 @@
 * Threaded: finish windows impl 
 * Threaded: glibc impl of netLookup
 
-* fix Group.wait not handling cancelation (need to move impl of ResetEvent to Threaded)
+* eliminate dependency on std.Thread (Mutex, Condition, maybe more)
 * implement cancelRequest for non-linux posix
 * finish converting all Threaded into directly calling system functions and handling EINTR
 * audit the TODOs