Commit 7aa4062f5c

Andrew Kelley <andrew@ziglang.org>
2025-04-04 02:54:37
introduce Io.select and implement it in thread pool
1 parent c4fcf85
Changed files (2)
lib
std
lib/std/Thread/Pool.zig
@@ -335,6 +335,7 @@ pub fn io(pool: *Pool) Io {
             .go = go,
             .cancel = cancel,
             .cancelRequested = cancelRequested,
+            .select = select,
 
             .mutexLock = mutexLock,
             .mutexUnlock = mutexUnlock,
@@ -358,10 +359,13 @@ const AsyncClosure = struct {
     func: *const fn (context: *anyopaque, result: *anyopaque) void,
     runnable: Runnable = .{ .runFn = runFn },
     reset_event: std.Thread.ResetEvent,
+    select_condition: ?*std.Thread.ResetEvent,
     cancel_tid: std.Thread.Id,
     context_offset: usize,
     result_offset: usize,
 
+    const done_reset_event: *std.Thread.ResetEvent = @ptrFromInt(std.mem.alignBackward(usize, std.math.maxInt(usize), @alignOf(std.Thread.ResetEvent)));
+
     const canceling_tid: std.Thread.Id = switch (@typeInfo(std.Thread.Id)) {
         .int => |int_info| switch (int_info.signedness) {
             .signed => -1,
@@ -396,6 +400,17 @@ const AsyncClosure = struct {
             .acq_rel,
             .acquire,
         )) |cancel_tid| assert(cancel_tid == canceling_tid);
+
+        if (@atomicRmw(
+            ?*std.Thread.ResetEvent,
+            &closure.select_condition,
+            .Xchg,
+            done_reset_event,
+            .release,
+        )) |select_reset| {
+            assert(select_reset != done_reset_event);
+            select_reset.set();
+        }
         closure.reset_event.set();
     }
 
@@ -455,6 +470,7 @@ fn @"async"(
         .result_offset = result_offset,
         .reset_event = .{},
         .cancel_tid = 0,
+        .select_condition = null,
     };
     @memcpy(closure.contextPointer()[0..context.len], context);
     pool.run_queue.prepend(&closure.runnable.node);
@@ -720,47 +736,54 @@ fn conditionWake(userdata: ?*anyopaque, cond: *Io.Condition, wake: Io.Condition.
 
 fn createFile(
     userdata: ?*anyopaque,
-    dir: std.fs.Dir,
+    dir: Io.Dir,
     sub_path: []const u8,
-    flags: std.fs.File.CreateFlags,
-) Io.FileOpenError!std.fs.File {
+    flags: Io.File.CreateFlags,
+) Io.File.OpenError!Io.File {
     const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
     try pool.checkCancel();
-    return dir.createFile(sub_path, flags);
+    const fs_dir: std.fs.Dir = .{ .fd = dir.handle };
+    const fs_file = try fs_dir.createFile(sub_path, flags);
+    return .{ .handle = fs_file.handle };
 }
 
 fn openFile(
     userdata: ?*anyopaque,
-    dir: std.fs.Dir,
+    dir: Io.Dir,
     sub_path: []const u8,
-    flags: std.fs.File.OpenFlags,
-) Io.FileOpenError!std.fs.File {
+    flags: Io.File.OpenFlags,
+) Io.File.OpenError!Io.File {
     const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
     try pool.checkCancel();
-    return dir.openFile(sub_path, flags);
+    const fs_dir: std.fs.Dir = .{ .fd = dir.handle };
+    const fs_file = try fs_dir.openFile(sub_path, flags);
+    return .{ .handle = fs_file.handle };
 }
 
-fn closeFile(userdata: ?*anyopaque, file: std.fs.File) void {
+fn closeFile(userdata: ?*anyopaque, file: Io.File) void {
     const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
     _ = pool;
-    return file.close();
+    const fs_file: std.fs.File = .{ .handle = file.handle };
+    return fs_file.close();
 }
 
-fn pread(userdata: ?*anyopaque, file: std.fs.File, buffer: []u8, offset: std.posix.off_t) Io.FilePReadError!usize {
+fn pread(userdata: ?*anyopaque, file: Io.File, buffer: []u8, offset: std.posix.off_t) Io.File.PReadError!usize {
     const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
     try pool.checkCancel();
+    const fs_file: std.fs.File = .{ .handle = file.handle };
     return switch (offset) {
-        -1 => file.read(buffer),
-        else => file.pread(buffer, @bitCast(offset)),
+        -1 => fs_file.read(buffer),
+        else => fs_file.pread(buffer, @bitCast(offset)),
     };
 }
 
-fn pwrite(userdata: ?*anyopaque, file: std.fs.File, buffer: []const u8, offset: std.posix.off_t) Io.FilePWriteError!usize {
+fn pwrite(userdata: ?*anyopaque, file: Io.File, buffer: []const u8, offset: std.posix.off_t) Io.File.PWriteError!usize {
     const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
     try pool.checkCancel();
+    const fs_file: std.fs.File = .{ .handle = file.handle };
     return switch (offset) {
-        -1 => file.write(buffer),
-        else => file.pwrite(buffer, @bitCast(offset)),
+        -1 => fs_file.write(buffer),
+        else => fs_file.pwrite(buffer, @bitCast(offset)),
     };
 }
 
@@ -774,7 +797,7 @@ fn now(userdata: ?*anyopaque, clockid: std.posix.clockid_t) Io.ClockGetTimeError
 fn sleep(userdata: ?*anyopaque, clockid: std.posix.clockid_t, deadline: Io.Deadline) Io.SleepError!void {
     const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
     const deadline_nanoseconds: i96 = switch (deadline) {
-        .nanoseconds => |nanoseconds| nanoseconds,
+        .duration => |duration| duration.nanoseconds,
         .timestamp => |timestamp| @intFromEnum(timestamp),
     };
     var timespec: std.posix.timespec = .{
@@ -784,7 +807,7 @@ fn sleep(userdata: ?*anyopaque, clockid: std.posix.clockid_t, deadline: Io.Deadl
     while (true) {
         try pool.checkCancel();
         switch (std.os.linux.E.init(std.os.linux.clock_nanosleep(clockid, .{ .ABSTIME = switch (deadline) {
-            .nanoseconds => false,
+            .duration => false,
             .timestamp => true,
         } }, &timespec, &timespec))) {
             .SUCCESS => return,
@@ -795,3 +818,35 @@ fn sleep(userdata: ?*anyopaque, clockid: std.posix.clockid_t, deadline: Io.Deadl
         }
     }
 }
+
+fn select(userdata: ?*anyopaque, futures: []const *Io.AnyFuture) usize {
+    const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
+    _ = pool;
+
+    var reset_event: std.Thread.ResetEvent = .{};
+
+    for (futures, 0..) |future, i| {
+        const closure: *AsyncClosure = @ptrCast(@alignCast(future));
+        if (@atomicRmw(?*std.Thread.ResetEvent, &closure.select_condition, .Xchg, &reset_event, .seq_cst) == AsyncClosure.done_reset_event) {
+            for (futures[0..i]) |cleanup_future| {
+                const cleanup_closure: *AsyncClosure = @ptrCast(@alignCast(cleanup_future));
+                if (@atomicRmw(?*std.Thread.ResetEvent, &cleanup_closure.select_condition, .Xchg, null, .seq_cst) == AsyncClosure.done_reset_event) {
+                    cleanup_closure.reset_event.wait(); // Ensure no reference to our stack-allocated reset_event.
+                }
+            }
+            return i;
+        }
+    }
+
+    reset_event.wait();
+
+    var result: ?usize = null;
+    for (futures, 0..) |future, i| {
+        const closure: *AsyncClosure = @ptrCast(@alignCast(future));
+        if (@atomicRmw(?*std.Thread.ResetEvent, &closure.select_condition, .Xchg, null, .seq_cst) == AsyncClosure.done_reset_event) {
+            closure.reset_event.wait(); // Ensure no reference to our stack-allocated reset_event.
+            if (result == null) result = i; // In case multiple are ready, return first.
+        }
+    }
+    return result.?;
+}
lib/std/Io.zig
@@ -626,17 +626,21 @@ pub const VTable = struct {
     /// Thread-safe.
     cancelRequested: *const fn (?*anyopaque) bool,
 
+    /// Blocks until one of the futures from the list has a result ready, such
+    /// that awaiting it will not block. Returns that index.
+    select: *const fn (?*anyopaque, futures: []const *AnyFuture) usize,
+
     mutexLock: *const fn (?*anyopaque, prev_state: Mutex.State, mutex: *Mutex) Cancelable!void,
     mutexUnlock: *const fn (?*anyopaque, prev_state: Mutex.State, mutex: *Mutex) void,
 
     conditionWait: *const fn (?*anyopaque, cond: *Condition, mutex: *Mutex) Cancelable!void,
     conditionWake: *const fn (?*anyopaque, cond: *Condition, wake: Condition.Wake) void,
 
-    createFile: *const fn (?*anyopaque, dir: fs.Dir, sub_path: []const u8, flags: fs.File.CreateFlags) FileOpenError!fs.File,
-    openFile: *const fn (?*anyopaque, dir: fs.Dir, sub_path: []const u8, flags: fs.File.OpenFlags) FileOpenError!fs.File,
-    closeFile: *const fn (?*anyopaque, fs.File) void,
-    pread: *const fn (?*anyopaque, file: fs.File, buffer: []u8, offset: std.posix.off_t) FilePReadError!usize,
-    pwrite: *const fn (?*anyopaque, file: fs.File, buffer: []const u8, offset: std.posix.off_t) FilePWriteError!usize,
+    createFile: *const fn (?*anyopaque, dir: Dir, sub_path: []const u8, flags: File.CreateFlags) File.OpenError!File,
+    openFile: *const fn (?*anyopaque, dir: Dir, sub_path: []const u8, flags: File.OpenFlags) File.OpenError!File,
+    closeFile: *const fn (?*anyopaque, File) void,
+    pread: *const fn (?*anyopaque, file: File, buffer: []u8, offset: std.posix.off_t) File.PReadError!usize,
+    pwrite: *const fn (?*anyopaque, file: File, buffer: []const u8, offset: std.posix.off_t) File.PWriteError!usize,
 
     now: *const fn (?*anyopaque, clockid: std.posix.clockid_t) ClockGetTimeError!Timestamp,
     sleep: *const fn (?*anyopaque, clockid: std.posix.clockid_t, deadline: Deadline) SleepError!void,
@@ -647,28 +651,118 @@ pub const Cancelable = error{
     Canceled,
 };
 
-pub const OpenFlags = fs.File.OpenFlags;
-pub const CreateFlags = fs.File.CreateFlags;
+pub const Dir = struct {
+    handle: Handle,
+
+    pub fn cwd() Dir {
+        return .{ .handle = std.fs.cwd().fd };
+    }
+
+    pub const Handle = std.posix.fd_t;
+
+    pub fn openFile(dir: Dir, io: Io, sub_path: []const u8, flags: File.OpenFlags) File.OpenError!File {
+        return io.vtable.openFile(io.userdata, dir, sub_path, flags);
+    }
+
+    pub fn createFile(dir: Dir, io: Io, sub_path: []const u8, flags: File.CreateFlags) File.OpenError!File {
+        return io.vtable.createFile(io.userdata, dir, sub_path, flags);
+    }
+
+    pub const WriteFileOptions = struct {
+        /// On Windows, `sub_path` should be encoded as [WTF-8](https://simonsapin.github.io/wtf-8/).
+        /// On WASI, `sub_path` should be encoded as valid UTF-8.
+        /// On other platforms, `sub_path` is an opaque sequence of bytes with no particular encoding.
+        sub_path: []const u8,
+        data: []const u8,
+        flags: File.CreateFlags = .{},
+    };
+
+    pub const WriteFileError = File.WriteError || File.OpenError || Cancelable;
+
+    /// Writes content to the file system, using the file creation flags provided.
+    pub fn writeFile(dir: Dir, io: Io, options: WriteFileOptions) WriteFileError!void {
+        var file = try dir.createFile(io, options.sub_path, options.flags);
+        defer file.close(io);
+        try file.writeAll(io, options.data);
+    }
+};
+
+pub const File = struct {
+    handle: Handle,
+
+    pub const Handle = std.posix.fd_t;
+
+    pub const OpenFlags = fs.File.OpenFlags;
+    pub const CreateFlags = fs.File.CreateFlags;
+
+    pub const OpenError = fs.File.OpenError || Cancelable;
+
+    pub fn close(file: File, io: Io) void {
+        return io.vtable.closeFile(io.userdata, file);
+    }
+
+    pub const ReadError = fs.File.ReadError || Cancelable;
+
+    pub fn read(file: File, io: Io, buffer: []u8) ReadError!usize {
+        return @errorCast(file.pread(io, buffer, -1));
+    }
+
+    pub const PReadError = fs.File.PReadError || Cancelable;
+
+    pub fn pread(file: File, io: Io, buffer: []u8, offset: std.posix.off_t) PReadError!usize {
+        return io.vtable.pread(io.userdata, file, buffer, offset);
+    }
+
+    pub const WriteError = fs.File.WriteError || Cancelable;
+
+    pub fn write(file: File, io: Io, buffer: []const u8) WriteError!usize {
+        return @errorCast(file.pwrite(io, buffer, -1));
+    }
+
+    pub const PWriteError = fs.File.PWriteError || Cancelable;
+
+    pub fn pwrite(file: File, io: Io, buffer: []const u8, offset: std.posix.off_t) PWriteError!usize {
+        return io.vtable.pwrite(io.userdata, file, buffer, offset);
+    }
+
+    pub fn writeAll(file: File, io: Io, bytes: []const u8) WriteError!void {
+        var index: usize = 0;
+        while (index < bytes.len) {
+            index += try file.write(io, bytes[index..]);
+        }
+    }
 
-pub const FileOpenError = fs.File.OpenError || Cancelable;
-pub const FileReadError = fs.File.ReadError || Cancelable;
-pub const FilePReadError = fs.File.PReadError || Cancelable;
-pub const FileWriteError = fs.File.WriteError || Cancelable;
-pub const FilePWriteError = fs.File.PWriteError || Cancelable;
+    pub fn readAll(file: File, io: Io, buffer: []u8) ReadError!usize {
+        var index: usize = 0;
+        while (index != buffer.len) {
+            const amt = try file.read(io, buffer[index..]);
+            if (amt == 0) break;
+            index += amt;
+        }
+        return index;
+    }
+};
 
 pub const Timestamp = enum(i96) {
     _,
 
-    pub fn durationTo(from: Timestamp, to: Timestamp) i96 {
-        return @intFromEnum(to) - @intFromEnum(from);
+    pub fn durationTo(from: Timestamp, to: Timestamp) Duration {
+        return .{ .nanoseconds = @intFromEnum(to) - @intFromEnum(from) };
     }
 
-    pub fn addDuration(from: Timestamp, duration: i96) Timestamp {
-        return @enumFromInt(@intFromEnum(from) + duration);
+    pub fn addDuration(from: Timestamp, duration: Duration) Timestamp {
+        return @enumFromInt(@intFromEnum(from) + duration.nanoseconds);
     }
 };
-pub const Deadline = union(enum) {
+pub const Duration = struct {
     nanoseconds: i96,
+
+    pub fn ms(x: u64) Duration {
+        return .{ .nanoseconds = @as(i96, x) * std.time.ns_per_ms };
+    }
+};
+pub const Deadline = union(enum) {
+    duration: Duration,
     timestamp: Timestamp,
 };
 pub const ClockGetTimeError = std.posix.ClockGetTimeError || Cancelable;
@@ -1055,7 +1149,7 @@ pub fn Queue(Elem: type) type {
 
 /// Calls `function` with `args`, such that the return value of the function is
 /// not guaranteed to be available until `await` is called.
-pub fn async(io: Io, function: anytype, args: anytype) Future(@typeInfo(@TypeOf(function)).@"fn".return_type.?) {
+pub fn async(io: Io, function: anytype, args: std.meta.ArgsTuple(@TypeOf(function))) Future(@typeInfo(@TypeOf(function)).@"fn".return_type.?) {
     const Result = @typeInfo(@TypeOf(function)).@"fn".return_type.?;
     const Args = @TypeOf(args);
     const TypeErased = struct {
@@ -1079,7 +1173,7 @@ pub fn async(io: Io, function: anytype, args: anytype) Future(@typeInfo(@TypeOf(
 
 /// Calls `function` with `args` asynchronously. The resource cleans itself up
 /// when the function returns. Does not support await, cancel, or a return value.
-pub fn go(io: Io, function: anytype, args: anytype) void {
+pub fn go(io: Io, function: anytype, args: std.meta.ArgsTuple(@TypeOf(function))) void {
     const Args = @TypeOf(args);
     const TypeErased = struct {
         fn start(context: *const anyopaque) void {
@@ -1095,55 +1189,56 @@ pub fn go(io: Io, function: anytype, args: anytype) void {
     );
 }
 
-pub fn openFile(io: Io, dir: fs.Dir, sub_path: []const u8, flags: fs.File.OpenFlags) FileOpenError!fs.File {
-    return io.vtable.openFile(io.userdata, dir, sub_path, flags);
-}
-
-pub fn createFile(io: Io, dir: fs.Dir, sub_path: []const u8, flags: fs.File.CreateFlags) FileOpenError!fs.File {
-    return io.vtable.createFile(io.userdata, dir, sub_path, flags);
-}
-
-pub fn closeFile(io: Io, file: fs.File) void {
-    return io.vtable.closeFile(io.userdata, file);
-}
-
-pub fn read(io: Io, file: fs.File, buffer: []u8) FileReadError!usize {
-    return @errorCast(io.pread(file, buffer, -1));
-}
-
-pub fn pread(io: Io, file: fs.File, buffer: []u8, offset: std.posix.off_t) FilePReadError!usize {
-    return io.vtable.pread(io.userdata, file, buffer, offset);
+pub fn now(io: Io, clockid: std.posix.clockid_t) ClockGetTimeError!Timestamp {
+    return io.vtable.now(io.userdata, clockid);
 }
 
-pub fn write(io: Io, file: fs.File, buffer: []const u8) FileWriteError!usize {
-    return @errorCast(io.pwrite(file, buffer, -1));
+pub fn sleep(io: Io, clockid: std.posix.clockid_t, deadline: Deadline) SleepError!void {
+    return io.vtable.sleep(io.userdata, clockid, deadline);
 }
 
-pub fn pwrite(io: Io, file: fs.File, buffer: []const u8, offset: std.posix.off_t) FilePWriteError!usize {
-    return io.vtable.pwrite(io.userdata, file, buffer, offset);
+pub fn sleepDuration(io: Io, duration: Duration) SleepError!void {
+    return io.vtable.sleep(io.userdata, .MONOTONIC, .{ .duration = duration });
 }
 
-pub fn writeAll(io: Io, file: fs.File, bytes: []const u8) FileWriteError!void {
-    var index: usize = 0;
-    while (index < bytes.len) {
-        index += try io.write(file, bytes[index..]);
+/// Given a struct with each field a `*Future`, returns a union with the same
+/// fields, each field type the future's result.
+pub fn SelectUnion(S: type) type {
+    const struct_fields = @typeInfo(S).@"struct".fields;
+    var fields: [struct_fields.len]std.builtin.Type.UnionField = undefined;
+    for (&fields, struct_fields) |*union_field, struct_field| {
+        const F = @typeInfo(struct_field.type).pointer.child;
+        const Result = @TypeOf(@as(F, undefined).result);
+        union_field.* = .{
+            .name = struct_field.name,
+            .type = Result,
+            .alignment = struct_field.alignment,
+        };
     }
+    return @Type(.{ .@"union" = .{
+        .layout = .auto,
+        .tag_type = std.meta.FieldEnum(S),
+        .fields = &fields,
+        .decls = &.{},
+    } });
 }
 
-pub fn readAll(io: Io, file: fs.File, buffer: []u8) FileReadError!usize {
-    var index: usize = 0;
-    while (index != buffer.len) {
-        const amt = try io.read(file, buffer[index..]);
-        if (amt == 0) break;
-        index += amt;
+/// `s` is a struct with every field a `*Future(T)`, where `T` can be any type,
+/// and can be different for each field.
+pub fn select(io: Io, s: anytype) SelectUnion(@TypeOf(s)) {
+    const U = SelectUnion(@TypeOf(s));
+    const S = @TypeOf(s);
+    const fields = @typeInfo(S).@"struct".fields;
+    var futures: [fields.len]*AnyFuture = undefined;
+    inline for (fields, &futures) |field, *any_future| {
+        const future = @field(s, field.name);
+        any_future.* = future.any_future orelse return @unionInit(U, field.name, future.result);
+    }
+    switch (io.vtable.select(io.userdata, &futures)) {
+        inline 0...(fields.len - 1) => |selected_index| {
+            const field_name = fields[selected_index].name;
+            return @unionInit(U, field_name, @field(s, field_name).await(io));
+        },
+        else => unreachable,
     }
-    return index;
-}
-
-pub fn now(io: Io, clockid: std.posix.clockid_t) ClockGetTimeError!Timestamp {
-    return io.vtable.now(io.userdata, clockid);
-}
-
-pub fn sleep(io: Io, clockid: std.posix.clockid_t, deadline: Deadline) SleepError!void {
-    return io.vtable.sleep(io.userdata, clockid, deadline);
 }