Commit f84aca36c3

Jacob Young <jacobly0@users.noreply.github.com>
2025-03-31 14:06:20
Io: implement faster mutex
1 parent a1c1d06
Changed files (3)
lib
lib/std/Io/EventLoop.zig
@@ -10,7 +10,7 @@ const IoUring = std.os.linux.IoUring;
 /// Must be a thread-safe allocator.
 gpa: Allocator,
 mutex: std.Thread.Mutex,
-main_fiber: Fiber,
+main_fiber_buffer: [@sizeOf(Fiber) + Fiber.max_result_size]u8 align(@alignOf(Fiber)),
 threads: Thread.List,
 
 /// Empirically saw >128KB being used by the self-hosted backend to panic.
@@ -51,10 +51,12 @@ const Thread = struct {
 };
 
 const Fiber = struct {
+    required_align: void align(4),
     context: Context,
     awaiter: ?*Fiber,
     queue_next: ?*Fiber,
     cancel_thread: ?*Thread,
+    awaiting_completions: std.StaticBitSet(3),
 
     const finished: ?*Fiber = @ptrFromInt(@alignOf(Thread));
 
@@ -131,7 +133,7 @@ const Fiber = struct {
         const thread: *Thread = .current();
         std.log.debug("recyling {*}", .{fiber});
         assert(fiber.queue_next == null);
-        @memset(fiber.allocatedSlice(), undefined);
+        //@memset(fiber.allocatedSlice(), undefined); // (race)
         fiber.queue_next = thread.free_queue;
         thread.free_queue = fiber;
     }
@@ -145,10 +147,17 @@ pub fn io(el: *EventLoop) Io {
         .vtable = &.{
             .@"async" = @"async",
             .@"await" = @"await",
+            .go = go,
 
             .cancel = cancel,
             .cancelRequested = cancelRequested,
 
+            .mutexLock = mutexLock,
+            .mutexUnlock = mutexUnlock,
+
+            .conditionWait = conditionWait,
+            .conditionWake = conditionWake,
+
             .createFile = createFile,
             .openFile = openFile,
             .closeFile = closeFile,
@@ -169,18 +178,22 @@ pub fn init(el: *EventLoop, gpa: Allocator) !void {
     el.* = .{
         .gpa = gpa,
         .mutex = .{},
-        .main_fiber = .{
-            .context = undefined,
-            .awaiter = null,
-            .queue_next = null,
-            .cancel_thread = null,
-        },
+        .main_fiber_buffer = undefined,
         .threads = .{
             .allocated = @ptrCast(allocated_slice[0..threads_size]),
             .reserved = 1,
             .active = 1,
         },
     };
+    const main_fiber: *Fiber = @ptrCast(&el.main_fiber_buffer);
+    main_fiber.* = .{
+        .required_align = {},
+        .context = undefined,
+        .awaiter = null,
+        .queue_next = null,
+        .cancel_thread = null,
+        .awaiting_completions = .initEmpty(),
+    };
     const main_thread = &el.threads.allocated[0];
     Thread.self = main_thread;
     const idle_stack_end: [*]usize = @alignCast(@ptrCast(allocated_slice[idle_stack_end_offset..].ptr));
@@ -192,7 +205,7 @@ pub fn init(el: *EventLoop, gpa: Allocator) !void {
             .rbp = 0,
             .rip = @intFromPtr(&mainIdleEntry),
         },
-        .current_context = &el.main_fiber.context,
+        .current_context = &main_fiber.context,
         .ready_queue = null,
         .free_queue = null,
         .io_uring = try IoUring.init(io_uring_entries, 0),
@@ -201,53 +214,57 @@ pub fn init(el: *EventLoop, gpa: Allocator) !void {
     };
     errdefer main_thread.io_uring.deinit();
     std.log.debug("created main idle {*}", .{&main_thread.idle_context});
-    std.log.debug("created main {*}", .{&el.main_fiber});
+    std.log.debug("created main {*}", .{main_fiber});
 }
 
 pub fn deinit(el: *EventLoop) void {
     const active_threads = @atomicLoad(u32, &el.threads.active, .acquire);
-    for (el.threads.allocated[0..active_threads]) |*thread|
-        assert(@atomicLoad(?*Fiber, &thread.ready_queue, .acquire) == null); // pending async
+    for (el.threads.allocated[0..active_threads]) |*thread| {
+        const ready_fiber = @atomicLoad(?*Fiber, &thread.ready_queue, .monotonic);
+        assert(ready_fiber == null or ready_fiber == Fiber.finished); // pending async
+    }
     el.yield(null, .exit);
+    const allocated_ptr: [*]align(@alignOf(Thread)) u8 = @alignCast(@ptrCast(el.threads.allocated.ptr));
+    const idle_stack_end_offset = std.mem.alignForward(usize, el.threads.allocated.len * @sizeOf(Thread) + idle_stack_size, std.heap.page_size_max);
+    for (el.threads.allocated[1..active_threads]) |*thread| thread.thread.join();
     for (el.threads.allocated[0..active_threads]) |*thread| while (thread.free_queue) |free_fiber| {
         thread.free_queue = free_fiber.queue_next;
         free_fiber.queue_next = null;
         el.gpa.free(free_fiber.allocatedSlice());
     };
-    const allocated_ptr: [*]align(@alignOf(Thread)) u8 = @alignCast(@ptrCast(el.threads.allocated.ptr));
-    const idle_stack_end_offset = std.mem.alignForward(usize, el.threads.allocated.len * @sizeOf(Thread) + idle_stack_size, std.heap.page_size_max);
-    for (el.threads.allocated[1..active_threads]) |thread| thread.thread.join();
     el.gpa.free(allocated_ptr[0..idle_stack_end_offset]);
     el.* = undefined;
 }
 
+fn findReadyFiber(el: *EventLoop, thread: *Thread) ?*Fiber {
+    if (@atomicRmw(?*Fiber, &thread.ready_queue, .Xchg, Fiber.finished, .acquire)) |ready_fiber| {
+        @atomicStore(?*Fiber, &thread.ready_queue, ready_fiber.queue_next, .release);
+        ready_fiber.queue_next = null;
+        return ready_fiber;
+    }
+    const active_threads = @atomicLoad(u32, &el.threads.active, .acquire);
+    for (0..@min(max_steal_ready_search, active_threads)) |_| {
+        defer thread.steal_ready_search_index += 1;
+        if (thread.steal_ready_search_index == active_threads) thread.steal_ready_search_index = 0;
+        const steal_ready_search_thread = &el.threads.allocated[0..active_threads][thread.steal_ready_search_index];
+        if (steal_ready_search_thread == thread) continue;
+        const ready_fiber = @atomicRmw(?*Fiber, &steal_ready_search_thread.ready_queue, .And, Fiber.finished, .acquire) orelse continue;
+        if (ready_fiber == Fiber.finished) continue;
+        @atomicStore(?*Fiber, &thread.ready_queue, ready_fiber.queue_next, .release);
+        ready_fiber.queue_next = null;
+        return ready_fiber;
+    }
+    // couldn't find anything to do, so we are now open for business
+    @atomicStore(?*Fiber, &thread.ready_queue, null, .monotonic);
+    return null;
+}
+
 fn yield(el: *EventLoop, maybe_ready_fiber: ?*Fiber, pending_task: SwitchMessage.PendingTask) void {
     const thread: *Thread = .current();
-    const ready_context: *Context = if (maybe_ready_fiber) |ready_fiber|
+    const ready_context = if (maybe_ready_fiber orelse el.findReadyFiber(thread)) |ready_fiber|
         &ready_fiber.context
-    else if (thread.ready_queue) |ready_fiber| ready_context: {
-        thread.ready_queue = ready_fiber.queue_next;
-        ready_fiber.queue_next = null;
-        break :ready_context &ready_fiber.context;
-    } else ready_context: {
-        const ready_threads = @atomicLoad(u32, &el.threads.active, .acquire);
-        break :ready_context for (0..max_steal_ready_search) |_| {
-            defer thread.steal_ready_search_index += 1;
-            if (thread.steal_ready_search_index == ready_threads) thread.steal_ready_search_index = 0;
-            const steal_ready_search_thread = &el.threads.allocated[thread.steal_ready_search_index];
-            if (steal_ready_search_thread == thread) continue;
-            const ready_fiber = @atomicLoad(?*Fiber, &steal_ready_search_thread.ready_queue, .acquire) orelse continue;
-            if (@cmpxchgWeak(
-                ?*Fiber,
-                &steal_ready_search_thread.ready_queue,
-                ready_fiber,
-                @atomicLoad(?*Fiber, &ready_fiber.queue_next, .acquire),
-                .acq_rel,
-                .monotonic,
-            )) |_| continue;
-            break &ready_fiber.context;
-        } else &thread.idle_context;
-    };
+    else
+        &thread.idle_context;
     const message: SwitchMessage = .{
         .contexts = .{
             .prev = thread.current_context,
@@ -270,10 +287,10 @@ fn schedule(el: *EventLoop, thread: *Thread, ready_queue: Fiber.Queue) void {
     }
     // shared fields of previous `Thread` must be initialized before later ones are marked as active
     const new_thread_index = @atomicLoad(u32, &el.threads.active, .acquire);
-    for (0..max_idle_search) |_| {
+    for (0..@min(max_idle_search, new_thread_index)) |_| {
         defer thread.idle_search_index += 1;
         if (thread.idle_search_index == new_thread_index) thread.idle_search_index = 0;
-        const idle_search_thread = &el.threads.allocated[thread.idle_search_index];
+        const idle_search_thread = &el.threads.allocated[0..new_thread_index][thread.idle_search_index];
         if (idle_search_thread == thread) continue;
         if (@cmpxchgWeak(
             ?*Fiber,
@@ -325,8 +342,8 @@ fn schedule(el: *EventLoop, thread: *Thread, ready_queue: Fiber.Queue) void {
                 std.log.warn("unable to create worker thread due to io_uring init failure: {s}", .{@errorName(err)});
                 break :spawn_thread;
             },
-            .idle_search_index = next_thread_index,
-            .steal_ready_search_index = next_thread_index,
+            .idle_search_index = 0,
+            .steal_ready_search_index = 0,
         };
         new_thread.thread = std.Thread.spawn(.{
             .stack_size = idle_stack_size,
@@ -357,7 +374,7 @@ fn mainIdle(el: *EventLoop, message: *const SwitchMessage) callconv(.withStackAl
     message.handle(el);
     const thread: *Thread = &el.threads.allocated[0];
     el.idle(thread);
-    el.yield(&el.main_fiber, .nothing);
+    el.yield(@ptrCast(&el.main_fiber_buffer), .nothing);
     unreachable; // switched to dead fiber
 }
 
@@ -384,8 +401,10 @@ const Completion = struct {
 fn idle(el: *EventLoop, thread: *Thread) void {
     var maybe_ready_fiber: ?*Fiber = null;
     while (true) {
-        el.yield(maybe_ready_fiber, .nothing);
-        maybe_ready_fiber = null;
+        while (maybe_ready_fiber orelse el.findReadyFiber(thread)) |ready_fiber| {
+            el.yield(ready_fiber, .nothing);
+            maybe_ready_fiber = null;
+        }
         _ = thread.io_uring.submit_and_wait(1) catch |err| switch (err) {
             error.SignalInterrupt => std.log.warn("submit_and_wait failed with SignalInterrupt", .{}),
             else => |e| @panic(@errorName(e)),
@@ -450,7 +469,12 @@ const SwitchMessage = struct {
 
     const PendingTask = union(enum) {
         nothing,
+        reschedule,
         register_awaiter: *?*Fiber,
+        lock_mutex: struct {
+            prev_state: Io.Mutex.State,
+            mutex: *Io.Mutex,
+        },
         exit,
     };
 
@@ -459,8 +483,14 @@ const SwitchMessage = struct {
         thread.current_context = message.contexts.ready;
         switch (message.pending_task) {
             .nothing => {},
+            .reschedule => {
+                const prev_fiber: *Fiber = @alignCast(@fieldParentPtr("context", message.contexts.prev));
+                assert(prev_fiber.queue_next == null);
+                el.schedule(thread, .{ .head = prev_fiber, .tail = prev_fiber });
+            },
             .register_awaiter => |awaiter| {
                 const prev_fiber: *Fiber = @alignCast(@fieldParentPtr("context", message.contexts.prev));
+                assert(prev_fiber.queue_next == null);
                 if (@atomicRmw(
                     ?*Fiber,
                     awaiter,
@@ -469,6 +499,36 @@ const SwitchMessage = struct {
                     .acq_rel,
                 ) == Fiber.finished) el.schedule(thread, .{ .head = prev_fiber, .tail = prev_fiber });
             },
+            .lock_mutex => |lock_mutex| {
+                const prev_fiber: *Fiber = @alignCast(@fieldParentPtr("context", message.contexts.prev));
+                assert(prev_fiber.queue_next == null);
+                var prev_state = lock_mutex.prev_state;
+                while (switch (prev_state) {
+                    else => next_state: {
+                        prev_fiber.queue_next = @ptrFromInt(@intFromEnum(prev_state));
+                        break :next_state @cmpxchgWeak(
+                            Io.Mutex.State,
+                            &lock_mutex.mutex.state,
+                            prev_state,
+                            @enumFromInt(@intFromPtr(prev_fiber)),
+                            .release,
+                            .acquire,
+                        );
+                    },
+                    .unlocked => @cmpxchgWeak(
+                        Io.Mutex.State,
+                        &lock_mutex.mutex.state,
+                        .unlocked,
+                        .locked_once,
+                        .acquire,
+                        .acquire,
+                    ) orelse {
+                        prev_fiber.queue_next = null;
+                        el.schedule(thread, .{ .head = prev_fiber, .tail = prev_fiber });
+                        return;
+                    },
+                }) |next_state| prev_state = next_state;
+            },
             .exit => for (el.threads.allocated[0..@atomicLoad(u32, &el.threads.active, .acquire)]) |*each_thread| {
                 getSqe(&thread.io_uring).* = .{
                     .opcode = .MSG_RING,
@@ -590,13 +650,13 @@ fn @"async"(
         start(context.ptr, result.ptr);
         return null;
     };
-    errdefer fiber.recycle();
     std.log.debug("allocated {*}", .{fiber});
 
     const closure: *AsyncClosure = @ptrFromInt(Fiber.max_context_align.max(.of(AsyncClosure)).backward(
         @intFromPtr(fiber.allocatedEnd()) - Fiber.max_context_size,
     ) - @sizeOf(AsyncClosure));
     fiber.* = .{
+        .required_align = {},
         .context = switch (builtin.cpu.arch) {
             .x86_64 => .{
                 .rsp = @intFromPtr(closure) - @sizeOf(usize),
@@ -608,6 +668,7 @@ fn @"async"(
         .awaiter = null,
         .queue_next = null,
         .cancel_thread = null,
+        .awaiting_completions = .initEmpty(),
     };
     closure.* = .{
         .event_loop = event_loop,
@@ -634,6 +695,19 @@ fn @"await"(
     future_fiber.recycle();
 }
 
+fn go(
+    userdata: ?*anyopaque,
+    context: []const u8,
+    context_alignment: std.mem.Alignment,
+    start: *const fn (context: *const anyopaque) void,
+) void {
+    _ = userdata;
+    _ = context;
+    _ = context_alignment;
+    _ = start;
+    @panic("TODO");
+}
+
 fn cancel(
     userdata: ?*anyopaque,
     any_future: *std.Io.AnyFuture,
@@ -673,7 +747,7 @@ fn cancelRequested(userdata: ?*anyopaque) bool {
     return @atomicLoad(?*Thread, &Thread.current().currentFiber().cancel_thread, .acquire) == Thread.canceling;
 }
 
-pub fn createFile(
+fn createFile(
     userdata: ?*anyopaque,
     dir: std.fs.Dir,
     sub_path: []const u8,
@@ -775,7 +849,7 @@ pub fn createFile(
     }
 }
 
-pub fn openFile(
+fn openFile(
     userdata: ?*anyopaque,
     dir: std.fs.Dir,
     sub_path: []const u8,
@@ -883,7 +957,7 @@ pub fn openFile(
     }
 }
 
-pub fn closeFile(userdata: ?*anyopaque, file: std.fs.File) void {
+fn closeFile(userdata: ?*anyopaque, file: std.fs.File) void {
     const el: *EventLoop = @alignCast(@ptrCast(userdata));
     const thread: *Thread = .current();
     const iou = &thread.io_uring;
@@ -919,7 +993,7 @@ pub fn closeFile(userdata: ?*anyopaque, file: std.fs.File) void {
     }
 }
 
-pub fn pread(userdata: ?*anyopaque, file: std.fs.File, buffer: []u8, offset: std.posix.off_t) Io.FilePReadError!usize {
+fn pread(userdata: ?*anyopaque, file: std.fs.File, buffer: []u8, offset: std.posix.off_t) Io.FilePReadError!usize {
     const el: *EventLoop = @alignCast(@ptrCast(userdata));
     const thread: *Thread = .current();
     const iou = &thread.io_uring;
@@ -971,7 +1045,7 @@ pub fn pread(userdata: ?*anyopaque, file: std.fs.File, buffer: []u8, offset: std
     }
 }
 
-pub fn pwrite(userdata: ?*anyopaque, file: std.fs.File, buffer: []const u8, offset: std.posix.off_t) Io.FilePWriteError!usize {
+fn pwrite(userdata: ?*anyopaque, file: std.fs.File, buffer: []const u8, offset: std.posix.off_t) Io.FilePWriteError!usize {
     const el: *EventLoop = @alignCast(@ptrCast(userdata));
     const thread: *Thread = .current();
     const iou = &thread.io_uring;
@@ -1027,13 +1101,13 @@ pub fn pwrite(userdata: ?*anyopaque, file: std.fs.File, buffer: []const u8, offs
     }
 }
 
-pub fn now(userdata: ?*anyopaque, clockid: std.posix.clockid_t) Io.ClockGetTimeError!Io.Timestamp {
+fn now(userdata: ?*anyopaque, clockid: std.posix.clockid_t) Io.ClockGetTimeError!Io.Timestamp {
     _ = userdata;
     const timespec = try std.posix.clock_gettime(clockid);
     return @enumFromInt(@as(i128, timespec.sec) * std.time.ns_per_s + timespec.nsec);
 }
 
-pub fn sleep(userdata: ?*anyopaque, clockid: std.posix.clockid_t, deadline: Io.Deadline) Io.SleepError!void {
+fn sleep(userdata: ?*anyopaque, clockid: std.posix.clockid_t, deadline: Io.Deadline) Io.SleepError!void {
     const el: *EventLoop = @alignCast(@ptrCast(userdata));
     const thread: *Thread = .current();
     const iou = &thread.io_uring;
@@ -1086,10 +1160,65 @@ pub fn sleep(userdata: ?*anyopaque, clockid: std.posix.clockid_t, deadline: Io.D
     }
 }
 
+fn mutexLock(userdata: ?*anyopaque, prev_state: Io.Mutex.State, mutex: *Io.Mutex) error{Canceled}!void {
+    const el: *EventLoop = @alignCast(@ptrCast(userdata));
+    el.yield(null, .{ .lock_mutex = .{
+        .prev_state = prev_state,
+        .mutex = mutex,
+    } });
+}
+fn mutexUnlock(userdata: ?*anyopaque, prev_state: Io.Mutex.State, mutex: *Io.Mutex) void {
+    var maybe_waiting_fiber: ?*Fiber = @ptrFromInt(@intFromEnum(prev_state));
+    while (if (maybe_waiting_fiber) |waiting_fiber| @cmpxchgWeak(
+        Io.Mutex.State,
+        &mutex.state,
+        @enumFromInt(@intFromPtr(waiting_fiber)),
+        @enumFromInt(@intFromPtr(waiting_fiber.queue_next)),
+        .release,
+        .acquire,
+    ) else @cmpxchgWeak(
+        Io.Mutex.State,
+        &mutex.state,
+        .locked_once,
+        .unlocked,
+        .release,
+        .acquire,
+    ) orelse return) |next_state| maybe_waiting_fiber = @ptrFromInt(@intFromEnum(next_state));
+    maybe_waiting_fiber.?.queue_next = null;
+    const el: *EventLoop = @alignCast(@ptrCast(userdata));
+    el.yield(maybe_waiting_fiber.?, .reschedule);
+}
+
+fn conditionWait(
+    userdata: ?*anyopaque,
+    cond: *Io.Condition,
+    mutex: *Io.Mutex,
+    timeout: ?u64,
+) Io.Condition.WaitError!void {
+    _ = userdata;
+    _ = cond;
+    _ = mutex;
+    _ = timeout;
+    @panic("TODO");
+}
+
+fn conditionWake(userdata: ?*anyopaque, cond: *Io.Condition, notify: Io.Condition.Notify) void {
+    _ = userdata;
+    _ = cond;
+    _ = notify;
+    @panic("TODO");
+}
+
 fn errno(signed: i32) std.os.linux.E {
     return .init(@bitCast(@as(isize, signed)));
 }
 
 fn getSqe(iou: *IoUring) *std.os.linux.io_uring_sqe {
-    return iou.get_sqe() catch @panic("TODO: handle submission queue full");
+    while (true) return iou.get_sqe() catch {
+        _ = iou.submit_and_wait(0) catch |err| switch (err) {
+            error.SignalInterrupt => std.log.warn("submit_and_wait failed with SignalInterrupt", .{}),
+            else => |e| @panic(@errorName(e)),
+        };
+        continue;
+    };
 }
lib/std/Thread/Pool.zig
@@ -335,8 +335,10 @@ pub fn io(pool: *Pool) Io {
             .go = go,
             .cancel = cancel,
             .cancelRequested = cancelRequested,
+
             .mutexLock = mutexLock,
             .mutexUnlock = mutexUnlock,
+
             .conditionWait = conditionWait,
             .conditionWake = conditionWake,
 
@@ -594,53 +596,26 @@ fn checkCancel(pool: *Pool) error{Canceled}!void {
     if (cancelRequested(pool)) return error.Canceled;
 }
 
-fn mutexLock(userdata: ?*anyopaque, m: *Io.Mutex) void {
-    @branchHint(.cold);
-    const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
-    _ = pool;
-
-    // Avoid doing an atomic swap below if we already know the state is contended.
-    // An atomic swap unconditionally stores which marks the cache-line as modified unnecessarily.
-    if (m.state.load(.monotonic) == Io.Mutex.contended) {
-        std.Thread.Futex.wait(&m.state, Io.Mutex.contended);
-    }
-
-    // Try to acquire the lock while also telling the existing lock holder that there are threads waiting.
-    //
-    // Once we sleep on the Futex, we must acquire the mutex using `contended` rather than `locked`.
-    // If not, threads sleeping on the Futex wouldn't see the state change in unlock and potentially deadlock.
-    // The downside is that the last mutex unlocker will see `contended` and do an unnecessary Futex wake
-    // but this is better than having to wake all waiting threads on mutex unlock.
-    //
-    // Acquire barrier ensures grabbing the lock happens before the critical section
-    // and that the previous lock holder's critical section happens before we grab the lock.
-    while (m.state.swap(Io.Mutex.contended, .acquire) != Io.Mutex.unlocked) {
-        std.Thread.Futex.wait(&m.state, Io.Mutex.contended);
+fn mutexLock(userdata: ?*anyopaque, prev_state: Io.Mutex.State, mutex: *Io.Mutex) error{Canceled}!void {
+    _ = userdata;
+    if (prev_state == .contended) {
+        std.Thread.Futex.wait(@ptrCast(&mutex.state), @intFromEnum(Io.Mutex.State.contended));
     }
-}
-
-fn mutexUnlock(userdata: ?*anyopaque, m: *Io.Mutex) void {
-    const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
-    _ = pool;
-    // Needs to also wake up a waiting thread if any.
-    //
-    // A waiting thread will acquire with `contended` instead of `locked`
-    // which ensures that it wakes up another thread on the next unlock().
-    //
-    // Release barrier ensures the critical section happens before we let go of the lock
-    // and that our critical section happens before the next lock holder grabs the lock.
-    const state = m.state.swap(Io.Mutex.unlocked, .release);
-    assert(state != Io.Mutex.unlocked);
-
-    if (state == Io.Mutex.contended) {
-        std.Thread.Futex.wake(&m.state, 1);
+    while (@atomicRmw(
+        Io.Mutex.State,
+        &mutex.state,
+        .Xchg,
+        .contended,
+        .acquire,
+    ) != .unlocked) {
+        std.Thread.Futex.wait(@ptrCast(&mutex.state), @intFromEnum(Io.Mutex.State.contended));
     }
 }
-
-fn mutexLockInternal(pool: *std.Thread.Pool, m: *Io.Mutex) void {
-    if (!m.tryLock()) {
-        @branchHint(.unlikely);
-        mutexLock(pool, m);
+fn mutexUnlock(userdata: ?*anyopaque, prev_state: Io.Mutex.State, mutex: *Io.Mutex) void {
+    _ = userdata;
+    _ = prev_state;
+    if (@atomicRmw(Io.Mutex.State, &mutex.state, .Xchg, .unlocked, .release) == .contended) {
+        std.Thread.Futex.wake(@ptrCast(&mutex.state), 1);
     }
 }
 
@@ -674,8 +649,8 @@ fn conditionWait(
     assert(state & waiter_mask != waiter_mask);
     state += one_waiter;
 
-    mutexUnlock(pool, mutex);
-    defer mutexLockInternal(pool, mutex);
+    mutex.unlock(pool.io());
+    defer mutex.lock(pool.io()) catch @panic("TODO");
 
     var futex_deadline = std.Thread.Futex.Deadline.init(timeout);
 
@@ -808,14 +783,14 @@ fn pwrite(userdata: ?*anyopaque, file: std.fs.File, buffer: []const u8, offset:
     };
 }
 
-pub fn now(userdata: ?*anyopaque, clockid: std.posix.clockid_t) Io.ClockGetTimeError!Io.Timestamp {
+fn now(userdata: ?*anyopaque, clockid: std.posix.clockid_t) Io.ClockGetTimeError!Io.Timestamp {
     const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
     try pool.checkCancel();
     const timespec = try std.posix.clock_gettime(clockid);
     return @enumFromInt(@as(i128, timespec.sec) * std.time.ns_per_s + timespec.nsec);
 }
 
-pub fn sleep(userdata: ?*anyopaque, clockid: std.posix.clockid_t, deadline: Io.Deadline) Io.SleepError!void {
+fn sleep(userdata: ?*anyopaque, clockid: std.posix.clockid_t, deadline: Io.Deadline) Io.SleepError!void {
     const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
     const deadline_nanoseconds: i96 = switch (deadline) {
         .nanoseconds => |nanoseconds| nanoseconds,
lib/std/Io.zig
@@ -626,8 +626,8 @@ pub const VTable = struct {
     /// Thread-safe.
     cancelRequested: *const fn (?*anyopaque) bool,
 
-    mutexLock: *const fn (?*anyopaque, mutex: *Mutex) void,
-    mutexUnlock: *const fn (?*anyopaque, mutex: *Mutex) void,
+    mutexLock: *const fn (?*anyopaque, prev_state: Mutex.State, mutex: *Mutex) error{Canceled}!void,
+    mutexUnlock: *const fn (?*anyopaque, prev_state: Mutex.State, mutex: *Mutex) void,
 
     conditionWait: *const fn (?*anyopaque, cond: *Condition, mutex: *Mutex, timeout_ns: ?u64) Condition.WaitError!void,
     conditionWake: *const fn (?*anyopaque, cond: *Condition, notify: Condition.Notify) void,
@@ -706,8 +706,63 @@ pub fn Future(Result: type) type {
     };
 }
 
-pub const Mutex = struct {
-    state: std.atomic.Value(u32) = std.atomic.Value(u32).init(unlocked),
+pub const Mutex = if (true) struct {
+    state: State,
+
+    pub const State = enum(usize) {
+        locked_once = 0b00,
+        unlocked = 0b01,
+        contended = 0b10,
+        /// contended
+        _,
+
+        pub fn isUnlocked(state: State) bool {
+            return @intFromEnum(state) & @intFromEnum(State.unlocked) == @intFromEnum(State.unlocked);
+        }
+    };
+
+    pub const init: Mutex = .{ .state = .unlocked };
+
+    pub fn tryLock(mutex: *Mutex) bool {
+        const prev_state: State = @enumFromInt(@atomicRmw(
+            usize,
+            @as(*usize, @ptrCast(&mutex.state)),
+            .And,
+            ~@intFromEnum(State.unlocked),
+            .acquire,
+        ));
+        return prev_state.isUnlocked();
+    }
+
+    pub fn lock(mutex: *Mutex, io: std.Io) error{Canceled}!void {
+        const prev_state: State = @enumFromInt(@atomicRmw(
+            usize,
+            @as(*usize, @ptrCast(&mutex.state)),
+            .And,
+            ~@intFromEnum(State.unlocked),
+            .acquire,
+        ));
+        if (prev_state.isUnlocked()) {
+            @branchHint(.likely);
+            return;
+        }
+        return io.vtable.mutexLock(io.userdata, prev_state, mutex);
+    }
+
+    pub fn unlock(mutex: *Mutex, io: std.Io) void {
+        const prev_state = @cmpxchgWeak(State, &mutex.state, .locked_once, .unlocked, .release, .acquire) orelse {
+            @branchHint(.likely);
+            return;
+        };
+        std.debug.assert(prev_state != .unlocked); // mutex not locked
+        return io.vtable.mutexUnlock(io.userdata, prev_state, mutex);
+    }
+} else struct {
+    state: std.atomic.Value(u32),
+
+    pub const State = void;
+
+    pub const init: Mutex = .{ .state = .init(unlocked) };
 
     pub const unlocked: u32 = 0b00;
     pub const locked: u32 = 0b01;
@@ -728,15 +783,15 @@ pub const Mutex = struct {
     }
 
     /// Avoids the vtable for uncontended locks.
-    pub fn lock(m: *Mutex, io: Io) void {
+    pub fn lock(m: *Mutex, io: Io) error{Canceled}!void {
         if (!m.tryLock()) {
             @branchHint(.unlikely);
-            io.vtable.mutexLock(io.userdata, m);
+            try io.vtable.mutexLock(io.userdata, {}, m);
         }
     }
 
     pub fn unlock(m: *Mutex, io: Io) void {
-        io.vtable.mutexUnlock(io.userdata, m);
+        io.vtable.mutexUnlock(io.userdata, {}, m);
     }
 };