master
  1//! Condition variables are used with a Mutex to efficiently wait for an arbitrary condition to occur.
  2//! It does this by atomically unlocking the mutex, blocking the thread until notified, and finally re-locking the mutex.
  3//! Condition can be statically initialized and is at most `@sizeOf(u64)` large.
  4//!
  5//! Example:
  6//! ```
  7//! var m = Mutex{};
  8//! var c = Condition{};
  9//! var predicate = false;
 10//!
 11//! fn consumer() void {
 12//!     m.lock();
 13//!     defer m.unlock();
 14//!
 15//!     while (!predicate) {
 16//!         c.wait(&m);
 17//!     }
 18//! }
 19//!
 20//! fn producer() void {
 21//!     {
 22//!         m.lock();
 23//!         defer m.unlock();
 24//!         predicate = true;
 25//!     }
 26//!     c.signal();
 27//! }
 28//!
 29//! const thread = try std.Thread.spawn(.{}, producer, .{});
 30//! consumer();
 31//! thread.join();
 32//! ```
 33//!
 34//! Note that condition variables can only reliably unblock threads that are sequenced before them using the same Mutex.
 35//! This means that the following is allowed to deadlock:
 36//! ```
 37//! thread-1: mutex.lock()
 38//! thread-1: condition.wait(&mutex)
 39//!
 40//! thread-2: // mutex.lock() (without this, the following signal may not see the waiting thread-1)
 41//! thread-2: // mutex.unlock() (this is optional for correctness once locked above, as signal can be called while holding the mutex)
 42//! thread-2: condition.signal()
 43//! ```
 44
 45const std = @import("../std.zig");
 46const builtin = @import("builtin");
 47const Condition = @This();
 48const Mutex = std.Thread.Mutex;
 49
 50const os = std.os;
 51const assert = std.debug.assert;
 52const testing = std.testing;
 53const Futex = std.Thread.Futex;
 54
 55impl: Impl = .{},
 56
 57/// Atomically releases the Mutex, blocks the caller thread, then re-acquires the Mutex on return.
 58/// "Atomically" here refers to accesses done on the Condition after acquiring the Mutex.
 59///
 60/// The Mutex must be locked by the caller's thread when this function is called.
 61/// A Mutex can have multiple Conditions waiting with it concurrently, but not the opposite.
 62/// It is undefined behavior for multiple threads to wait ith different mutexes using the same Condition concurrently.
 63/// Once threads have finished waiting with one Mutex, the Condition can be used to wait with another Mutex.
 64///
 65/// A blocking call to wait() is unblocked from one of the following conditions:
 66/// - a spurious ("at random") wake up occurs
 67/// - a future call to `signal()` or `broadcast()` which has acquired the Mutex and is sequenced after this `wait()`.
 68///
 69/// Given wait() can be interrupted spuriously, the blocking condition should be checked continuously
 70/// irrespective of any notifications from `signal()` or `broadcast()`.
 71pub fn wait(self: *Condition, mutex: *Mutex) void {
 72    self.impl.wait(mutex, null) catch |err| switch (err) {
 73        error.Timeout => unreachable, // no timeout provided so we shouldn't have timed-out
 74    };
 75}
 76
 77/// Atomically releases the Mutex, blocks the caller thread, then re-acquires the Mutex on return.
 78/// "Atomically" here refers to accesses done on the Condition after acquiring the Mutex.
 79///
 80/// The Mutex must be locked by the caller's thread when this function is called.
 81/// A Mutex can have multiple Conditions waiting with it concurrently, but not the opposite.
 82/// It is undefined behavior for multiple threads to wait ith different mutexes using the same Condition concurrently.
 83/// Once threads have finished waiting with one Mutex, the Condition can be used to wait with another Mutex.
 84///
 85/// A blocking call to `timedWait()` is unblocked from one of the following conditions:
 86/// - a spurious ("at random") wake occurs
 87/// - the caller was blocked for around `timeout_ns` nanoseconds, in which `error.Timeout` is returned.
 88/// - a future call to `signal()` or `broadcast()` which has acquired the Mutex and is sequenced after this `timedWait()`.
 89///
 90/// Given `timedWait()` can be interrupted spuriously, the blocking condition should be checked continuously
 91/// irrespective of any notifications from `signal()` or `broadcast()`.
 92pub fn timedWait(self: *Condition, mutex: *Mutex, timeout_ns: u64) error{Timeout}!void {
 93    return self.impl.wait(mutex, timeout_ns);
 94}
 95
 96/// Unblocks at least one thread blocked in a call to `wait()` or `timedWait()` with a given Mutex.
 97/// The blocked thread must be sequenced before this call with respect to acquiring the same Mutex in order to be observable for unblocking.
 98/// `signal()` can be called with or without the relevant Mutex being acquired and have no "effect" if there's no observable blocked threads.
 99pub fn signal(self: *Condition) void {
100    self.impl.wake(.one);
101}
102
103/// Unblocks all threads currently blocked in a call to `wait()` or `timedWait()` with a given Mutex.
104/// The blocked threads must be sequenced before this call with respect to acquiring the same Mutex in order to be observable for unblocking.
105/// `broadcast()` can be called with or without the relevant Mutex being acquired and have no "effect" if there's no observable blocked threads.
106pub fn broadcast(self: *Condition) void {
107    self.impl.wake(.all);
108}
109
110const Impl = if (builtin.single_threaded)
111    SingleThreadedImpl
112else if (builtin.os.tag == .windows)
113    WindowsImpl
114else
115    FutexImpl;
116
117const Notify = enum {
118    one, // wake up only one thread
119    all, // wake up all threads
120};
121
122const SingleThreadedImpl = struct {
123    fn wait(self: *Impl, mutex: *Mutex, timeout: ?u64) error{Timeout}!void {
124        _ = self;
125        _ = mutex;
126        // There are no other threads to wake us up.
127        // So if we wait without a timeout we would never wake up.
128        assert(timeout != null); // Deadlock detected.
129        return error.Timeout;
130    }
131
132    fn wake(self: *Impl, comptime notify: Notify) void {
133        // There are no other threads to wake up.
134        _ = self;
135        _ = notify;
136    }
137};
138
139const WindowsImpl = struct {
140    condition: os.windows.CONDITION_VARIABLE = .{},
141
142    fn wait(self: *Impl, mutex: *Mutex, timeout: ?u64) error{Timeout}!void {
143        var timeout_overflowed = false;
144        var timeout_ms: os.windows.DWORD = os.windows.INFINITE;
145
146        if (timeout) |timeout_ns| {
147            // Round the nanoseconds to the nearest millisecond,
148            // then saturating cast it to windows DWORD for use in kernel32 call.
149            const ms = (timeout_ns +| (std.time.ns_per_ms / 2)) / std.time.ns_per_ms;
150            timeout_ms = std.math.cast(os.windows.DWORD, ms) orelse std.math.maxInt(os.windows.DWORD);
151
152            // Track if the timeout overflowed into INFINITE and make sure not to wait forever.
153            if (timeout_ms == os.windows.INFINITE) {
154                timeout_overflowed = true;
155                timeout_ms -= 1;
156            }
157        }
158
159        if (builtin.mode == .Debug) {
160            // The internal state of the DebugMutex needs to be handled here as well.
161            mutex.impl.locking_thread.store(0, .unordered);
162        }
163        const rc = os.windows.kernel32.SleepConditionVariableSRW(
164            &self.condition,
165            if (builtin.mode == .Debug) &mutex.impl.impl.srwlock else &mutex.impl.srwlock,
166            timeout_ms,
167            0, // the srwlock was assumed to acquired in exclusive mode not shared
168        );
169        if (builtin.mode == .Debug) {
170            // The internal state of the DebugMutex needs to be handled here as well.
171            mutex.impl.locking_thread.store(std.Thread.getCurrentId(), .unordered);
172        }
173
174        // Return error.Timeout if we know the timeout elapsed correctly.
175        if (rc == os.windows.FALSE) {
176            assert(os.windows.GetLastError() == .TIMEOUT);
177            if (!timeout_overflowed) return error.Timeout;
178        }
179    }
180
181    fn wake(self: *Impl, comptime notify: Notify) void {
182        switch (notify) {
183            .one => os.windows.ntdll.RtlWakeConditionVariable(&self.condition),
184            .all => os.windows.ntdll.RtlWakeAllConditionVariable(&self.condition),
185        }
186    }
187};
188
189const FutexImpl = struct {
190    state: std.atomic.Value(u32) = std.atomic.Value(u32).init(0),
191    epoch: std.atomic.Value(u32) = std.atomic.Value(u32).init(0),
192
193    const one_waiter = 1;
194    const waiter_mask = 0xffff;
195
196    const one_signal = 1 << 16;
197    const signal_mask = 0xffff << 16;
198
199    fn wait(self: *Impl, mutex: *Mutex, timeout: ?u64) error{Timeout}!void {
200        // Observe the epoch, then check the state again to see if we should wake up.
201        // The epoch must be observed before we check the state or we could potentially miss a wake() and deadlock:
202        //
203        // - T1: s = LOAD(&state)
204        // - T2: UPDATE(&s, signal)
205        // - T2: UPDATE(&epoch, 1) + FUTEX_WAKE(&epoch)
206        // - T1: e = LOAD(&epoch) (was reordered after the state load)
207        // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed the state update + the epoch change)
208        //
209        // Acquire barrier to ensure the epoch load happens before the state load.
210        var epoch = self.epoch.load(.acquire);
211        var state = self.state.fetchAdd(one_waiter, .monotonic);
212        assert(state & waiter_mask != waiter_mask);
213        state += one_waiter;
214
215        mutex.unlock();
216        defer mutex.lock();
217
218        var futex_deadline = Futex.Deadline.init(timeout);
219
220        while (true) {
221            futex_deadline.wait(&self.epoch, epoch) catch |err| switch (err) {
222                // On timeout, we must decrement the waiter we added above.
223                error.Timeout => {
224                    while (true) {
225                        // If there's a signal when we're timing out, consume it and report being woken up instead.
226                        // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return.
227                        while (state & signal_mask != 0) {
228                            const new_state = state - one_waiter - one_signal;
229                            state = self.state.cmpxchgWeak(state, new_state, .acquire, .monotonic) orelse return;
230                        }
231
232                        // Remove the waiter we added and officially return timed out.
233                        const new_state = state - one_waiter;
234                        state = self.state.cmpxchgWeak(state, new_state, .monotonic, .monotonic) orelse return err;
235                    }
236                },
237            };
238
239            epoch = self.epoch.load(.acquire);
240            state = self.state.load(.monotonic);
241
242            // Try to wake up by consuming a signal and decremented the waiter we added previously.
243            // Acquire barrier ensures code before the wake() which added the signal happens before we decrement it and return.
244            while (state & signal_mask != 0) {
245                const new_state = state - one_waiter - one_signal;
246                state = self.state.cmpxchgWeak(state, new_state, .acquire, .monotonic) orelse return;
247            }
248        }
249    }
250
251    fn wake(self: *Impl, comptime notify: Notify) void {
252        var state = self.state.load(.monotonic);
253        while (true) {
254            const waiters = (state & waiter_mask) / one_waiter;
255            const signals = (state & signal_mask) / one_signal;
256
257            // Reserves which waiters to wake up by incrementing the signals count.
258            // Therefore, the signals count is always less than or equal to the waiters count.
259            // We don't need to Futex.wake if there's nothing to wake up or if other wake() threads have reserved to wake up the current waiters.
260            const wakeable = waiters - signals;
261            if (wakeable == 0) {
262                return;
263            }
264
265            const to_wake = switch (notify) {
266                .one => 1,
267                .all => wakeable,
268            };
269
270            // Reserve the amount of waiters to wake by incrementing the signals count.
271            // Release barrier ensures code before the wake() happens before the signal it posted and consumed by the wait() threads.
272            const new_state = state + (one_signal * to_wake);
273            state = self.state.cmpxchgWeak(state, new_state, .release, .monotonic) orelse {
274                // Wake up the waiting threads we reserved above by changing the epoch value.
275                // NOTE: a waiting thread could miss a wake up if *exactly* ((1<<32)-1) wake()s happen between it observing the epoch and sleeping on it.
276                // This is very unlikely due to how many precise amount of Futex.wake() calls that would be between the waiting thread's potential preemption.
277                //
278                // Release barrier ensures the signal being added to the state happens before the epoch is changed.
279                // If not, the waiting thread could potentially deadlock from missing both the state and epoch change:
280                //
281                // - T2: UPDATE(&epoch, 1) (reordered before the state change)
282                // - T1: e = LOAD(&epoch)
283                // - T1: s = LOAD(&state)
284                // - T2: UPDATE(&state, signal) + FUTEX_WAKE(&epoch)
285                // - T1: s & signals == 0 -> FUTEX_WAIT(&epoch, e) (missed both epoch change and state change)
286                _ = self.epoch.fetchAdd(1, .release);
287                Futex.wake(&self.epoch, to_wake);
288                return;
289            };
290        }
291    }
292};
293
294test "smoke test" {
295    var mutex = Mutex{};
296    var cond = Condition{};
297
298    // Try to wake outside the mutex
299    defer cond.signal();
300    defer cond.broadcast();
301
302    mutex.lock();
303    defer mutex.unlock();
304
305    // Try to wait with a timeout (should not deadlock)
306    try testing.expectError(error.Timeout, cond.timedWait(&mutex, 0));
307    try testing.expectError(error.Timeout, cond.timedWait(&mutex, std.time.ns_per_ms));
308
309    // Try to wake inside the mutex.
310    cond.signal();
311    cond.broadcast();
312}
313
314// Inspired from: https://github.com/Amanieu/parking_lot/pull/129
315test "wait and signal" {
316    // This test requires spawning threads
317    if (builtin.single_threaded) {
318        return error.SkipZigTest;
319    }
320
321    const io = testing.io;
322
323    const num_threads = 4;
324
325    const MultiWait = struct {
326        mutex: Mutex = .{},
327        cond: Condition = .{},
328        threads: [num_threads]std.Thread = undefined,
329        spawn_count: std.math.IntFittingRange(0, num_threads) = 0,
330
331        fn run(self: *@This()) void {
332            self.mutex.lock();
333            defer self.mutex.unlock();
334            self.spawn_count += 1;
335
336            self.cond.wait(&self.mutex);
337            self.cond.timedWait(&self.mutex, std.time.ns_per_ms) catch {};
338            self.cond.signal();
339        }
340    };
341
342    var multi_wait = MultiWait{};
343    for (&multi_wait.threads) |*t| {
344        t.* = try std.Thread.spawn(.{}, MultiWait.run, .{&multi_wait});
345    }
346
347    while (true) {
348        try std.Io.Clock.Duration.sleep(.{ .clock = .awake, .raw = .fromMilliseconds(100) }, io);
349
350        multi_wait.mutex.lock();
351        defer multi_wait.mutex.unlock();
352        // Make sure all of the threads have finished spawning to avoid a deadlock.
353        if (multi_wait.spawn_count == num_threads) break;
354    }
355
356    multi_wait.cond.signal();
357    for (multi_wait.threads) |t| {
358        t.join();
359    }
360}
361
362test signal {
363    // This test requires spawning threads
364    if (builtin.single_threaded) {
365        return error.SkipZigTest;
366    }
367
368    const io = testing.io;
369
370    const num_threads = 4;
371
372    const SignalTest = struct {
373        mutex: Mutex = .{},
374        cond: Condition = .{},
375        notified: bool = false,
376        threads: [num_threads]std.Thread = undefined,
377        spawn_count: std.math.IntFittingRange(0, num_threads) = 0,
378
379        fn run(self: *@This()) void {
380            self.mutex.lock();
381            defer self.mutex.unlock();
382            self.spawn_count += 1;
383
384            // Use timedWait() a few times before using wait()
385            // to test multiple threads timing out frequently.
386            var i: usize = 0;
387            while (!self.notified) : (i +%= 1) {
388                if (i < 5) {
389                    self.cond.timedWait(&self.mutex, 1) catch {};
390                } else {
391                    self.cond.wait(&self.mutex);
392                }
393            }
394
395            // Once we received the signal, notify another thread (inside the lock).
396            assert(self.notified);
397            self.cond.signal();
398        }
399    };
400
401    var signal_test = SignalTest{};
402    for (&signal_test.threads) |*t| {
403        t.* = try std.Thread.spawn(.{}, SignalTest.run, .{&signal_test});
404    }
405
406    while (true) {
407        try std.Io.Clock.Duration.sleep(.{ .clock = .awake, .raw = .fromMilliseconds(10) }, io);
408
409        signal_test.mutex.lock();
410        defer signal_test.mutex.unlock();
411        // Make sure at least one thread has finished spawning to avoid testing nothing.
412        if (signal_test.spawn_count > 0) break;
413    }
414
415    {
416        // Wake up one of them (outside the lock) after setting notified=true.
417        defer signal_test.cond.signal();
418
419        signal_test.mutex.lock();
420        defer signal_test.mutex.unlock();
421
422        try testing.expect(!signal_test.notified);
423        signal_test.notified = true;
424    }
425
426    for (signal_test.threads) |t| {
427        t.join();
428    }
429}
430
431test "multi signal" {
432    // This test requires spawning threads
433    if (builtin.single_threaded) {
434        return error.SkipZigTest;
435    }
436
437    const num_threads = 4;
438    const num_iterations = 4;
439
440    const Paddle = struct {
441        mutex: Mutex = .{},
442        cond: Condition = .{},
443        value: u32 = 0,
444
445        fn hit(self: *@This()) void {
446            defer self.cond.signal();
447
448            self.mutex.lock();
449            defer self.mutex.unlock();
450
451            self.value += 1;
452        }
453
454        fn run(self: *@This(), hit_to: *@This()) !void {
455            self.mutex.lock();
456            defer self.mutex.unlock();
457
458            var current: u32 = 0;
459            while (current < num_iterations) : (current += 1) {
460                // Wait for the value to change from hit()
461                while (self.value == current) {
462                    self.cond.wait(&self.mutex);
463                }
464
465                // hit the next paddle
466                try testing.expectEqual(self.value, current + 1);
467                hit_to.hit();
468            }
469        }
470    };
471
472    var paddles = [_]Paddle{.{}} ** num_threads;
473    var threads = [_]std.Thread{undefined} ** num_threads;
474
475    // Create a circle of paddles which hit each other
476    for (&threads, 0..) |*t, i| {
477        const paddle = &paddles[i];
478        const hit_to = &paddles[(i + 1) % paddles.len];
479        t.* = try std.Thread.spawn(.{}, Paddle.run, .{ paddle, hit_to });
480    }
481
482    // Hit the first paddle and wait for them all to complete by hitting each other for num_iterations.
483    paddles[0].hit();
484    for (threads) |t| t.join();
485
486    // The first paddle will be hit one last time by the last paddle.
487    for (paddles, 0..) |p, i| {
488        const expected = @as(u32, num_iterations) + @intFromBool(i == 0);
489        try testing.expectEqual(p.value, expected);
490    }
491}
492
493test broadcast {
494    // This test requires spawning threads
495    if (builtin.single_threaded) {
496        return error.SkipZigTest;
497    }
498
499    const num_threads = 10;
500
501    const BroadcastTest = struct {
502        mutex: Mutex = .{},
503        cond: Condition = .{},
504        completed: Condition = .{},
505        count: usize = 0,
506        threads: [num_threads]std.Thread = undefined,
507
508        fn run(self: *@This()) void {
509            self.mutex.lock();
510            defer self.mutex.unlock();
511
512            // The last broadcast thread to start tells the main test thread it's completed.
513            self.count += 1;
514            if (self.count == num_threads) {
515                self.completed.signal();
516            }
517
518            // Waits for the count to reach zero after the main test thread observes it at num_threads.
519            // Tries to use timedWait() a bit before falling back to wait() to test multiple threads timing out.
520            var i: usize = 0;
521            while (self.count != 0) : (i +%= 1) {
522                if (i < 10) {
523                    self.cond.timedWait(&self.mutex, 1) catch {};
524                } else {
525                    self.cond.wait(&self.mutex);
526                }
527            }
528        }
529    };
530
531    var broadcast_test = BroadcastTest{};
532    for (&broadcast_test.threads) |*t| {
533        t.* = try std.Thread.spawn(.{}, BroadcastTest.run, .{&broadcast_test});
534    }
535
536    {
537        broadcast_test.mutex.lock();
538        defer broadcast_test.mutex.unlock();
539
540        // Wait for all the broadcast threads to spawn.
541        // timedWait() to detect any potential deadlocks.
542        while (broadcast_test.count != num_threads) {
543            broadcast_test.completed.timedWait(
544                &broadcast_test.mutex,
545                1 * std.time.ns_per_s,
546            ) catch {};
547        }
548
549        // Reset the counter and wake all the threads to exit.
550        broadcast_test.count = 0;
551        broadcast_test.cond.broadcast();
552    }
553
554    for (broadcast_test.threads) |t| {
555        t.join();
556    }
557}
558
559test "broadcasting - wake all threads" {
560    // Tests issue #12877
561    // This test requires spawning threads
562    if (builtin.single_threaded) {
563        return error.SkipZigTest;
564    }
565
566    var num_runs: usize = 1;
567    const num_threads = 10;
568
569    while (num_runs > 0) : (num_runs -= 1) {
570        const BroadcastTest = struct {
571            mutex: Mutex = .{},
572            cond: Condition = .{},
573            completed: Condition = .{},
574            count: usize = 0,
575            thread_id_to_wake: usize = 0,
576            threads: [num_threads]std.Thread = undefined,
577            wakeups: usize = 0,
578
579            fn run(self: *@This(), thread_id: usize) void {
580                self.mutex.lock();
581                defer self.mutex.unlock();
582
583                // The last broadcast thread to start tells the main test thread it's completed.
584                self.count += 1;
585                if (self.count == num_threads) {
586                    self.completed.signal();
587                }
588
589                while (self.thread_id_to_wake != thread_id) {
590                    self.cond.timedWait(&self.mutex, 1 * std.time.ns_per_s) catch {};
591                    self.wakeups += 1;
592                }
593                if (self.thread_id_to_wake <= num_threads) {
594                    // Signal next thread to wake up.
595                    self.thread_id_to_wake += 1;
596                    self.cond.broadcast();
597                }
598            }
599        };
600
601        var broadcast_test = BroadcastTest{};
602        var thread_id: usize = 1;
603        for (&broadcast_test.threads) |*t| {
604            t.* = try std.Thread.spawn(.{}, BroadcastTest.run, .{ &broadcast_test, thread_id });
605            thread_id += 1;
606        }
607
608        {
609            broadcast_test.mutex.lock();
610            defer broadcast_test.mutex.unlock();
611
612            // Wait for all the broadcast threads to spawn.
613            // timedWait() to detect any potential deadlocks.
614            while (broadcast_test.count != num_threads) {
615                broadcast_test.completed.timedWait(
616                    &broadcast_test.mutex,
617                    1 * std.time.ns_per_s,
618                ) catch {};
619            }
620
621            // Signal thread 1 to wake up
622            broadcast_test.thread_id_to_wake = 1;
623            broadcast_test.cond.broadcast();
624        }
625
626        for (broadcast_test.threads) |t| {
627            t.join();
628        }
629    }
630}