Commit 3b34622368

Andrew Kelley <andrew@ziglang.org>
2025-10-11 07:37:54
std.Io: add unix domain sockets API
note that "reuseaddr" does nothing for these
1 parent 9e681ca
Changed files (4)
lib/std/Io/net/test.zig
@@ -290,15 +290,16 @@ test "listen on a unix socket, send bytes, receive bytes" {
     const socket_path = try generateFileName("socket.unix");
     defer testing.allocator.free(socket_path);
 
-    const socket_addr = try net.IpAddress.initUnix(socket_path);
+    const socket_addr = try net.UnixAddress.init(socket_path);
     defer std.fs.cwd().deleteFile(socket_path) catch {};
 
-    var server = try socket_addr.listen(io, .{});
-    defer server.deinit(io);
+    var server = try socket_addr.listen(io);
+    defer server.socket.close(io);
 
     const S = struct {
         fn clientFn(path: []const u8) !void {
-            var stream = try net.connectUnixSocket(path);
+            const server_path: net.UnixAddress = try .init(path);
+            var stream = try server_path.connect(io);
             defer stream.close(io);
 
             var stream_writer = stream.writer(io, &.{});
@@ -319,23 +320,6 @@ test "listen on a unix socket, send bytes, receive bytes" {
     try testing.expectEqualSlices(u8, "Hello world!", buf[0..n]);
 }
 
-test "listen on a unix socket with reuse_address option" {
-    if (!net.has_unix_sockets) return error.SkipZigTest;
-    // Windows doesn't implement reuse port option.
-    if (builtin.os.tag == .windows) return error.SkipZigTest;
-
-    const io = testing.io;
-
-    const socket_path = try generateFileName("socket.unix");
-    defer testing.allocator.free(socket_path);
-
-    const socket_addr = try net.Address.initUnix(socket_path);
-    defer std.fs.cwd().deleteFile(socket_path) catch {};
-
-    var server = try socket_addr.listen(io, .{ .reuse_address = true });
-    server.deinit(io);
-}
-
 fn generateFileName(base_name: []const u8) ![]const u8 {
     const random_bytes_count = 12;
     const sub_path_len = comptime std.fs.base64_encoder.calcSize(random_bytes_count);
lib/std/Io/net.zig
@@ -219,7 +219,7 @@ pub const IpAddress = union(enum) {
     /// Waits for a TCP connection. When using this API, `bind` does not need
     /// to be called. The returned `Server` has an open `stream`.
     pub fn listen(address: IpAddress, io: Io, options: ListenOptions) ListenError!Server {
-        return io.vtable.listen(io.userdata, address, options);
+        return io.vtable.netListenIp(io.userdata, address, options);
     }
 
     pub const BindError = error{
@@ -262,7 +262,7 @@ pub const IpAddress = union(enum) {
     /// One bound `Socket` can be used to receive messages from multiple
     /// different addresses.
     pub fn bind(address: *const IpAddress, io: Io, options: BindOptions) BindError!Socket {
-        return io.vtable.ipBind(io.userdata, address, options);
+        return io.vtable.netBindIp(io.userdata, address, options);
     }
 
     pub const ConnectError = error{
@@ -298,7 +298,7 @@ pub const IpAddress = union(enum) {
 
     /// Initiates a connection-oriented network stream.
     pub fn connect(address: *const IpAddress, io: Io, options: ConnectOptions) ConnectError!Stream {
-        return io.vtable.ipConnect(io.userdata, address, options);
+        return io.vtable.netConnectIp(io.userdata, address, options);
     }
 };
 
@@ -775,6 +775,39 @@ pub const Ip6Address = struct {
     };
 };
 
+pub const UnixAddress = struct {
+    path: []const u8,
+
+    pub const max_len = 108;
+
+    pub const InitError = error{NameTooLong};
+
+    pub fn init(p: []const u8) InitError!UnixAddress {
+        if (p.len > max_len) return error.NameTooLong;
+        return .{ .path = p };
+    }
+
+    pub const ListenError = error{};
+
+    pub fn listen(ua: UnixAddress, io: Io) ListenError!Server {
+        assert(ua.path.len <= max_len);
+        return .{ .socket = .{
+            .handle = try io.vtable.netListenUnix(io.userdata, ua),
+            .address = .{ .ip4 = .loopback(0) },
+        } };
+    }
+
+    pub const ConnectError = error{};
+
+    pub fn connect(ua: UnixAddress, io: Io) ConnectError!Stream {
+        assert(ua.path.len <= max_len);
+        return .{ .socket = .{
+            .handle = try io.vtable.netConnectUnix(io.userdata, ua),
+            .address = .{ .ip4 = .loopback(0) },
+        } };
+    }
+};
+
 pub const ReceiveFlags = packed struct(u8) {
     oob: bool = false,
     peek: bool = false,
@@ -917,6 +950,7 @@ pub const Socket = struct {
         else => std.posix.fd_t,
     };
 
+    /// Leaves `address` in a valid state.
     pub fn close(s: *Socket, io: Io) void {
         io.vtable.netClose(io.userdata, s.handle);
         s.handle = undefined;
@@ -1156,7 +1190,7 @@ pub const Server = struct {
 
     /// Blocks until a client connects to the server.
     pub fn accept(s: *Server, io: Io) AcceptError!Stream {
-        return io.vtable.accept(io.userdata, s);
+        return io.vtable.netAccept(io.userdata, s.socket.handle);
     }
 };
 
lib/std/Io/Threaded.zig
@@ -203,22 +203,24 @@ pub fn io(pool: *Pool) Io {
                 else => sleepPosix,
             },
 
-            .listen = switch (builtin.os.tag) {
+            .netListenIp = switch (builtin.os.tag) {
                 .windows => @panic("TODO"),
-                else => listenPosix,
+                else => netListenIpPosix,
             },
-            .accept = switch (builtin.os.tag) {
+            .netListenUnix = netListenUnix,
+            .netAccept = switch (builtin.os.tag) {
                 .windows => @panic("TODO"),
-                else => acceptPosix,
+                else => netAcceptPosix,
             },
-            .ipBind = switch (builtin.os.tag) {
+            .netBindIp = switch (builtin.os.tag) {
                 .windows => @panic("TODO"),
-                else => ipBindPosix,
+                else => netBindIpPosix,
             },
-            .ipConnect = switch (builtin.os.tag) {
+            .netConnectIp = switch (builtin.os.tag) {
                 .windows => @panic("TODO"),
-                else => ipConnectPosix,
+                else => netConnectIpPosix,
             },
+            .netConnectUnix = netConnectUnix,
             .netClose = netClose,
             .netRead = switch (builtin.os.tag) {
                 .windows => @panic("TODO"),
@@ -1636,7 +1638,7 @@ fn select(userdata: ?*anyopaque, futures: []const *Io.AnyFuture) usize {
     return result.?;
 }
 
-fn listenPosix(
+fn netListenIpPosix(
     userdata: ?*anyopaque,
     address: Io.net.IpAddress,
     options: Io.net.IpAddress.ListenOptions,
@@ -1702,6 +1704,13 @@ fn listenPosix(
     };
 }
 
+fn netListenUnix(userdata: ?*anyopaque, address: Io.net.UnixAddress) Io.net.UnixAddress.ListenError!Io.net.Socket.Handle {
+    const pool: *Pool = @ptrCast(@alignCast(userdata));
+    _ = pool;
+    _ = address;
+    @panic("TODO");
+}
+
 fn posixBind(pool: *Pool, socket_fd: posix.socket_t, addr: *const posix.sockaddr, addr_len: posix.socklen_t) !void {
     while (true) {
         try pool.checkCancel();
@@ -1784,7 +1793,7 @@ fn setSocketOption(pool: *Pool, fd: posix.fd_t, level: i32, opt_name: u32, optio
     }
 }
 
-fn ipConnectPosix(
+fn netConnectIpPosix(
     userdata: ?*anyopaque,
     address: *const Io.net.IpAddress,
     options: Io.net.IpAddress.ConnectOptions,
@@ -1806,7 +1815,14 @@ fn ipConnectPosix(
     } };
 }
 
-fn ipBindPosix(
+fn netConnectUnix(userdata: ?*anyopaque, address: Io.net.UnixAddress) Io.net.UnixAddress.ConnectError!Io.net.Socket.Handle {
+    const pool: *Pool = @ptrCast(@alignCast(userdata));
+    _ = pool;
+    _ = address;
+    @panic("TODO");
+}
+
+fn netBindIpPosix(
     userdata: ?*anyopaque,
     address: *const Io.net.IpAddress,
     options: Io.net.IpAddress.BindOptions,
@@ -1871,9 +1887,8 @@ fn openSocketPosix(pool: *Pool, family: posix.sa_family_t, options: Io.net.IpAdd
 const socket_flags_unsupported = builtin.os.tag.isDarwin() or native_os == .haiku; // 💩💩
 const have_accept4 = !socket_flags_unsupported;
 
-fn acceptPosix(userdata: ?*anyopaque, server: *Io.net.Server) Io.net.Server.AcceptError!Io.net.Stream {
+fn netAcceptPosix(userdata: ?*anyopaque, listen_fd: Io.net.Socket.Handle) Io.net.Server.AcceptError!Io.net.Stream {
     const pool: *Pool = @ptrCast(@alignCast(userdata));
-    const listen_fd = server.socket.handle;
     var storage: PosixAddress = undefined;
     var addr_len: posix.socklen_t = @sizeOf(PosixAddress);
     const fd = while (true) {
lib/std/Io.zig
@@ -672,10 +672,12 @@ pub const VTable = struct {
     now: *const fn (?*anyopaque, Clock) Clock.Error!Timestamp,
     sleep: *const fn (?*anyopaque, Timeout) SleepError!void,
 
-    listen: *const fn (?*anyopaque, address: net.IpAddress, options: net.IpAddress.ListenOptions) net.IpAddress.ListenError!net.Server,
-    accept: *const fn (?*anyopaque, server: *net.Server) net.Server.AcceptError!net.Stream,
-    ipBind: *const fn (?*anyopaque, address: *const net.IpAddress, options: net.IpAddress.BindOptions) net.IpAddress.BindError!net.Socket,
-    ipConnect: *const fn (?*anyopaque, address: *const net.IpAddress, options: net.IpAddress.ConnectOptions) net.IpAddress.ConnectError!net.Stream,
+    netListenIp: *const fn (?*anyopaque, address: net.IpAddress, options: net.IpAddress.ListenOptions) net.IpAddress.ListenError!net.Server,
+    netAccept: *const fn (?*anyopaque, server: net.Socket.Handle) net.Server.AcceptError!net.Stream,
+    netBindIp: *const fn (?*anyopaque, address: *const net.IpAddress, options: net.IpAddress.BindOptions) net.IpAddress.BindError!net.Socket,
+    netConnectIp: *const fn (?*anyopaque, address: *const net.IpAddress, options: net.IpAddress.ConnectOptions) net.IpAddress.ConnectError!net.Stream,
+    netListenUnix: *const fn (?*anyopaque, net.UnixAddress) net.UnixAddress.ListenError!net.Socket.Handle,
+    netConnectUnix: *const fn (?*anyopaque, net.UnixAddress) net.UnixAddress.ConnectError!net.Socket.Handle,
     netSend: *const fn (?*anyopaque, net.Socket.Handle, []net.OutgoingMessage, net.SendFlags) struct { ?net.Socket.SendError, usize },
     netReceive: *const fn (?*anyopaque, net.Socket.Handle, message_buffer: []net.IncomingMessage, data_buffer: []u8, net.ReceiveFlags, Timeout) struct { ?net.Socket.ReceiveTimeoutError, usize },
     netRead: *const fn (?*anyopaque, src: net.Stream, data: [][]u8) net.Stream.Reader.Error!usize,