Commit 531d5b213f

cryptocode <cryptocode@zolo.io>
2022-04-08 20:26:56
std: add Thread.Condition.timedWait (#11352)
* std: add Thread.Condition.timedWait I needed the equivalent of `std::condition_variable::wait_for`, but it's missing in std. This PR adds an implementation, following the status quo of using std.os.CLOCK.REALTIME in the pthread case (i.e. Futex) A follow-up patch moving futex/condition stuff to monotonic clocks where available seems like a good idea. This would involve conditionally exposing more functions and constants through std.c and std.os. For instance, Chromium picks `pthread_cond_timedwait_relative_np` on macOS and `clock_gettime(CLOCK_MONOTONIC...)` on BSD's. Tested on Windows 11, macOS 12.2.1 and Linux (with/without libc) * Sleep in the single threaded case, handle timeout overflow in the Windows case and address a race condition in the AtomicCondition case.
1 parent 6ae8fe1
Changed files (1)
lib
std
lib/std/Thread/Condition.zig
@@ -17,6 +17,10 @@ pub fn wait(cond: *Condition, mutex: *Mutex) void {
     cond.impl.wait(mutex);
 }
 
+pub fn timedWait(cond: *Condition, mutex: *Mutex, timeout_ns: u64) error{TimedOut}!void {
+    try cond.impl.timedWait(mutex, timeout_ns);
+}
+
 pub fn signal(cond: *Condition) void {
     cond.impl.signal();
 }
@@ -41,6 +45,14 @@ pub const SingleThreadedCondition = struct {
         unreachable; // deadlock detected
     }
 
+    pub fn timedWait(cond: *SingleThreadedCondition, mutex: *Mutex, timeout_ns: u64) error{TimedOut}!void {
+        _ = cond;
+        _ = mutex;
+        _ = timeout_ns;
+        std.time.sleep(timeout_ns);
+        return error.TimedOut;
+    }
+
     pub fn signal(cond: *SingleThreadedCondition) void {
         _ = cond;
     }
@@ -63,6 +75,25 @@ pub const WindowsCondition = struct {
         assert(rc != windows.FALSE);
     }
 
+    pub fn timedWait(cond: *WindowsCondition, mutex: *Mutex, timeout_ns: u64) error{TimedOut}!void {
+        var timeout_checked = std.math.cast(windows.DWORD, timeout_ns / std.time.ns_per_ms) catch overflow: {
+            break :overflow std.math.maxInt(windows.DWORD);
+        };
+
+        // Handle the case where timeout is INFINITE, otherwise SleepConditionVariableSRW's time-out never elapses
+        const timeout_overflowed = timeout_checked == windows.INFINITE;
+        timeout_checked -= @boolToInt(timeout_overflowed);
+
+        const rc = windows.kernel32.SleepConditionVariableSRW(
+            &cond.cond,
+            &mutex.impl.srwlock,
+            timeout_checked,
+            @as(windows.ULONG, 0),
+        );
+        if (rc == windows.FALSE and windows.kernel32.GetLastError() == windows.Win32Error.TIMEOUT) return error.TimedOut;
+        assert(rc != windows.FALSE);
+    }
+
     pub fn signal(cond: *WindowsCondition) void {
         windows.kernel32.WakeConditionVariable(&cond.cond);
     }
@@ -80,6 +111,24 @@ pub const PthreadCondition = struct {
         assert(rc == .SUCCESS);
     }
 
+    pub fn timedWait(cond: *PthreadCondition, mutex: *Mutex, timeout_ns: u64) error{TimedOut}!void {
+        var ts: std.os.timespec = undefined;
+        std.os.clock_gettime(std.os.CLOCK.REALTIME, &ts) catch unreachable;
+        ts.tv_sec += @intCast(@TypeOf(ts.tv_sec), timeout_ns / std.time.ns_per_s);
+        ts.tv_nsec += @intCast(@TypeOf(ts.tv_nsec), timeout_ns % std.time.ns_per_s);
+        if (ts.tv_nsec >= std.time.ns_per_s) {
+            ts.tv_sec += 1;
+            ts.tv_nsec -= std.time.ns_per_s;
+        }
+
+        const rc = std.c.pthread_cond_timedwait(&cond.cond, &mutex.impl.pthread_mutex, &ts);
+        return switch (rc) {
+            .SUCCESS => {},
+            .TIMEDOUT => error.TimedOut,
+            else => unreachable,
+        };
+    }
+
     pub fn signal(cond: *PthreadCondition) void {
         const rc = std.c.pthread_cond_signal(&cond.cond);
         assert(rc == .SUCCESS);
@@ -100,6 +149,7 @@ pub const AtomicCondition = struct {
 
     pub const QueueItem = struct {
         futex: i32 = 0,
+        dequeued: bool = false,
 
         fn wait(cond: *@This()) void {
             while (@atomicLoad(i32, &cond.futex, .Acquire) == 0) {
@@ -122,6 +172,39 @@ pub const AtomicCondition = struct {
             }
         }
 
+        pub fn timedWait(cond: *@This(), timeout_ns: u64) error{TimedOut}!void {
+            const start_time = std.time.nanoTimestamp();
+            while (@atomicLoad(i32, &cond.futex, .Acquire) == 0) {
+                switch (builtin.os.tag) {
+                    .linux => {
+                        var ts: std.os.timespec = undefined;
+                        ts.tv_sec = @intCast(@TypeOf(ts.tv_sec), timeout_ns / std.time.ns_per_s);
+                        ts.tv_nsec = @intCast(@TypeOf(ts.tv_nsec), timeout_ns % std.time.ns_per_s);
+                        switch (linux.getErrno(linux.futex_wait(
+                            &cond.futex,
+                            linux.FUTEX.PRIVATE_FLAG | linux.FUTEX.WAIT,
+                            0,
+                            &ts,
+                        ))) {
+                            .SUCCESS => {},
+                            .INTR => {},
+                            .AGAIN => {},
+                            .TIMEDOUT => return error.TimedOut,
+                            .INVAL => {}, // possibly timeout overflow
+                            .FAULT => unreachable,
+                            else => unreachable,
+                        }
+                    },
+                    else => {
+                        if (std.time.nanoTimestamp() - start_time >= timeout_ns) {
+                            return error.TimedOut;
+                        }
+                        std.atomic.spinLoopHint();
+                    },
+                }
+            }
+        }
+
         fn notify(cond: *@This()) void {
             @atomicStore(i32, &cond.futex, 1, .Release);
 
@@ -158,6 +241,41 @@ pub const AtomicCondition = struct {
         mutex.lock();
     }
 
+    pub fn timedWait(cond: *AtomicCondition, mutex: *Mutex, timeout_ns: u64) error{TimedOut}!void {
+        var waiter = QueueList.Node{ .data = .{} };
+
+        {
+            cond.queue_mutex.lock();
+            defer cond.queue_mutex.unlock();
+
+            cond.queue_list.prepend(&waiter);
+            @atomicStore(bool, &cond.pending, true, .SeqCst);
+        }
+
+        var timed_out = false;
+        mutex.unlock();
+        defer mutex.lock();
+        waiter.data.timedWait(timeout_ns) catch |err| switch (err) {
+            error.TimedOut => {
+                defer if (!timed_out) {
+                    waiter.data.wait();
+                };
+                cond.queue_mutex.lock();
+                defer cond.queue_mutex.unlock();
+
+                if (!waiter.data.dequeued) {
+                    timed_out = true;
+                    cond.queue_list.remove(&waiter);
+                }
+            },
+            else => unreachable,
+        };
+
+        if (timed_out) {
+            return error.TimedOut;
+        }
+    }
+
     pub fn signal(cond: *AtomicCondition) void {
         if (@atomicLoad(bool, &cond.pending, .SeqCst) == false)
             return;
@@ -167,12 +285,16 @@ pub const AtomicCondition = struct {
             defer cond.queue_mutex.unlock();
 
             const maybe_waiter = cond.queue_list.popFirst();
+            if (maybe_waiter) |waiter| {
+                waiter.data.dequeued = true;
+            }
             @atomicStore(bool, &cond.pending, cond.queue_list.first != null, .SeqCst);
             break :blk maybe_waiter;
         };
 
-        if (maybe_waiter) |waiter|
+        if (maybe_waiter) |waiter| {
             waiter.data.notify();
+        }
     }
 
     pub fn broadcast(cond: *AtomicCondition) void {
@@ -186,12 +308,19 @@ pub const AtomicCondition = struct {
             defer cond.queue_mutex.unlock();
 
             const waiters = cond.queue_list;
+
+            var it = waiters.first;
+            while (it) |node| : (it = node.next) {
+                node.data.dequeued = true;
+            }
+
             cond.queue_list = .{};
             break :blk waiters;
         };
 
-        while (waiters.popFirst()) |waiter|
+        while (waiters.popFirst()) |waiter| {
             waiter.data.notify();
+        }
     }
 };
 
@@ -238,3 +367,45 @@ test "Thread.Condition" {
 
     for (threads) |t| t.join();
 }
+
+test "Thread.Condition.timedWait" {
+    if (builtin.single_threaded) {
+        return error.SkipZigTest;
+    }
+
+    var cond = Condition{};
+    var mut = Mutex{};
+
+    // Expect a timeout, as the condition variable is never signaled
+    {
+        mut.lock();
+        defer mut.unlock();
+        try testing.expectError(error.TimedOut, cond.timedWait(&mut, 10 * std.time.ns_per_ms));
+    }
+
+    // Expect a signal before timeout
+    {
+        const TestContext = struct {
+            cond: *Condition,
+            mutex: *Mutex,
+            n: *u32,
+            fn worker(ctx: *@This()) void {
+                ctx.mutex.lock();
+                defer ctx.mutex.unlock();
+                ctx.n.* = 1;
+                ctx.cond.signal();
+            }
+        };
+
+        var n: u32 = 0;
+
+        var ctx = TestContext{ .cond = &cond, .mutex = &mut, .n = &n };
+        mut.lock();
+        var thread = try std.Thread.spawn(.{}, TestContext.worker, .{&ctx});
+        // Looped check to handle spurious wakeups
+        while (n != 1) try cond.timedWait(&mut, 500 * std.time.ns_per_ms);
+        mut.unlock();
+        try testing.expect(n == 1);
+        thread.join();
+    }
+}