Commit 629a20459d

Jacob Young <jacobly0@users.noreply.github.com>
2025-03-27 17:03:35
EventLoop: rewrite context switching
1 parent fe6f1ef
Changed files (1)
lib
lib/std/Io/EventLoop.zig
@@ -1,4 +1,5 @@
 const std = @import("../std.zig");
+const builtin = @import("builtin");
 const assert = std.debug.assert;
 const Allocator = std.mem.Allocator;
 const Io = std.Io;
@@ -16,7 +17,7 @@ const max_result_len = 64;
 const min_stack_size = 4 * 1024 * 1024;
 
 const Fiber = struct {
-    regs: Regs,
+    context: Context,
     awaiter: ?*Fiber,
     queue_node: std.DoublyLinkedList(void).Node,
 
@@ -66,34 +67,28 @@ fn allocateFiber(el: *EventLoop, result_len: usize) error{OutOfMemory}!*Fiber {
 }
 
 fn yield(el: *EventLoop, optional_fiber: ?*Fiber, register_awaiter: ?*?*Fiber) void {
+    const ready_fiber: *Fiber = optional_fiber orelse if (ready_node: {
+        el.mutex.lock();
+        defer el.mutex.unlock();
+        break :ready_node el.queue.pop();
+    }) |ready_node|
+        @fieldParentPtr("queue_node", ready_node)
+    else if (register_awaiter) |_| // time to switch to an idle fiber?
+        @panic("no other fiber to switch to in order to be able to register this fiber as an awaiter")
+    else // nothing to do
+        return;
     const message: SwitchMessage = .{
-        .ready_fiber = optional_fiber orelse if (ready_node: {
-            el.mutex.lock();
-            defer el.mutex.unlock();
-            break :ready_node el.queue.pop();
-        }) |ready_node|
-            @fieldParentPtr("queue_node", ready_node)
-        else if (register_awaiter) |_|
-            @panic("no other fiber to switch to in order to be able to register this fiber as an awaiter") // time to switch to an idle fiber?
-        else
-            return, // nothing to do
+        .prev_context = &current_fiber.context,
+        .ready_context = &ready_fiber.context,
         .register_awaiter = register_awaiter,
     };
-    std.log.debug("switching from {*} to {*}", .{ current_fiber, message.ready_fiber });
-    SwitchMessage.handle(@ptrFromInt(contextSwitch(&current_fiber.regs, &message.ready_fiber.regs, @intFromPtr(&message))), el);
+    std.log.debug("switching from {*} to {*}", .{
+        @as(*Fiber, @fieldParentPtr("context", message.prev_context)),
+        @as(*Fiber, @fieldParentPtr("context", message.ready_context)),
+    });
+    contextSwitch(&message).handle(el);
 }
 
-const SwitchMessage = struct {
-    ready_fiber: *Fiber,
-    register_awaiter: ?*?*Fiber,
-
-    fn handle(message: *const SwitchMessage, el: *EventLoop) void {
-        const prev_fiber = current_fiber;
-        current_fiber = message.ready_fiber;
-        if (message.register_awaiter) |awaiter| if (@atomicRmw(?*Fiber, awaiter, .Xchg, prev_fiber, .acq_rel) == Fiber.finished) el.schedule(prev_fiber);
-    }
-};
-
 fn schedule(el: *EventLoop, fiber: *Fiber) void {
     el.mutex.lock();
     defer el.mutex.unlock();
@@ -109,47 +104,62 @@ fn recycle(el: *EventLoop, fiber: *Fiber) void {
     el.free.append(&fiber.queue_node);
 }
 
-const Regs = extern struct {
+const SwitchMessage = extern struct {
+    prev_context: *Context,
+    ready_context: *Context,
+    register_awaiter: ?*?*Fiber,
+
+    fn handle(message: *const SwitchMessage, el: *EventLoop) void {
+        const prev_fiber: *Fiber = @fieldParentPtr("context", message.prev_context);
+        current_fiber = @fieldParentPtr("context", message.ready_context);
+        if (message.register_awaiter) |awaiter| if (@atomicRmw(?*Fiber, awaiter, .Xchg, prev_fiber, .acq_rel) == Fiber.finished) el.schedule(prev_fiber);
+    }
+};
+
+const Context = extern struct {
     rsp: usize,
-    r15: usize,
-    r14: usize,
-    r13: usize,
-    r12: usize,
-    rbx: usize,
     rbp: usize,
+    rip: usize,
 };
 
-const contextSwitch: *const fn (old: *Regs, new: *Regs, message: usize) callconv(.c) usize = @ptrCast(&contextSwitch_naked);
-
-noinline fn contextSwitch_naked() callconv(.naked) void {
-    asm volatile (
-        \\movq %%rsp, 0x00(%%rdi)
-        \\movq %%r15, 0x08(%%rdi)
-        \\movq %%r14, 0x10(%%rdi)
-        \\movq %%r13, 0x18(%%rdi)
-        \\movq %%r12, 0x20(%%rdi)
-        \\movq %%rbx, 0x28(%%rdi)
-        \\movq %%rbp, 0x30(%%rdi)
-        \\
-        \\movq 0x00(%%rsi), %%rsp
-        \\movq 0x08(%%rsi), %%r15
-        \\movq 0x10(%%rsi), %%r14
-        \\movq 0x18(%%rsi), %%r13
-        \\movq 0x20(%%rsi), %%r12
-        \\movq 0x28(%%rsi), %%rbx
-        \\movq 0x30(%%rsi), %%rbp
-        \\
-        \\movq %%rdx, %%rax
-        \\ret
-    );
+inline fn contextSwitch(message: *const SwitchMessage) *const SwitchMessage {
+    return switch (builtin.cpu.arch) {
+        .x86_64 => asm volatile (
+            \\ movq 0(%%rsi), %%rax
+            \\ movq 8(%%rsi), %%rcx
+            \\ leaq 0f(%%rip), %%rdx
+            \\ movq %%rsp, 0(%%rax)
+            \\ movq %%rbp, 8(%%rax)
+            \\ movq %%rdx, 16(%%rax)
+            \\ movq 0(%%rcx), %%rsp
+            \\ movq 8(%%rcx), %%rbp
+            \\ jmpq *16(%%rcx)
+            \\0:
+            : [received_message] "={rsi}" (-> *const SwitchMessage),
+            : [message_to_send] "{rsi}" (message),
+            : "rax", "rcx", "rdx", "rbx", "rdi", //
+            "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", //
+            "mm0", "mm1", "mm2", "mm3", "mm4", "mm5", "mm6", "mm7", //
+            "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", //
+            "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", //
+            "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", "zmm22", "zmm23", //
+            "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", //
+            "fpsr", "fpcr", "mxcsr", "rflags", "dirflag", "memory"
+        ),
+        else => |arch| @compileError("unimplemented architecture: " ++ @tagName(arch)),
+    };
 }
 
-fn popRet() callconv(.naked) void {
-    asm volatile (
-        \\pop %%rdi
-        \\movq %%rax, %%rsi
-        \\ret
-    );
+fn fiberEntry() callconv(.naked) void {
+    switch (builtin.cpu.arch) {
+        .x86_64 => asm volatile (
+            \\ leaq 8(%%rsp), %%rdi
+            \\ jmp %[AsyncClosure_call:P]
+            :
+            : [AsyncClosure_call] "X" (&AsyncClosure.call),
+        ),
+        else => |arch| @compileError("unimplemented architecture: " ++ @tagName(arch)),
+    }
 }
 
 pub fn @"async"(
@@ -179,21 +189,10 @@ pub fn @"async"(
         .start = start,
     };
     const stack_end: [*]align(16) usize = @alignCast(@ptrCast(closure));
-    const stack_top = (stack_end - 4)[0..4];
-    stack_top.* = .{
-        @intFromPtr(&popRet),
-        @intFromPtr(closure),
-        @intFromPtr(&AsyncClosure.call),
-        0,
-    };
-    fiber.regs = .{
-        .rsp = @intFromPtr(stack_top),
-        .r15 = 0,
-        .r14 = 0,
-        .r13 = 0,
-        .r12 = 0,
-        .rbx = 0,
+    fiber.context = .{
+        .rsp = @intFromPtr(stack_end - 1),
         .rbp = 0,
+        .rip = @intFromPtr(&fiberEntry),
     };
 
     event_loop.schedule(fiber);