Commit d3c4158a10

Andrew Kelley <andrew@ziglang.org>
2025-10-15 09:36:02
std.Io: implement Select
and finish implementation of HostName.connect
1 parent 35ce907
Changed files (5)
lib/std/Io/net/HostName.zig
@@ -88,7 +88,7 @@ pub const LookupResult = union(enum) {
 /// Adds any number of `IpAddress` into resolved, exactly one canonical_name,
 /// and then always finishes by adding one `LookupResult.end` entry.
 ///
-/// Guaranteed not to block if provided queue has capacity at least 8.
+/// Guaranteed not to block if provided queue has capacity at least 16.
 pub fn lookup(
     host_name: HostName,
     io: Io,
@@ -216,11 +216,13 @@ pub fn connect(
     } });
     defer lookup_task.cancel(io);
 
-    var select: Io.Select(union(enum) { ip_connect: IpAddress.ConnectError!Stream }) = .init;
-    defer select.cancel(io);
+    const Result = union(enum) { connect_result: IpAddress.ConnectError!Stream };
+    var finished_task_buffer: [results_buffer.len]Result = undefined;
+    var select: Io.Select(Result) = .init(io, &finished_task_buffer);
+    defer select.cancel();
 
     while (results.getOne(io)) |result| switch (result) {
-        .address => |address| select.async(io, .ip_connect, IpAddress.connect, .{ address, io, options }),
+        .address => |address| select.async(.connect_result, IpAddress.connect, .{ address, io, options }),
         .canonical_name => continue,
         .end => |lookup_result| {
             try lookup_result;
@@ -230,8 +232,8 @@ pub fn connect(
 
     var aggregate_error: ConnectError = error.UnknownHostName;
 
-    while (select.remaining != 0) switch (select.wait(io)) {
-        .ip_connect => |ip_connect| if (ip_connect) |stream| return stream else |err| switch (err) {
+    while (select.outstanding != 0) switch (try select.wait()) {
+        .connect_result => |connect_result| if (connect_result) |stream| return stream else |err| switch (err) {
             error.SystemResources => |e| return e,
             error.OptionUnsupported => |e| return e,
             error.ProcessFdQuotaExceeded => |e| return e,
lib/std/Io/net/test.zig
@@ -1,5 +1,7 @@
-const std = @import("std");
 const builtin = @import("builtin");
+
+const std = @import("std");
+const Io = std.Io;
 const net = std.Io.net;
 const mem = std.mem;
 const testing = std.testing;
@@ -126,33 +128,56 @@ test "resolve DNS" {
         const localhost_v4 = try net.IpAddress.parse("127.0.0.1", 80);
         const localhost_v6 = try net.IpAddress.parse("::2", 80);
 
-        var addresses_buffer: [8]net.IpAddress = undefined;
-        var canon_name_buffer: [net.HostName.max_len]u8 = undefined;
-        const result = try net.HostName.lookup(try .init("localhost"), io, .{
+        var canonical_name_buffer: [net.HostName.max_len]u8 = undefined;
+        var results_buffer: [32]net.HostName.LookupResult = undefined;
+        var results: Io.Queue(net.HostName.LookupResult) = .init(&results_buffer);
+
+        net.HostName.lookup(try .init("localhost"), io, &results, .{
             .port = 80,
-            .addresses_buffer = &addresses_buffer,
-            .canonical_name_buffer = &canon_name_buffer,
+            .canonical_name_buffer = &canonical_name_buffer,
         });
-        for (addresses_buffer[0..result.addresses_len]) |addr| {
-            if (addr.eql(&localhost_v4) or addr.eql(&localhost_v6)) break;
-        } else @panic("unexpected address for localhost");
+
+        var addresses_found: usize = 0;
+
+        while (results.getOne(io)) |result| switch (result) {
+            .address => |address| {
+                if (address.eql(&localhost_v4) or address.eql(&localhost_v6))
+                    addresses_found += 1;
+            },
+            .canonical_name => |canonical_name| try testing.expectEqualStrings("localhost", canonical_name.bytes),
+            .end => |end| {
+                try end;
+                break;
+            },
+        } else |err| return err;
+
+        try testing.expect(addresses_found != 0);
     }
 
     {
         // The tests are required to work even when there is no Internet connection,
         // so some of these errors we must accept and skip the test.
-        var addresses_buffer: [8]net.IpAddress = undefined;
-        var canon_name_buffer: [net.HostName.max_len]u8 = undefined;
-        const result = net.HostName.lookup(try .init("example.com"), io, .{
+        var canonical_name_buffer: [net.HostName.max_len]u8 = undefined;
+        var results_buffer: [16]net.HostName.LookupResult = undefined;
+        var results: Io.Queue(net.HostName.LookupResult) = .init(&results_buffer);
+
+        net.HostName.lookup(try .init("example.com"), io, &results, .{
             .port = 80,
-            .addresses_buffer = &addresses_buffer,
-            .canonical_name_buffer = &canon_name_buffer,
-        }) catch |err| switch (err) {
-            error.UnknownHostName => return error.SkipZigTest,
-            error.NameServerFailure => return error.SkipZigTest,
-            else => return err,
-        };
-        _ = result;
+            .canonical_name_buffer = &canonical_name_buffer,
+        });
+
+        while (results.getOne(io)) |result| switch (result) {
+            .address => {},
+            .canonical_name => {},
+            .end => |end| {
+                end catch |err| switch (err) {
+                    error.UnknownHostName => return error.SkipZigTest,
+                    error.NameServerFailure => return error.SkipZigTest,
+                    else => return err,
+                };
+                break;
+            },
+        } else |err| return err;
     }
 }
 
lib/std/Io/net.zig
@@ -315,8 +315,8 @@ pub const IpAddress = union(enum) {
     };
 
     /// Initiates a connection-oriented network stream.
-    pub fn connect(address: *const IpAddress, io: Io, options: ConnectOptions) ConnectError!Stream {
-        return io.vtable.netConnectIp(io.userdata, address, options);
+    pub fn connect(address: IpAddress, io: Io, options: ConnectOptions) ConnectError!Stream {
+        return io.vtable.netConnectIp(io.userdata, &address, options);
     }
 };
 
lib/std/Io/Threaded.zig
@@ -458,7 +458,7 @@ const GroupClosure = struct {
     group: *Io.Group,
     /// Points to sibling `GroupClosure`. Used for walking the group to cancel all.
     node: std.SinglyLinkedList.Node,
-    func: *const fn (context: *anyopaque) void,
+    func: *const fn (*Io.Group, context: *anyopaque) void,
     context_alignment: std.mem.Alignment,
     context_len: usize,
 
@@ -476,7 +476,7 @@ const GroupClosure = struct {
             return;
         }
         current_closure = closure;
-        gc.func(gc.contextPointer());
+        gc.func(group, gc.contextPointer());
         current_closure = null;
 
         // In case a cancel happens after successful task completion, prevents
@@ -512,7 +512,7 @@ fn groupAsync(
     group: *Io.Group,
     context: []const u8,
     context_alignment: std.mem.Alignment,
-    start: *const fn (context: *const anyopaque) void,
+    start: *const fn (*Io.Group, context: *const anyopaque) void,
 ) void {
     if (builtin.single_threaded) return start(context.ptr);
     const pool: *Pool = @ptrCast(@alignCast(userdata));
@@ -520,7 +520,7 @@ fn groupAsync(
     const gpa = pool.allocator;
     const n = GroupClosure.contextEnd(context_alignment, context.len);
     const gc: *GroupClosure = @ptrCast(@alignCast(gpa.alignedAlloc(u8, .of(GroupClosure), n) catch {
-        return start(context.ptr);
+        return start(group, context.ptr);
     }));
     gc.* = .{
         .closure = .{
@@ -548,7 +548,7 @@ fn groupAsync(
     pool.threads.ensureTotalCapacityPrecise(gpa, thread_capacity) catch {
         pool.mutex.unlock();
         gc.free(gpa);
-        return start(context.ptr);
+        return start(group, context.ptr);
     };
 
     pool.run_queue.prepend(&gc.closure.node);
@@ -558,7 +558,7 @@ fn groupAsync(
             assert(pool.run_queue.popFirst() == &gc.closure.node);
             pool.mutex.unlock();
             gc.free(gpa);
-            return start(context.ptr);
+            return start(group, context.ptr);
         };
         pool.threads.appendAssumeCapacity(thread);
     }
@@ -2662,6 +2662,7 @@ fn netLookupFallible(
                     .{ .address = addr },
                     .{ .canonical_name = copyCanon(options.canonical_name_buffer, name) },
                 });
+                return;
             } else |_| {}
         }
 
lib/std/Io.zig
@@ -639,7 +639,7 @@ pub const VTable = struct {
         /// Copied and then passed to `start`.
         context: []const u8,
         context_alignment: std.mem.Alignment,
-        start: *const fn (context: *const anyopaque) void,
+        start: *const fn (*Group, context: *const anyopaque) void,
     ) void,
     groupWait: *const fn (?*anyopaque, *Group, token: *anyopaque) void,
     groupCancel: *const fn (?*anyopaque, *Group, token: *anyopaque) void,
@@ -1005,7 +1005,8 @@ pub const Group = struct {
     pub fn async(g: *Group, io: Io, function: anytype, args: std.meta.ArgsTuple(@TypeOf(function))) void {
         const Args = @TypeOf(args);
         const TypeErased = struct {
-            fn start(context: *const anyopaque) void {
+            fn start(group: *Group, context: *const anyopaque) void {
+                _ = group;
                 const args_casted: *const Args = @ptrCast(@alignCast(context));
                 @call(.auto, function, args_casted.*);
             }
@@ -1033,6 +1034,85 @@ pub const Group = struct {
     }
 };
 
+pub fn Select(comptime U: type) type {
+    return struct {
+        io: Io,
+        group: Group,
+        queue: Queue(U),
+        outstanding: usize,
+
+        const S = @This();
+
+        pub const Union = U;
+
+        pub const Field = std.meta.FieldEnum(U);
+
+        pub fn init(io: Io, buffer: []U) S {
+            return .{
+                .io = io,
+                .queue = .init(buffer),
+                .group = .init,
+                .outstanding = 0,
+            };
+        }
+
+        /// Calls `function` with `args` asynchronously. The resource spawned is
+        /// owned by the select.
+        ///
+        /// `function` must have return type matching the `field` field of `Union`.
+        ///
+        /// `function` *may* be called immediately, before `async` returns.
+        ///
+        /// After this is called, `wait` or `cancel` must be called before the
+        /// select is deinitialized.
+        ///
+        /// Threadsafe.
+        ///
+        /// Related:
+        /// * `Io.async`
+        /// * `Group.async`
+        pub fn async(
+            s: *S,
+            comptime field: Field,
+            function: anytype,
+            args: std.meta.ArgsTuple(@TypeOf(function)),
+        ) void {
+            const Args = @TypeOf(args);
+            const TypeErased = struct {
+                fn start(group: *Group, context: *const anyopaque) void {
+                    const args_casted: *const Args = @ptrCast(@alignCast(context));
+                    const unerased_select: *S = @fieldParentPtr("group", group);
+                    const elem = @unionInit(U, @tagName(field), @call(.auto, function, args_casted.*));
+                    unerased_select.queue.putOneUncancelable(unerased_select.io, elem);
+                }
+            };
+            _ = @atomicRmw(usize, &s.outstanding, .Add, 1, .monotonic);
+            s.io.vtable.groupAsync(s.io.userdata, &s.group, @ptrCast((&args)[0..1]), .of(Args), TypeErased.start);
+        }
+
+        /// Blocks until another task of the select finishes.
+        ///
+        /// Asserts there is at least one more `outstanding` task.
+        ///
+        /// Not threadsafe.
+        pub fn wait(s: *S) Io.Cancelable!U {
+            s.outstanding -= 1;
+            return s.queue.getOne(s.io);
+        }
+
+        /// Equivalent to `wait` but requests cancellation on all remaining
+        /// tasks owned by the select.
+        ///
+        /// It is illegal to call `wait` after this.
+        ///
+        /// Idempotent. Not threadsafe.
+        pub fn cancel(s: *S) void {
+            s.outstanding = 0;
+            s.group.cancel(s.io);
+        }
+    };
+}
+
 pub const Mutex = struct {
     state: State,