Commit 99f2ce5e91

Andrew Kelley <andrew@ziglang.org>
2022-11-23 22:24:55
Merge pull request #13577 from ianic/issue-12877
stdlib: fix condition variable broadcast FutexImpl
1 parent 89a491a
Changed files (1)
lib
std
lib/std/Thread/Condition.zig
@@ -194,42 +194,27 @@ const FutexImpl = struct {
     const signal_mask = 0xffff << 16;
 
     fn wait(self: *Impl, mutex: *Mutex, timeout: ?u64) error{Timeout}!void {
-        // Register that we're waiting on the state by incrementing the wait count.
-        // This assumes that there can be at most ((1<<16)-1) or 65,355 threads concurrently waiting on the same Condvar.
-        // If this is hit in practice, then this condvar not working is the least of your concerns.
+        // Observe the epoch, then check the state again to see if we should wake up.
+        // The epoch must be observed before we check the state or we could potentially miss a wake() and deadlock:
+        //
+        // - T1: s = LOAD(&state)
+        // - T2: UPDATE(&s, signal)
+        // - T2: UPDATE(&epoch, 1) + FUTEX_WAKE(&epoch)
+        // - T1: e = LOAD(&epoch) (was reordered after the state load)
+        // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed the state update + the epoch change)
+        //
+        // Acquire barrier to ensure the epoch load happens before the state load.
+        var epoch = self.epoch.load(.Acquire);
         var state = self.state.fetchAdd(one_waiter, .Monotonic);
         assert(state & waiter_mask != waiter_mask);
         state += one_waiter;
 
-        // Temporarily release the mutex in order to block on the condition variable.
         mutex.unlock();
         defer mutex.lock();
 
         var futex_deadline = Futex.Deadline.init(timeout);
-        while (true) {
-            // Try to wake up by consuming a signal and decremented the waiter we added previously.
-            // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return.
-            while (state & signal_mask != 0) {
-                const new_state = state - one_waiter - one_signal;
-                state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return;
-            }
-
-            // Observe the epoch, then check the state again to see if we should wake up.
-            // The epoch must be observed before we check the state or we could potentially miss a wake() and deadlock:
-            //
-            // - T1: s = LOAD(&state)
-            // - T2: UPDATE(&s, signal)
-            // - T2: UPDATE(&epoch, 1) + FUTEX_WAKE(&epoch)
-            // - T1: e = LOAD(&epoch) (was reordered after the state load)
-            // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed the state update + the epoch change)
-            //
-            // Acquire barrier to ensure the epoch load happens before the state load.
-            const epoch = self.epoch.load(.Acquire);
-            state = self.state.load(.Monotonic);
-            if (state & signal_mask != 0) {
-                continue;
-            }
 
+        while (true) {
             futex_deadline.wait(&self.epoch, epoch) catch |err| switch (err) {
                 // On timeout, we must decrement the waiter we added above.
                 error.Timeout => {
@@ -247,6 +232,16 @@ const FutexImpl = struct {
                     }
                 },
             };
+
+            epoch = self.epoch.load(.Acquire);
+            state = self.state.load(.Monotonic);
+
+            // Try to wake up by consuming a signal and decremented the waiter we added previously.
+            // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return.
+            while (state & signal_mask != 0) {
+                const new_state = state - one_waiter - one_signal;
+                state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return;
+            }
         }
     }
 
@@ -536,3 +531,150 @@ test "Condition - broadcasting" {
         t.join();
     }
 }
+
+test "Condition - broadcasting - wake all threads" {
+    // Tests issue #12877
+    // This test requires spawning threads
+    if (builtin.single_threaded) {
+        return error.SkipZigTest;
+    }
+
+    var num_runs: usize = 1;
+    const num_threads = 10;
+
+    while (num_runs > 0) : (num_runs -= 1) {
+        const BroadcastTest = struct {
+            mutex: Mutex = .{},
+            cond: Condition = .{},
+            completed: Condition = .{},
+            count: usize = 0,
+            thread_id_to_wake: usize = 0,
+            threads: [num_threads]std.Thread = undefined,
+            wakeups: usize = 0,
+
+            fn run(self: *@This(), thread_id: usize) void {
+                self.mutex.lock();
+                defer self.mutex.unlock();
+
+                // The last broadcast thread to start tells the main test thread it's completed.
+                self.count += 1;
+                if (self.count == num_threads) {
+                    self.completed.signal();
+                }
+
+                while (self.thread_id_to_wake != thread_id) {
+                    self.cond.timedWait(&self.mutex, 1 * std.time.ns_per_s) catch std.debug.panic("thread_id {d} timeout {d}", .{ thread_id, self.thread_id_to_wake });
+                    self.wakeups += 1;
+                }
+                if (self.thread_id_to_wake <= num_threads) {
+                    // Signal next thread to wake up.
+                    self.thread_id_to_wake += 1;
+                    self.cond.broadcast();
+                }
+            }
+        };
+
+        var broadcast_test = BroadcastTest{};
+        var thread_id: usize = 1;
+        for (broadcast_test.threads) |*t| {
+            t.* = try std.Thread.spawn(.{}, BroadcastTest.run, .{ &broadcast_test, thread_id });
+            thread_id += 1;
+        }
+
+        {
+            broadcast_test.mutex.lock();
+            defer broadcast_test.mutex.unlock();
+
+            // Wait for all the broadcast threads to spawn.
+            // timedWait() to detect any potential deadlocks.
+            while (broadcast_test.count != num_threads) {
+                try broadcast_test.completed.timedWait(
+                    &broadcast_test.mutex,
+                    1 * std.time.ns_per_s,
+                );
+            }
+
+            // Signal thread 1 to wake up
+            broadcast_test.thread_id_to_wake = 1;
+            broadcast_test.cond.broadcast();
+        }
+
+        for (broadcast_test.threads) |t| {
+            t.join();
+        }
+    }
+}
+
+test "Condition - signal wakes one" {
+    // This test requires spawning threads
+    if (builtin.single_threaded) {
+        return error.SkipZigTest;
+    }
+
+    var num_runs: usize = 1;
+    const num_threads = 3;
+    const timeoutDelay = 10 * std.time.ns_per_ms;
+
+    while (num_runs > 0) : (num_runs -= 1) {
+
+        // Start multiple runner threads, wait for them to start and send the signal
+        // then. Expect that one thread wake up and all other times out.
+        //
+        // Test depends on delay in timedWait! If too small all threads can timeout
+        // before any one gets wake up.
+
+        const Runner = struct {
+            mutex: Mutex = .{},
+            cond: Condition = .{},
+            completed: Condition = .{},
+            count: usize = 0,
+            threads: [num_threads]std.Thread = undefined,
+            wakeups: usize = 0,
+            timeouts: usize = 0,
+
+            fn run(self: *@This()) void {
+                self.mutex.lock();
+                defer self.mutex.unlock();
+
+                // The last started thread tells the main test thread it's completed.
+                self.count += 1;
+                if (self.count == num_threads) {
+                    self.completed.signal();
+                }
+
+                self.cond.timedWait(&self.mutex, timeoutDelay) catch {
+                    self.timeouts += 1;
+                    return;
+                };
+                self.wakeups += 1;
+            }
+        };
+
+        // Start threads
+        var runner = Runner{};
+        for (runner.threads) |*t| {
+            t.* = try std.Thread.spawn(.{}, Runner.run, .{&runner});
+        }
+
+        {
+            runner.mutex.lock();
+            defer runner.mutex.unlock();
+
+            // Wait for all the threads to spawn.
+            // timedWait() to detect any potential deadlocks.
+            while (runner.count != num_threads) {
+                try runner.completed.timedWait(&runner.mutex, 1 * std.time.ns_per_s);
+            }
+            // Signal one thread, the others should get timeout.
+            runner.cond.signal();
+        }
+
+        for (runner.threads) |t| {
+            t.join();
+        }
+
+        // Expect that only one got singal
+        try std.testing.expectEqual(runner.wakeups, 1);
+        try std.testing.expectEqual(runner.timeouts, num_threads - 1);
+    }
+}