Commit d3c4158a10
Changed files (5)
lib
lib/std/Io/net/HostName.zig
@@ -88,7 +88,7 @@ pub const LookupResult = union(enum) {
/// Adds any number of `IpAddress` into resolved, exactly one canonical_name,
/// and then always finishes by adding one `LookupResult.end` entry.
///
-/// Guaranteed not to block if provided queue has capacity at least 8.
+/// Guaranteed not to block if provided queue has capacity at least 16.
pub fn lookup(
host_name: HostName,
io: Io,
@@ -216,11 +216,13 @@ pub fn connect(
} });
defer lookup_task.cancel(io);
- var select: Io.Select(union(enum) { ip_connect: IpAddress.ConnectError!Stream }) = .init;
- defer select.cancel(io);
+ const Result = union(enum) { connect_result: IpAddress.ConnectError!Stream };
+ var finished_task_buffer: [results_buffer.len]Result = undefined;
+ var select: Io.Select(Result) = .init(io, &finished_task_buffer);
+ defer select.cancel();
while (results.getOne(io)) |result| switch (result) {
- .address => |address| select.async(io, .ip_connect, IpAddress.connect, .{ address, io, options }),
+ .address => |address| select.async(.connect_result, IpAddress.connect, .{ address, io, options }),
.canonical_name => continue,
.end => |lookup_result| {
try lookup_result;
@@ -230,8 +232,8 @@ pub fn connect(
var aggregate_error: ConnectError = error.UnknownHostName;
- while (select.remaining != 0) switch (select.wait(io)) {
- .ip_connect => |ip_connect| if (ip_connect) |stream| return stream else |err| switch (err) {
+ while (select.outstanding != 0) switch (try select.wait()) {
+ .connect_result => |connect_result| if (connect_result) |stream| return stream else |err| switch (err) {
error.SystemResources => |e| return e,
error.OptionUnsupported => |e| return e,
error.ProcessFdQuotaExceeded => |e| return e,
lib/std/Io/net/test.zig
@@ -1,5 +1,7 @@
-const std = @import("std");
const builtin = @import("builtin");
+
+const std = @import("std");
+const Io = std.Io;
const net = std.Io.net;
const mem = std.mem;
const testing = std.testing;
@@ -126,33 +128,56 @@ test "resolve DNS" {
const localhost_v4 = try net.IpAddress.parse("127.0.0.1", 80);
const localhost_v6 = try net.IpAddress.parse("::2", 80);
- var addresses_buffer: [8]net.IpAddress = undefined;
- var canon_name_buffer: [net.HostName.max_len]u8 = undefined;
- const result = try net.HostName.lookup(try .init("localhost"), io, .{
+ var canonical_name_buffer: [net.HostName.max_len]u8 = undefined;
+ var results_buffer: [32]net.HostName.LookupResult = undefined;
+ var results: Io.Queue(net.HostName.LookupResult) = .init(&results_buffer);
+
+ net.HostName.lookup(try .init("localhost"), io, &results, .{
.port = 80,
- .addresses_buffer = &addresses_buffer,
- .canonical_name_buffer = &canon_name_buffer,
+ .canonical_name_buffer = &canonical_name_buffer,
});
- for (addresses_buffer[0..result.addresses_len]) |addr| {
- if (addr.eql(&localhost_v4) or addr.eql(&localhost_v6)) break;
- } else @panic("unexpected address for localhost");
+
+ var addresses_found: usize = 0;
+
+ while (results.getOne(io)) |result| switch (result) {
+ .address => |address| {
+ if (address.eql(&localhost_v4) or address.eql(&localhost_v6))
+ addresses_found += 1;
+ },
+ .canonical_name => |canonical_name| try testing.expectEqualStrings("localhost", canonical_name.bytes),
+ .end => |end| {
+ try end;
+ break;
+ },
+ } else |err| return err;
+
+ try testing.expect(addresses_found != 0);
}
{
// The tests are required to work even when there is no Internet connection,
// so some of these errors we must accept and skip the test.
- var addresses_buffer: [8]net.IpAddress = undefined;
- var canon_name_buffer: [net.HostName.max_len]u8 = undefined;
- const result = net.HostName.lookup(try .init("example.com"), io, .{
+ var canonical_name_buffer: [net.HostName.max_len]u8 = undefined;
+ var results_buffer: [16]net.HostName.LookupResult = undefined;
+ var results: Io.Queue(net.HostName.LookupResult) = .init(&results_buffer);
+
+ net.HostName.lookup(try .init("example.com"), io, &results, .{
.port = 80,
- .addresses_buffer = &addresses_buffer,
- .canonical_name_buffer = &canon_name_buffer,
- }) catch |err| switch (err) {
- error.UnknownHostName => return error.SkipZigTest,
- error.NameServerFailure => return error.SkipZigTest,
- else => return err,
- };
- _ = result;
+ .canonical_name_buffer = &canonical_name_buffer,
+ });
+
+ while (results.getOne(io)) |result| switch (result) {
+ .address => {},
+ .canonical_name => {},
+ .end => |end| {
+ end catch |err| switch (err) {
+ error.UnknownHostName => return error.SkipZigTest,
+ error.NameServerFailure => return error.SkipZigTest,
+ else => return err,
+ };
+ break;
+ },
+ } else |err| return err;
}
}
lib/std/Io/net.zig
@@ -315,8 +315,8 @@ pub const IpAddress = union(enum) {
};
/// Initiates a connection-oriented network stream.
- pub fn connect(address: *const IpAddress, io: Io, options: ConnectOptions) ConnectError!Stream {
- return io.vtable.netConnectIp(io.userdata, address, options);
+ pub fn connect(address: IpAddress, io: Io, options: ConnectOptions) ConnectError!Stream {
+ return io.vtable.netConnectIp(io.userdata, &address, options);
}
};
lib/std/Io/Threaded.zig
@@ -458,7 +458,7 @@ const GroupClosure = struct {
group: *Io.Group,
/// Points to sibling `GroupClosure`. Used for walking the group to cancel all.
node: std.SinglyLinkedList.Node,
- func: *const fn (context: *anyopaque) void,
+ func: *const fn (*Io.Group, context: *anyopaque) void,
context_alignment: std.mem.Alignment,
context_len: usize,
@@ -476,7 +476,7 @@ const GroupClosure = struct {
return;
}
current_closure = closure;
- gc.func(gc.contextPointer());
+ gc.func(group, gc.contextPointer());
current_closure = null;
// In case a cancel happens after successful task completion, prevents
@@ -512,7 +512,7 @@ fn groupAsync(
group: *Io.Group,
context: []const u8,
context_alignment: std.mem.Alignment,
- start: *const fn (context: *const anyopaque) void,
+ start: *const fn (*Io.Group, context: *const anyopaque) void,
) void {
if (builtin.single_threaded) return start(context.ptr);
const pool: *Pool = @ptrCast(@alignCast(userdata));
@@ -520,7 +520,7 @@ fn groupAsync(
const gpa = pool.allocator;
const n = GroupClosure.contextEnd(context_alignment, context.len);
const gc: *GroupClosure = @ptrCast(@alignCast(gpa.alignedAlloc(u8, .of(GroupClosure), n) catch {
- return start(context.ptr);
+ return start(group, context.ptr);
}));
gc.* = .{
.closure = .{
@@ -548,7 +548,7 @@ fn groupAsync(
pool.threads.ensureTotalCapacityPrecise(gpa, thread_capacity) catch {
pool.mutex.unlock();
gc.free(gpa);
- return start(context.ptr);
+ return start(group, context.ptr);
};
pool.run_queue.prepend(&gc.closure.node);
@@ -558,7 +558,7 @@ fn groupAsync(
assert(pool.run_queue.popFirst() == &gc.closure.node);
pool.mutex.unlock();
gc.free(gpa);
- return start(context.ptr);
+ return start(group, context.ptr);
};
pool.threads.appendAssumeCapacity(thread);
}
@@ -2662,6 +2662,7 @@ fn netLookupFallible(
.{ .address = addr },
.{ .canonical_name = copyCanon(options.canonical_name_buffer, name) },
});
+ return;
} else |_| {}
}
lib/std/Io.zig
@@ -639,7 +639,7 @@ pub const VTable = struct {
/// Copied and then passed to `start`.
context: []const u8,
context_alignment: std.mem.Alignment,
- start: *const fn (context: *const anyopaque) void,
+ start: *const fn (*Group, context: *const anyopaque) void,
) void,
groupWait: *const fn (?*anyopaque, *Group, token: *anyopaque) void,
groupCancel: *const fn (?*anyopaque, *Group, token: *anyopaque) void,
@@ -1005,7 +1005,8 @@ pub const Group = struct {
pub fn async(g: *Group, io: Io, function: anytype, args: std.meta.ArgsTuple(@TypeOf(function))) void {
const Args = @TypeOf(args);
const TypeErased = struct {
- fn start(context: *const anyopaque) void {
+ fn start(group: *Group, context: *const anyopaque) void {
+ _ = group;
const args_casted: *const Args = @ptrCast(@alignCast(context));
@call(.auto, function, args_casted.*);
}
@@ -1033,6 +1034,85 @@ pub const Group = struct {
}
};
+pub fn Select(comptime U: type) type {
+ return struct {
+ io: Io,
+ group: Group,
+ queue: Queue(U),
+ outstanding: usize,
+
+ const S = @This();
+
+ pub const Union = U;
+
+ pub const Field = std.meta.FieldEnum(U);
+
+ pub fn init(io: Io, buffer: []U) S {
+ return .{
+ .io = io,
+ .queue = .init(buffer),
+ .group = .init,
+ .outstanding = 0,
+ };
+ }
+
+ /// Calls `function` with `args` asynchronously. The resource spawned is
+ /// owned by the select.
+ ///
+ /// `function` must have return type matching the `field` field of `Union`.
+ ///
+ /// `function` *may* be called immediately, before `async` returns.
+ ///
+ /// After this is called, `wait` or `cancel` must be called before the
+ /// select is deinitialized.
+ ///
+ /// Threadsafe.
+ ///
+ /// Related:
+ /// * `Io.async`
+ /// * `Group.async`
+ pub fn async(
+ s: *S,
+ comptime field: Field,
+ function: anytype,
+ args: std.meta.ArgsTuple(@TypeOf(function)),
+ ) void {
+ const Args = @TypeOf(args);
+ const TypeErased = struct {
+ fn start(group: *Group, context: *const anyopaque) void {
+ const args_casted: *const Args = @ptrCast(@alignCast(context));
+ const unerased_select: *S = @fieldParentPtr("group", group);
+ const elem = @unionInit(U, @tagName(field), @call(.auto, function, args_casted.*));
+ unerased_select.queue.putOneUncancelable(unerased_select.io, elem);
+ }
+ };
+ _ = @atomicRmw(usize, &s.outstanding, .Add, 1, .monotonic);
+ s.io.vtable.groupAsync(s.io.userdata, &s.group, @ptrCast((&args)[0..1]), .of(Args), TypeErased.start);
+ }
+
+ /// Blocks until another task of the select finishes.
+ ///
+ /// Asserts there is at least one more `outstanding` task.
+ ///
+ /// Not threadsafe.
+ pub fn wait(s: *S) Io.Cancelable!U {
+ s.outstanding -= 1;
+ return s.queue.getOne(s.io);
+ }
+
+ /// Equivalent to `wait` but requests cancellation on all remaining
+ /// tasks owned by the select.
+ ///
+ /// It is illegal to call `wait` after this.
+ ///
+ /// Idempotent. Not threadsafe.
+ pub fn cancel(s: *S) void {
+ s.outstanding = 0;
+ s.group.cancel(s.io);
+ }
+ };
+}
+
pub const Mutex = struct {
state: State,