Commit b8fabb3426

kprotty <kbutcher6200@gmail.com>
2019-12-22 22:24:17
ResetEvent: broadcast by default
1 parent 28dbdba
Changed files (3)
lib/std/c.zig
@@ -220,6 +220,7 @@ pub const PTHREAD_COND_INITIALIZER = pthread_cond_t{};
 pub extern "c" fn pthread_cond_wait(noalias cond: *pthread_cond_t, noalias mutex: *pthread_mutex_t) c_int;
 pub extern "c" fn pthread_cond_timedwait(noalias cond: *pthread_cond_t, noalias mutex: *pthread_mutex_t, noalias abstime: *const timespec) c_int;
 pub extern "c" fn pthread_cond_signal(cond: *pthread_cond_t) c_int;
+pub extern "c" fn pthread_cond_broadcast(cond: *pthread_cond_t) c_int;
 pub extern "c" fn pthread_cond_destroy(cond: *pthread_cond_t) c_int;
 
 pub const pthread_t = *@OpaqueType();
lib/std/mutex.zig
@@ -1,6 +1,8 @@
 const std = @import("std.zig");
 const builtin = @import("builtin");
 const os = std.os;
+const assert = std.debug.assert;
+const windows = os.windows;
 const testing = std.testing;
 const SpinLock = std.SpinLock;
 const ResetEvent = std.ResetEvent;
@@ -73,20 +75,33 @@ else if (builtin.os == .windows)
             return self.tryAcquire() orelse self.acquireSlow();
         }
 
+        fn acquireSpinning(self: *Mutex) Held {
+            @setCold(true);
+            while (true) : (SpinLock.yield()) {
+                return self.tryAcquire() orelse continue;
+            }
+        }
+
         fn acquireSlow(self: *Mutex) Held {
+            // try to use NT keyed events for blocking, falling back to spinlock if unavailable
             @setCold(true);
+            const handle = ResetEvent.OsEvent.Futex.getEventHandle() orelse return self.acquireSpinning();
+            const key = @ptrCast(*const c_void, &self.waiters);
+
             while (true) : (SpinLock.loopHint(1)) {
                 const waiters = @atomicLoad(u32, &self.waiters, .Monotonic);
 
                 // try and take lock if unlocked
                 if ((waiters & 1) == 0) {
-                    if (@atomicRmw(u8, &self.locked, .Xchg, 1, .Acquire) == 0)
+                    if (@atomicRmw(u8, &self.locked, .Xchg, 1, .Acquire) == 0) {
                         return Held{ .mutex = self };
+                    }
 
                 // otherwise, try and update the waiting count.
                 // then unset the WAKE bit so that another unlocker can wake up a thread.
                 } else if (@cmpxchgWeak(u32, &self.waiters, waiters, (waiters + WAIT) | 1, .Monotonic, .Monotonic) == null) {
-                    ResetEvent.OsEvent.Futex.wait(@ptrCast(*i32, &self.waiters), undefined, null) catch unreachable;
+                    const rc = windows.ntdll.NtWaitForKeyedEvent(handle, key, windows.FALSE, null);
+                    assert(rc == 0);
                     _ = @atomicRmw(u32, &self.waiters, .Sub, WAKE, .Monotonic);
                 }
             }
@@ -98,6 +113,8 @@ else if (builtin.os == .windows)
             pub fn release(self: Held) void {
                 // unlock without a rmw/cmpxchg instruction
                 @atomicStore(u8, @ptrCast(*u8, &self.mutex.locked), 0, .Release);
+                const handle = ResetEvent.OsEvent.Futex.getEventHandle() orelse return;
+                const key = @ptrCast(*const c_void, &self.mutex.waiters);
 
                 while (true) : (SpinLock.loopHint(1)) {
                     const waiters = @atomicLoad(u32, &self.mutex.waiters, .Monotonic);
@@ -110,8 +127,11 @@ else if (builtin.os == .windows)
                     if (waiters & WAKE != 0) return;
 
                     // try to decrease the waiter count & set the WAKE bit meaning a thread is waking up
-                    if (@cmpxchgWeak(u32, &self.mutex.waiters, waiters, waiters - WAIT + WAKE, .Release, .Monotonic) == null)
-                        return ResetEvent.OsEvent.Futex.wake(@ptrCast(*i32, &self.mutex.waiters));
+                    if (@cmpxchgWeak(u32, &self.mutex.waiters, waiters, waiters - WAIT + WAKE, .Release, .Monotonic) == null) {
+                        const rc = windows.ntdll.NtReleaseKeyedEvent(handle, key, windows.FALSE, null);
+                        assert(rc == 0);
+                        return;   
+                    }
                 }
             }
         };
lib/std/reset_event.zig
@@ -36,7 +36,7 @@ pub const ResetEvent = struct {
     }
 
     /// Sets the event if not already set and
-    /// wakes up at least one thread waiting the event.
+    /// wakes up all the threads waiting on the event.
     pub fn set(self: *ResetEvent) void {
         return self.os_event.set();
     }
@@ -135,7 +135,7 @@ const PosixEvent = struct {
 
         if (!self.is_set) {
             self.is_set = true;
-            assert(c.pthread_cond_signal(&self.cond) == 0);
+            assert(c.pthread_cond_broadcast(&self.cond) == 0);
         }
     }
 
@@ -181,40 +181,39 @@ const PosixEvent = struct {
 };
 
 const AtomicEvent = struct {
-    state: State,
+    waiters: u32,
 
-    const State = enum(i32) {
-        Empty,
-        Waiting,
-        Signaled,
-    };
+    const WAKE = 1 << 0;
+    const WAIT = 1 << 1;
 
     fn init() AtomicEvent {
-        return AtomicEvent{ .state = .Empty };
+        return AtomicEvent{ .waiters = 0 };
     }
 
     fn deinit(self: *AtomicEvent) void {
         self.* = undefined;
     }
 
-    fn isSet(self: *AtomicEvent) bool {
-        return @atomicLoad(State, &self.state, .Acquire) == .Signaled;
+    fn isSet(self: *const AtomicEvent) bool {
+        return @atomicLoad(u32, &self.waiters, .Acquire) == WAKE;
     }
 
     fn reset(self: *AtomicEvent) void {
-        @atomicStore(State, &self.state, .Empty, .Monotonic);
+        @atomicStore(u32, &self.waiters, 0, .Monotonic);
     }
 
     fn set(self: *AtomicEvent) void {
-        if (@atomicRmw(State, &self.state, .Xchg, .Signaled, .Release) == .Waiting)
-            Futex.wake(@ptrCast(*i32, &self.state));
+        const waiters = @atomicRmw(u32, &self.waiters, .Xchg, WAKE, .Release);
+        if (waiters >= WAIT) {
+            return Futex.wake(&self.waiters, waiters >> 1);
+        }
     }
 
     fn wait(self: *AtomicEvent, timeout: ?u64) !void {
-        var state = @atomicLoad(State, &self.state, .Monotonic);
-        while (state == .Empty) {
-            state = @cmpxchgWeak(State, &self.state, .Empty, .Waiting, .Acquire, .Monotonic) orelse 
-                return Futex.wait(@ptrCast(*i32, &self.state), @enumToInt(State.Waiting), timeout);
+        var waiters = @atomicLoad(u32, &self.waiters, .Acquire);
+        while (waiters != WAKE) {
+            waiters = @cmpxchgWeak(u32, &self.waiters, waiters, waiters + WAIT, .Acquire, .Acquire)
+                orelse return Futex.wait(&self.waiters, timeout);
         }
     }
 
@@ -225,15 +224,15 @@ const AtomicEvent = struct {
     };
 
     const SpinFutex = struct {
-        fn wake(ptr: *i32) void {}
+        fn wake(waiters: *u32, wake_count: u32) void {}
 
-        fn wait(ptr: *i32, expected: i32, timeout: ?u64) !void {
+        fn wait(waiters: *u32, timeout: ?u64) !void {
             // TODO: handle platforms where a monotonic timer isnt available
             var timer: time.Timer = undefined;
             if (timeout != null)
                 timer = time.Timer.start() catch unreachable;
 
-            while (@atomicLoad(i32, ptr, .Acquire) == expected) {
+            while (@atomicLoad(u32, waiters, .Acquire) != WAKE) {
                 SpinLock.yield();
                 if (timeout) |timeout_ns| {
                     if (timer.read() >= timeout_ns)
@@ -244,12 +243,14 @@ const AtomicEvent = struct {
     };
 
     const LinuxFutex = struct {
-        fn wake(ptr: *i32) void {
-            const rc = linux.futex_wake(ptr, linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG, 1);
+        fn wake(waiters: *u32, wake_count: u32) void {
+            const waiting = std.math.maxInt(i32); // wake_count
+            const ptr = @ptrCast(*const i32, waiters);
+            const rc = linux.futex_wake(ptr, linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG, waiting);
             assert(linux.getErrno(rc) == 0);
         }
 
-        fn wait(ptr: *i32, expected: i32, timeout: ?u64) !void {
+        fn wait(waiters: *u32, timeout: ?u64) !void {
             var ts: linux.timespec = undefined;
             var ts_ptr: ?*linux.timespec = null;
             if (timeout) |timeout_ns| {
@@ -258,7 +259,12 @@ const AtomicEvent = struct {
                 ts.tv_nsec = @intCast(isize, timeout_ns % time.ns_per_s);
             }
 
-            while (@atomicLoad(i32, ptr, .Acquire) == expected) {
+            while (true) {
+                const waiting = @atomicLoad(u32, waiters, .Acquire);
+                if (waiting == WAKE)
+                    return;
+                const expected = @intCast(i32, waiting);
+                const ptr = @ptrCast(*const i32, waiters);
                 const rc = linux.futex_wait(ptr, linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG, expected, ts_ptr);
                 switch (linux.getErrno(rc)) {
                     0 => continue,
@@ -272,15 +278,20 @@ const AtomicEvent = struct {
     };
 
     const WindowsFutex = struct {
-        pub fn wake(ptr: *i32) void {
-            const handle = getEventHandle() orelse return SpinFutex.wake(ptr);
+        pub fn wake(waiters: *u32, wake_count: u32) void {
+            const handle = getEventHandle() orelse return SpinFutex.wake(waiters, wake_count);
             const key = @ptrCast(*const c_void, ptr);
-            const rc = windows.ntdll.NtReleaseKeyedEvent(handle, key, windows.FALSE, null);
-            assert(rc == 0);
+            
+            var waiting = wake_count;
+            while (waiting != 0) : (waiting -= 1) {
+                const rc = windows.ntdll.NtReleaseKeyedEvent(handle, key, windows.FALSE, null);
+                assert(rc == 0);
+            }
         }
 
-        pub fn wait(ptr: *i32, expected: i32, timeout: ?u64) !void {
-            const handle = getEventHandle() orelse return SpinFutex.wait(ptr, expected, timeout);
+        pub fn wait(waiters: *u32, timeout: ?u64) !void {
+            const handle = getEventHandle() orelse return SpinFutex.wait(waiters, timeout);
+            const key = @ptrCast(*const c_void, ptr);
 
             // NT uses timeouts in units of 100ns with negative value being relative
             var timeout_ptr: ?*windows.LARGE_INTEGER = null;
@@ -291,10 +302,26 @@ const AtomicEvent = struct {
             }
 
             // NtWaitForKeyedEvent doesnt have spurious wake-ups
-            const key = @ptrCast(*const c_void, ptr);
-            const rc = windows.ntdll.NtWaitForKeyedEvent(handle, key, windows.FALSE, timeout_ptr);
+            var rc = windows.ntdll.NtWaitForKeyedEvent(handle, key, windows.FALSE, timeout_ptr);
             switch (rc) {
-                windows.WAIT_TIMEOUT => return error.TimedOut,
+                windows.WAIT_TIMEOUT => {
+                    // update the wait count to signal that we're not waiting anymore.
+                    // if the .set() thread already observed that we are, perform a
+                    // matching NtWaitForKeyedEvent so that the .set() thread doesn't
+                    // deadlock trying to run NtReleaseKeyedEvent above.
+                    var waiting = @atomicLoad(u32, waiters, .Monotonic);
+                    while (true) {
+                        if (waiting == WAKE) {
+                            rc = windows.ntdll.NtWaitForKeyedEvent(handle, key, windows.FALSE, null);
+                            assert(rc == windows.WAIT_OBJECT_0);
+                            break;
+                        } else {
+                            waiting = @cmpxchgWeak(u32, waiters, waiting, waiting - WAIT, .Acquire, .Monotonic) orelse break;
+                            continue;
+                        }
+                    }
+                    return error.TimedOut;
+                },
                 windows.WAIT_OBJECT_0 => {},
                 else => unreachable,
             }