Commit 03fd132b1c

Andrew Kelley <andrew@ziglang.org>
2025-10-28 23:06:07
std.Io: fix Group.wait unsoundness
Previously if a Group.wait was canceled, then a subsequent call to wait() or cancel() would trip an assertion in the synchronization code.
1 parent 6c794ce
Changed files (4)
lib/std/Io/net/HostName.zig
@@ -280,9 +280,8 @@ pub fn connectMany(
         .address => |address| group.async(io, enqueueConnection, .{ address, io, results, options }),
         .canonical_name => continue,
         .end => |lookup_result| {
-            results.putOneUncancelable(io, .{
-                .end = if (group.wait(io)) lookup_result else |err| err,
-            });
+            group.wait(io);
+            results.putOneUncancelable(io, .{ .end = lookup_result });
             return;
         },
     } else |err| switch (err) {
lib/std/Io/Kqueue.zig
@@ -859,7 +859,6 @@ pub fn io(k: *Kqueue) Io {
 
             .groupAsync = groupAsync,
             .groupWait = groupWait,
-            .groupWaitUncancelable = groupWaitUncancelable,
             .groupCancel = groupCancel,
 
             .mutexLock = mutexLock,
@@ -1027,15 +1026,7 @@ fn groupAsync(
     @panic("TODO");
 }
 
-fn groupWait(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) Io.Cancelable!void {
-    const k: *Kqueue = @ptrCast(@alignCast(userdata));
-    _ = k;
-    _ = group;
-    _ = token;
-    @panic("TODO");
-}
-
-fn groupWaitUncancelable(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) void {
+fn groupWait(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) void {
     const k: *Kqueue = @ptrCast(@alignCast(userdata));
     _ = k;
     _ = group;
lib/std/Io/Threaded.zig
@@ -177,7 +177,6 @@ pub fn io(t: *Threaded) Io {
 
             .groupAsync = groupAsync,
             .groupWait = groupWait,
-            .groupWaitUncancelable = groupWaitUncancelable,
             .groupCancel = groupCancel,
 
             .mutexLock = mutexLock,
@@ -274,7 +273,6 @@ pub fn ioBasic(t: *Threaded) Io {
 
             .groupAsync = groupAsync,
             .groupWait = groupWait,
-            .groupWaitUncancelable = groupWaitUncancelable,
             .groupCancel = groupCancel,
 
             .mutexLock = mutexLock,
@@ -579,7 +577,9 @@ const GroupClosure = struct {
             assert(cancel_tid == .canceling);
         }
 
-        syncFinish(group_state, reset_event);
+        const prev_state = group_state.fetchSub(sync_one_pending, .acq_rel);
+        assert((prev_state / sync_one_pending) > 0);
+        if (prev_state == (sync_one_pending | sync_is_waiting)) reset_event.set();
     }
 
     fn free(gc: *GroupClosure, gpa: Allocator) void {
@@ -602,29 +602,6 @@ const GroupClosure = struct {
 
     const sync_is_waiting: usize = 1 << 0;
     const sync_one_pending: usize = 1 << 1;
-
-    fn syncStart(state: *std.atomic.Value(usize)) void {
-        const prev_state = state.fetchAdd(sync_one_pending, .monotonic);
-        assert((prev_state / sync_one_pending) < (std.math.maxInt(usize) / sync_one_pending));
-    }
-
-    fn syncFinish(state: *std.atomic.Value(usize), event: *ResetEvent) void {
-        const prev_state = state.fetchSub(sync_one_pending, .acq_rel);
-        assert((prev_state / sync_one_pending) > 0);
-        if (prev_state == (sync_one_pending | sync_is_waiting)) event.set();
-    }
-
-    fn syncWait(t: *Threaded, state: *std.atomic.Value(usize), event: *ResetEvent) Io.Cancelable!void {
-        const prev_state = state.fetchAdd(sync_is_waiting, .acquire);
-        assert(prev_state & sync_is_waiting == 0);
-        if ((prev_state / sync_one_pending) > 0) try event.wait(t);
-    }
-
-    fn syncWaitUncancelable(state: *std.atomic.Value(usize), event: *ResetEvent) void {
-        const prev_state = state.fetchAdd(sync_is_waiting, .acquire);
-        assert(prev_state & sync_is_waiting == 0);
-        if ((prev_state / sync_one_pending) > 0) event.waitUncancelable();
-    }
 };
 
 fn groupAsync(
@@ -686,32 +663,14 @@ fn groupAsync(
     // This needs to be done before unlocking the mutex to avoid a race with
     // the associated task finishing.
     const group_state: *std.atomic.Value(usize) = @ptrCast(&group.state);
-    GroupClosure.syncStart(group_state);
+    const prev_state = group_state.fetchAdd(GroupClosure.sync_one_pending, .monotonic);
+    assert((prev_state / GroupClosure.sync_one_pending) < (std.math.maxInt(usize) / GroupClosure.sync_one_pending));
 
     t.mutex.unlock();
     t.cond.signal();
 }
 
-fn groupWait(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) Io.Cancelable!void {
-    const t: *Threaded = @ptrCast(@alignCast(userdata));
-    const gpa = t.allocator;
-
-    if (builtin.single_threaded) return;
-
-    const group_state: *std.atomic.Value(usize) = @ptrCast(&group.state);
-    const reset_event: *ResetEvent = @ptrCast(&group.context);
-    try GroupClosure.syncWait(t, group_state, reset_event);
-
-    var node: *std.SinglyLinkedList.Node = @ptrCast(@alignCast(token));
-    while (true) {
-        const gc: *GroupClosure = @fieldParentPtr("node", node);
-        const node_next = node.next;
-        gc.free(gpa);
-        node = node_next orelse break;
-    }
-}
-
-fn groupWaitUncancelable(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) void {
+fn groupWait(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) void {
     const t: *Threaded = @ptrCast(@alignCast(userdata));
     const gpa = t.allocator;
 
@@ -719,7 +678,19 @@ fn groupWaitUncancelable(userdata: ?*anyopaque, group: *Io.Group, token: *anyopa
 
     const group_state: *std.atomic.Value(usize) = @ptrCast(&group.state);
     const reset_event: *ResetEvent = @ptrCast(&group.context);
-    GroupClosure.syncWaitUncancelable(group_state, reset_event);
+    const prev_state = group_state.fetchAdd(GroupClosure.sync_is_waiting, .acquire);
+    assert(prev_state & GroupClosure.sync_is_waiting == 0);
+    if ((prev_state / GroupClosure.sync_one_pending) > 0) reset_event.wait(t) catch |err| switch (err) {
+        error.Canceled => {
+            var node: *std.SinglyLinkedList.Node = @ptrCast(@alignCast(token));
+            while (true) {
+                const gc: *GroupClosure = @fieldParentPtr("node", node);
+                gc.closure.requestCancel();
+                node = node.next orelse break;
+            }
+            reset_event.waitUncancelable();
+        },
+    };
 
     var node: *std.SinglyLinkedList.Node = @ptrCast(@alignCast(token));
     while (true) {
@@ -747,7 +718,9 @@ fn groupCancel(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) void
 
     const group_state: *std.atomic.Value(usize) = @ptrCast(&group.state);
     const reset_event: *ResetEvent = @ptrCast(&group.context);
-    GroupClosure.syncWaitUncancelable(group_state, reset_event);
+    const prev_state = group_state.fetchAdd(GroupClosure.sync_is_waiting, .acquire);
+    assert(prev_state & GroupClosure.sync_is_waiting == 0);
+    if ((prev_state / GroupClosure.sync_one_pending) > 0) reset_event.waitUncancelable();
 
     {
         var node: *std.SinglyLinkedList.Node = @ptrCast(@alignCast(token));
@@ -1549,7 +1522,7 @@ fn dirAccessPosix(
             .FAULT => |err| return errnoBug(err),
             .IO => return error.InputOutput,
             .NOMEM => return error.SystemResources,
-            .ILSEQ => return error.BadPathName, // TODO move to wasi
+            .ILSEQ => return error.BadPathName,
             else => |err| return posix.unexpectedErrno(err),
         }
     }
lib/std/Io.zig
@@ -653,8 +653,7 @@ pub const VTable = struct {
         context_alignment: std.mem.Alignment,
         start: *const fn (*Group, context: *const anyopaque) void,
     ) void,
-    groupWait: *const fn (?*anyopaque, *Group, token: *anyopaque) Cancelable!void,
-    groupWaitUncancelable: *const fn (?*anyopaque, *Group, token: *anyopaque) void,
+    groupWait: *const fn (?*anyopaque, *Group, token: *anyopaque) void,
     groupCancel: *const fn (?*anyopaque, *Group, token: *anyopaque) void,
 
     /// Blocks until one of the futures from the list has a result ready, such
@@ -1038,29 +1037,18 @@ pub const Group = struct {
         io.vtable.groupAsync(io.userdata, g, @ptrCast((&args)[0..1]), .of(Args), TypeErased.start);
     }
 
-    /// Blocks until all tasks of the group finish.
-    ///
-    /// On success, further calls to `wait`, `waitUncancelable`, and `cancel`
-    /// do nothing.
-    ///
-    /// Not threadsafe.
-    pub fn wait(g: *Group, io: Io) Cancelable!void {
-        const token = g.token orelse return;
-        try io.vtable.groupWait(io.userdata, g, token);
-        g.token = null;
-    }
-
-    /// Equivalent to `wait` except uninterruptible.
+    /// Blocks until all tasks of the group finish. During this time,
+    /// cancellation requests propagate to all members of the group.
     ///
     /// Idempotent. Not threadsafe.
-    pub fn waitUncancelable(g: *Group, io: Io) void {
+    pub fn wait(g: *Group, io: Io) void {
         const token = g.token orelse return;
         g.token = null;
-        io.vtable.groupWaitUncancelable(io.userdata, g, token);
+        io.vtable.groupWait(io.userdata, g, token);
     }
 
-    /// Equivalent to `wait` but requests cancellation on all tasks owned by
-    /// the group.
+    /// Equivalent to `wait` but immediately requests cancellation on all
+    /// members of the group.
     ///
     /// Idempotent. Not threadsafe.
     pub fn cancel(g: *Group, io: Io) void {