Commit f229b74099

Igor Anić <igor.anic@gmail.com>
2022-11-17 20:58:45
stdlib: fix condition variable broadcast FutexImpl
fixes #12877 Current implementation (before this fix) observes number of waiters when broadcast occurs and then makes that number of wakeups. If we have multiple threads waiting for wakeup which immediately go into wait if wakeup is not for that thread (as described in the issue). The same thread can get multiple wakeups while some got none. That is not consistent with documented behavior for condition variable broadcast: `Unblocks all threads currently blocked in a call to wait() or timedWait() with a given Mutex.`. This fix ensures that the thread waiting on futext is woken up on futex wake.
1 parent 88a0f3d
Changed files (1)
lib
std
lib/std/Thread/Condition.zig
@@ -194,59 +194,50 @@ const FutexImpl = struct {
     const signal_mask = 0xffff << 16;
 
     fn wait(self: *Impl, mutex: *Mutex, timeout: ?u64) error{Timeout}!void {
-        // Register that we're waiting on the state by incrementing the wait count.
-        // This assumes that there can be at most ((1<<16)-1) or 65,355 threads concurrently waiting on the same Condvar.
-        // If this is hit in practice, then this condvar not working is the least of your concerns.
+        // Observe the epoch, then check the state again to see if we should wake up.
+        // The epoch must be observed before we check the state or we could potentially miss a wake() and deadlock:
+        //
+        // - T1: s = LOAD(&state)
+        // - T2: UPDATE(&s, signal)
+        // - T2: UPDATE(&epoch, 1) + FUTEX_WAKE(&epoch)
+        // - T1: e = LOAD(&epoch) (was reordered after the state load)
+        // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed the state update + the epoch change)
+        //
+        // Acquire barrier to ensure the epoch load happens before the state load.
+        const epoch = self.epoch.load(.Acquire);
         var state = self.state.fetchAdd(one_waiter, .Monotonic);
         assert(state & waiter_mask != waiter_mask);
         state += one_waiter;
+        var futex_deadline = Futex.Deadline.init(timeout);
 
-        // Temporarily release the mutex in order to block on the condition variable.
         mutex.unlock();
         defer mutex.lock();
 
-        var futex_deadline = Futex.Deadline.init(timeout);
-        while (true) {
-            // Try to wake up by consuming a signal and decremented the waiter we added previously.
-            // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return.
-            while (state & signal_mask != 0) {
-                const new_state = state - one_waiter - one_signal;
-                state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return;
-            }
+        futex_deadline.wait(&self.epoch, epoch) catch |err| switch (err) {
+            // On timeout, we must decrement the waiter we added above.
+            error.Timeout => {
+                while (true) {
+                    // If there's a signal when we're timing out, consume it and report being woken up instead.
+                    // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return.
+                    while (state & signal_mask != 0) {
+                        const new_state = state - one_waiter - one_signal;
+                        state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return;
+                    }
 
-            // Observe the epoch, then check the state again to see if we should wake up.
-            // The epoch must be observed before we check the state or we could potentially miss a wake() and deadlock:
-            //
-            // - T1: s = LOAD(&state)
-            // - T2: UPDATE(&s, signal)
-            // - T2: UPDATE(&epoch, 1) + FUTEX_WAKE(&epoch)
-            // - T1: e = LOAD(&epoch) (was reordered after the state load)
-            // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed the state update + the epoch change)
-            //
-            // Acquire barrier to ensure the epoch load happens before the state load.
-            const epoch = self.epoch.load(.Acquire);
-            state = self.state.load(.Monotonic);
+                    // Remove the waiter we added and officially return timed out.
+                    const new_state = state - one_waiter;
+                    state = self.state.tryCompareAndSwap(state, new_state, .Monotonic, .Monotonic) orelse return err;
+                }
+            },
+        };
+
+        while (true) {
+            // Wait thread, decrement waiter and consume signal if exists.
+            var new_state = state - one_waiter;
             if (state & signal_mask != 0) {
-                continue;
+                new_state = state - one_signal;
             }
-
-            futex_deadline.wait(&self.epoch, epoch) catch |err| switch (err) {
-                // On timeout, we must decrement the waiter we added above.
-                error.Timeout => {
-                    while (true) {
-                        // If there's a signal when we're timing out, consume it and report being woken up instead.
-                        // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return.
-                        while (state & signal_mask != 0) {
-                            const new_state = state - one_waiter - one_signal;
-                            state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return;
-                        }
-
-                        // Remove the waiter we added and officially return timed out.
-                        const new_state = state - one_waiter;
-                        state = self.state.tryCompareAndSwap(state, new_state, .Monotonic, .Monotonic) orelse return err;
-                    }
-                },
-            };
+            state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return;
         }
     }
 
@@ -536,3 +527,74 @@ test "Condition - broadcasting" {
         t.join();
     }
 }
+
+test "Condition - broadcasting - wake all threads" {
+    // Tests issue #12877
+    // This test requires spawning threads
+    if (builtin.single_threaded) {
+        return error.SkipZigTest;
+    }
+
+    const num_threads = 10;
+
+    const BroadcastTest = struct {
+        mutex: Mutex = .{},
+        cond: Condition = .{},
+        completed: Condition = .{},
+        count: usize = 0,
+        thread_id_to_wake: usize = 0,
+        threads: [num_threads]std.Thread = undefined,
+        wakeups: usize = 0,
+
+        fn run(self: *@This(), thread_id: usize) void {
+            self.mutex.lock();
+            defer self.mutex.unlock();
+
+            // The last broadcast thread to start tells the main test thread it's completed.
+            self.count += 1;
+            if (self.count == num_threads) {
+                self.completed.signal();
+            }
+
+            while (self.thread_id_to_wake != thread_id) {
+                self.cond.timedWait(&self.mutex, 1 * std.time.ns_per_s) catch std.debug.panic("thread_id {d} timeout {d}", .{ thread_id, self.thread_id_to_wake });
+                self.wakeups += 1;
+            }
+            if (self.thread_id_to_wake <= num_threads) {
+                // Signal next thread to wake up.
+                self.thread_id_to_wake += 1;
+                self.cond.broadcast();
+            }
+        }
+    };
+
+    var broadcast_test = BroadcastTest{};
+    var thread_id: usize = 1;
+    for (broadcast_test.threads) |*t| {
+        t.* = try std.Thread.spawn(.{}, BroadcastTest.run, .{ &broadcast_test, thread_id });
+        thread_id += 1;
+    }
+
+    {
+        broadcast_test.mutex.lock();
+        defer broadcast_test.mutex.unlock();
+
+        // Wait for all the broadcast threads to spawn.
+        // timedWait() to detect any potential deadlocks.
+        while (broadcast_test.count != num_threads) {
+            try broadcast_test.completed.timedWait(
+                &broadcast_test.mutex,
+                1 * std.time.ns_per_s,
+            );
+        }
+
+        // Signal thread 1 to wake up
+        broadcast_test.thread_id_to_wake = 1;
+        broadcast_test.cond.broadcast();
+    }
+
+    for (broadcast_test.threads) |t| {
+        t.join();
+    }
+    //std.debug.print("wakeups {d}\n", .{broadcast_test.wakeups});
+}