Commit aca7feb8fa

Andrew Kelley <andrew@ziglang.org>
2024-05-28 00:22:18
std.Progress: fix race condition with setIpcFd
The update thread was sometimes reading the special state and then incorrectly getting 0 for the file descriptor, making it hang since it tried to read from stdin.
1 parent bb1f4d2
Changed files (1)
lib
lib/std/Progress.zig
@@ -85,6 +85,7 @@ pub const Node = struct {
         estimated_total_count: u32,
         name: [max_name_len]u8,
 
+        /// Not thread-safe.
         fn getIpcFd(s: Storage) ?posix.fd_t {
             return if (s.estimated_total_count == std.math.maxInt(u32)) switch (@typeInfo(posix.fd_t)) {
                 .Int => @bitCast(s.completed_count),
@@ -93,15 +94,21 @@ pub const Node = struct {
             } else null;
         }
 
+        /// Thread-safe.
         fn setIpcFd(s: *Storage, fd: posix.fd_t) void {
-            s.estimated_total_count = std.math.maxInt(u32);
-            s.completed_count = switch (@typeInfo(posix.fd_t)) {
+            const integer: u32 = switch (@typeInfo(posix.fd_t)) {
                 .Int => @bitCast(fd),
                 .Pointer => @intFromPtr(fd),
                 else => @compileError("unsupported fd_t of " ++ @typeName(posix.fd_t)),
             };
+            // `estimated_total_count` max int indicates the special state that
+            // causes `completed_count` to be treated as a file descriptor, so
+            // the order here matters.
+            @atomicStore(u32, &s.completed_count, integer, .seq_cst);
+            @atomicStore(u32, &s.estimated_total_count, std.math.maxInt(u32), .seq_cst);
         }
 
+        /// Not thread-safe.
         fn byteSwap(s: *Storage) void {
             s.completed_count = @byteSwap(s.completed_count);
             s.estimated_total_count = @byteSwap(s.estimated_total_count);
@@ -208,7 +215,9 @@ pub const Node = struct {
     pub fn setEstimatedTotalItems(n: Node, count: usize) void {
         const index = n.index.unwrap() orelse return;
         const storage = storageByIndex(index);
-        @atomicStore(u32, &storage.estimated_total_count, std.math.lossyCast(u32, count), .monotonic);
+        // Avoid u32 max int which is used to indicate a special state.
+        const saturated = @min(std.math.maxInt(u32) - 1, count);
+        @atomicStore(u32, &storage.estimated_total_count, saturated, .monotonic);
     }
 
     /// Thread-safe.
@@ -243,10 +252,13 @@ pub const Node = struct {
         }
     }
 
-    /// Posix-only. Used by `std.process.Child`.
+    /// Posix-only. Used by `std.process.Child`. Thread-safe.
     pub fn setIpcFd(node: Node, fd: posix.fd_t) void {
         const index = node.index.unwrap() orelse return;
-        assert(fd != -1);
+        assert(fd >= 0);
+        assert(fd != posix.STDOUT_FILENO);
+        assert(fd != posix.STDIN_FILENO);
+        assert(fd != posix.STDERR_FILENO);
         storageByIndex(index).setIpcFd(fd);
     }
 
@@ -582,8 +594,8 @@ fn serialize(serialized_buffer: *Serialized.Buffer) Serialized {
         while (begin_parent != .unused) {
             const dest_storage = &serialized_buffer.storage[serialized_len];
             @memcpy(&dest_storage.name, &storage_ptr.name);
-            dest_storage.completed_count = @atomicLoad(u32, &storage_ptr.completed_count, .monotonic);
-            dest_storage.estimated_total_count = @atomicLoad(u32, &storage_ptr.estimated_total_count, .monotonic);
+            dest_storage.completed_count = @atomicLoad(u32, &storage_ptr.completed_count, .seq_cst);
+            dest_storage.estimated_total_count = @atomicLoad(u32, &storage_ptr.estimated_total_count, .seq_cst);
             const end_parent = @atomicLoad(Node.Parent, parent_ptr, .seq_cst);
             if (begin_parent == end_parent) {
                 any_ipc = any_ipc or (dest_storage.getIpcFd() != null);