Commit 9947b47d80

Igor Anić <igor.anic@gmail.com>
2022-11-21 17:26:54
stdlib: Thread.Condition wake only if signaled
Previous implementation didn't check whether there are pending signals after return from futex.wait. While it is ok for broadcast case it can result in multiple wakeups when only one thread is signaled. This implementation checks that there are pending signals before returning from wait. It is similar to the original implementation but the without initial signal check, here we first go to the futex and then check for pending signal.
1 parent f229b74
Changed files (1)
lib
std
lib/std/Thread/Condition.zig
@@ -204,40 +204,44 @@ const FutexImpl = struct {
         // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed the state update + the epoch change)
         //
         // Acquire barrier to ensure the epoch load happens before the state load.
-        const epoch = self.epoch.load(.Acquire);
+        var epoch = self.epoch.load(.Acquire);
         var state = self.state.fetchAdd(one_waiter, .Monotonic);
         assert(state & waiter_mask != waiter_mask);
         state += one_waiter;
-        var futex_deadline = Futex.Deadline.init(timeout);
 
         mutex.unlock();
         defer mutex.lock();
 
-        futex_deadline.wait(&self.epoch, epoch) catch |err| switch (err) {
-            // On timeout, we must decrement the waiter we added above.
-            error.Timeout => {
-                while (true) {
-                    // If there's a signal when we're timing out, consume it and report being woken up instead.
-                    // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return.
-                    while (state & signal_mask != 0) {
-                        const new_state = state - one_waiter - one_signal;
-                        state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return;
+        var futex_deadline = Futex.Deadline.init(timeout);
+
+        while (true) {
+            futex_deadline.wait(&self.epoch, epoch) catch |err| switch (err) {
+                // On timeout, we must decrement the waiter we added above.
+                error.Timeout => {
+                    while (true) {
+                        // If there's a signal when we're timing out, consume it and report being woken up instead.
+                        // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return.
+                        while (state & signal_mask != 0) {
+                            const new_state = state - one_waiter - one_signal;
+                            state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return;
+                        }
+
+                        // Remove the waiter we added and officially return timed out.
+                        const new_state = state - one_waiter;
+                        state = self.state.tryCompareAndSwap(state, new_state, .Monotonic, .Monotonic) orelse return err;
                     }
+                },
+            };
 
-                    // Remove the waiter we added and officially return timed out.
-                    const new_state = state - one_waiter;
-                    state = self.state.tryCompareAndSwap(state, new_state, .Monotonic, .Monotonic) orelse return err;
-                }
-            },
-        };
+            epoch = self.epoch.load(.Acquire);
+            state = self.state.load(.Monotonic);
 
-        while (true) {
-            // Wait thread, decrement waiter and consume signal if exists.
-            var new_state = state - one_waiter;
-            if (state & signal_mask != 0) {
-                new_state = state - one_signal;
+            // Try to wake up by consuming a signal and decremented the waiter we added previously.
+            // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return.
+            while (state & signal_mask != 0) {
+                const new_state = state - one_waiter - one_signal;
+                state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return;
             }
-            state = self.state.tryCompareAndSwap(state, new_state, .Acquire, .Monotonic) orelse return;
         }
     }
 
@@ -535,66 +539,142 @@ test "Condition - broadcasting - wake all threads" {
         return error.SkipZigTest;
     }
 
+    var num_runs: usize = 1;
     const num_threads = 10;
 
-    const BroadcastTest = struct {
-        mutex: Mutex = .{},
-        cond: Condition = .{},
-        completed: Condition = .{},
-        count: usize = 0,
-        thread_id_to_wake: usize = 0,
-        threads: [num_threads]std.Thread = undefined,
-        wakeups: usize = 0,
-
-        fn run(self: *@This(), thread_id: usize) void {
-            self.mutex.lock();
-            defer self.mutex.unlock();
+    while (num_runs > 0) : (num_runs -= 1) {
+        const BroadcastTest = struct {
+            mutex: Mutex = .{},
+            cond: Condition = .{},
+            completed: Condition = .{},
+            count: usize = 0,
+            thread_id_to_wake: usize = 0,
+            threads: [num_threads]std.Thread = undefined,
+            wakeups: usize = 0,
+
+            fn run(self: *@This(), thread_id: usize) void {
+                self.mutex.lock();
+                defer self.mutex.unlock();
+
+                // The last broadcast thread to start tells the main test thread it's completed.
+                self.count += 1;
+                if (self.count == num_threads) {
+                    self.completed.signal();
+                }
 
-            // The last broadcast thread to start tells the main test thread it's completed.
-            self.count += 1;
-            if (self.count == num_threads) {
-                self.completed.signal();
+                while (self.thread_id_to_wake != thread_id) {
+                    self.cond.timedWait(&self.mutex, 1 * std.time.ns_per_s) catch std.debug.panic("thread_id {d} timeout {d}", .{ thread_id, self.thread_id_to_wake });
+                    self.wakeups += 1;
+                }
+                if (self.thread_id_to_wake <= num_threads) {
+                    // Signal next thread to wake up.
+                    self.thread_id_to_wake += 1;
+                    self.cond.broadcast();
+                }
             }
+        };
 
-            while (self.thread_id_to_wake != thread_id) {
-                self.cond.timedWait(&self.mutex, 1 * std.time.ns_per_s) catch std.debug.panic("thread_id {d} timeout {d}", .{ thread_id, self.thread_id_to_wake });
-                self.wakeups += 1;
-            }
-            if (self.thread_id_to_wake <= num_threads) {
-                // Signal next thread to wake up.
-                self.thread_id_to_wake += 1;
-                self.cond.broadcast();
+        var broadcast_test = BroadcastTest{};
+        var thread_id: usize = 1;
+        for (broadcast_test.threads) |*t| {
+            t.* = try std.Thread.spawn(.{}, BroadcastTest.run, .{ &broadcast_test, thread_id });
+            thread_id += 1;
+        }
+
+        {
+            broadcast_test.mutex.lock();
+            defer broadcast_test.mutex.unlock();
+
+            // Wait for all the broadcast threads to spawn.
+            // timedWait() to detect any potential deadlocks.
+            while (broadcast_test.count != num_threads) {
+                try broadcast_test.completed.timedWait(
+                    &broadcast_test.mutex,
+                    1 * std.time.ns_per_s,
+                );
             }
+
+            // Signal thread 1 to wake up
+            broadcast_test.thread_id_to_wake = 1;
+            broadcast_test.cond.broadcast();
         }
-    };
 
-    var broadcast_test = BroadcastTest{};
-    var thread_id: usize = 1;
-    for (broadcast_test.threads) |*t| {
-        t.* = try std.Thread.spawn(.{}, BroadcastTest.run, .{ &broadcast_test, thread_id });
-        thread_id += 1;
+        for (broadcast_test.threads) |t| {
+            t.join();
+        }
     }
+}
 
-    {
-        broadcast_test.mutex.lock();
-        defer broadcast_test.mutex.unlock();
+test "Condition - signal wakes one" {
+    // This test requires spawning threads
+    if (builtin.single_threaded) {
+        return error.SkipZigTest;
+    }
 
-        // Wait for all the broadcast threads to spawn.
-        // timedWait() to detect any potential deadlocks.
-        while (broadcast_test.count != num_threads) {
-            try broadcast_test.completed.timedWait(
-                &broadcast_test.mutex,
-                1 * std.time.ns_per_s,
-            );
+    var num_runs: usize = 1;
+    const num_threads = 3;
+    const timeoutDelay = 10 * std.time.ns_per_ms;
+
+    while (num_runs > 0) : (num_runs -= 1) {
+
+        // Start multiple runner threads, wait for them to start and send the signal
+        // then. Expect that one thread wake up and all other times out.
+        //
+        // Test depends on delay in timedWait! If too small all threads can timeout
+        // before any one gets wake up.
+
+        const Runner = struct {
+            mutex: Mutex = .{},
+            cond: Condition = .{},
+            completed: Condition = .{},
+            count: usize = 0,
+            threads: [num_threads]std.Thread = undefined,
+            wakeups: usize = 0,
+            timeouts: usize = 0,
+
+            fn run(self: *@This()) void {
+                self.mutex.lock();
+                defer self.mutex.unlock();
+
+                // The last started thread tells the main test thread it's completed.
+                self.count += 1;
+                if (self.count == num_threads) {
+                    self.completed.signal();
+                }
+
+                self.cond.timedWait(&self.mutex, timeoutDelay) catch {
+                    self.timeouts += 1;
+                    return;
+                };
+                self.wakeups += 1;
+            }
+        };
+
+        // Start threads
+        var runner = Runner{};
+        for (runner.threads) |*t| {
+            t.* = try std.Thread.spawn(.{}, Runner.run, .{&runner});
         }
 
-        // Signal thread 1 to wake up
-        broadcast_test.thread_id_to_wake = 1;
-        broadcast_test.cond.broadcast();
-    }
+        {
+            runner.mutex.lock();
+            defer runner.mutex.unlock();
 
-    for (broadcast_test.threads) |t| {
-        t.join();
+            // Wait for all the threads to spawn.
+            // timedWait() to detect any potential deadlocks.
+            while (runner.count != num_threads) {
+                try runner.completed.timedWait(&runner.mutex, 1 * std.time.ns_per_s);
+            }
+            // Signal one thread, the others should get timeout.
+            runner.cond.signal();
+        }
+
+        for (runner.threads) |t| {
+            t.join();
+        }
+
+        // Expect that only one got singal
+        try std.testing.expectEqual(runner.wakeups, 1);
+        try std.testing.expectEqual(runner.timeouts, num_threads - 1);
     }
-    //std.debug.print("wakeups {d}\n", .{broadcast_test.wakeups});
 }