Commit 86308ba1e1

Andrew Kelley <andrew@ziglang.org>
2023-01-17 02:11:07
std.net.getAddressList: call WSAStartup on Windows
1 parent 62e3fdc
Changed files (2)
lib/std/os/windows.zig
@@ -1296,6 +1296,23 @@ pub fn WSACleanup() !void {
 
 var wsa_startup_mutex: std.Thread.Mutex = .{};
 
+pub fn callWSAStartup() !void {
+    wsa_startup_mutex.lock();
+    defer wsa_startup_mutex.unlock();
+
+    // Here we could use a flag to prevent multiple threads to prevent
+    // multiple calls to WSAStartup, but it doesn't matter. We're globally
+    // leaking the resource intentionally, and the mutex already prevents
+    // data races within the WSAStartup function.
+    _ = WSAStartup(2, 2) catch |err| switch (err) {
+        error.SystemNotAvailable => return error.SystemResources,
+        error.VersionNotSupported => return error.Unexpected,
+        error.BlockingOperationInProgress => return error.Unexpected,
+        error.ProcessFdQuotaExceeded => return error.ProcessFdQuotaExceeded,
+        error.Unexpected => return error.Unexpected,
+    };
+}
+
 /// Microsoft requires WSAStartup to be called to initialize, or else
 /// WSASocketW will return WSANOTINITIALISED.
 /// Since this is a standard library, we do not have the luxury of
@@ -1338,21 +1355,7 @@ pub fn WSASocketW(
                 .WSANOTINITIALISED => {
                     if (!first) return error.Unexpected;
                     first = false;
-
-                    wsa_startup_mutex.lock();
-                    defer wsa_startup_mutex.unlock();
-
-                    // Here we could use a flag to prevent multiple threads to prevent
-                    // multiple calls to WSAStartup, but it doesn't matter. We're globally
-                    // leaking the resource intentionally, and the mutex already prevents
-                    // data races within the WSAStartup function.
-                    _ = WSAStartup(2, 2) catch |err| switch (err) {
-                        error.SystemNotAvailable => return error.SystemResources,
-                        error.VersionNotSupported => return error.Unexpected,
-                        error.BlockingOperationInProgress => return error.Unexpected,
-                        error.ProcessFdQuotaExceeded => return error.ProcessFdQuotaExceeded,
-                        error.Unexpected => return error.Unexpected,
-                    };
+                    try callWSAStartup();
                     continue;
                 },
                 else => |err| return unexpectedWSAError(err),
lib/std/net.zig
@@ -746,7 +746,79 @@ pub fn getAddressList(allocator: mem.Allocator, name: []const u8, port: u16) !*A
     const arena = result.arena.allocator();
     errdefer result.deinit();
 
-    if (builtin.target.os.tag == .windows or builtin.link_libc) {
+    if (builtin.target.os.tag == .windows) {
+        const name_c = try std.cstr.addNullByte(allocator, name);
+        defer allocator.free(name_c);
+
+        const port_c = try std.fmt.allocPrintZ(allocator, "{}", .{port});
+        defer allocator.free(port_c);
+
+        const ws2_32 = os.windows.ws2_32;
+        const hints = os.addrinfo{
+            .flags = ws2_32.AI.NUMERICSERV,
+            .family = os.AF.UNSPEC,
+            .socktype = os.SOCK.STREAM,
+            .protocol = os.IPPROTO.TCP,
+            .canonname = null,
+            .addr = null,
+            .addrlen = 0,
+            .next = null,
+        };
+        var res: *os.addrinfo = undefined;
+        var first = true;
+        while (true) {
+            const rc = ws2_32.getaddrinfo(name_c.ptr, port_c.ptr, &hints, &res);
+            switch (@intToEnum(os.windows.ws2_32.WinsockError, @intCast(u16, rc))) {
+                @intToEnum(os.windows.ws2_32.WinsockError, 0) => break,
+                .WSATRY_AGAIN => return error.TemporaryNameServerFailure,
+                .WSANO_RECOVERY => return error.NameServerFailure,
+                .WSAEAFNOSUPPORT => return error.AddressFamilyNotSupported,
+                .WSA_NOT_ENOUGH_MEMORY => return error.OutOfMemory,
+                .WSAHOST_NOT_FOUND => return error.UnknownHostName,
+                .WSATYPE_NOT_FOUND => return error.ServiceUnavailable,
+                .WSAEINVAL => unreachable,
+                .WSAESOCKTNOSUPPORT => unreachable,
+                .WSANOTINITIALISED => {
+                    if (!first) return error.Unexpected;
+                    first = false;
+                    try os.windows.callWSAStartup();
+                    continue;
+                },
+                else => |err| return os.windows.unexpectedWSAError(err),
+            }
+        }
+        defer ws2_32.freeaddrinfo(res);
+
+        const addr_count = blk: {
+            var count: usize = 0;
+            var it: ?*os.addrinfo = res;
+            while (it) |info| : (it = info.next) {
+                if (info.addr != null) {
+                    count += 1;
+                }
+            }
+            break :blk count;
+        };
+        result.addrs = try arena.alloc(Address, addr_count);
+
+        var it: ?*os.addrinfo = res;
+        var i: usize = 0;
+        while (it) |info| : (it = info.next) {
+            const addr = info.addr orelse continue;
+            result.addrs[i] = Address.initPosix(@alignCast(4, addr));
+
+            if (info.canonname) |n| {
+                if (result.canon_name == null) {
+                    result.canon_name = try arena.dupe(u8, mem.sliceTo(n, 0));
+                }
+            }
+            i += 1;
+        }
+
+        return result;
+    }
+
+    if (builtin.link_libc) {
         const name_c = try std.cstr.addNullByte(allocator, name);
         defer allocator.free(name_c);
 
@@ -765,19 +837,7 @@ pub fn getAddressList(allocator: mem.Allocator, name: []const u8, port: u16) !*A
             .next = null,
         };
         var res: *os.addrinfo = undefined;
-        const rc = sys.getaddrinfo(name_c.ptr, port_c.ptr, &hints, &res);
-        if (builtin.target.os.tag == .windows) switch (@intToEnum(os.windows.ws2_32.WinsockError, @intCast(u16, rc))) {
-            @intToEnum(os.windows.ws2_32.WinsockError, 0) => {},
-            .WSATRY_AGAIN => return error.TemporaryNameServerFailure,
-            .WSANO_RECOVERY => return error.NameServerFailure,
-            .WSAEAFNOSUPPORT => return error.AddressFamilyNotSupported,
-            .WSA_NOT_ENOUGH_MEMORY => return error.OutOfMemory,
-            .WSAHOST_NOT_FOUND => return error.UnknownHostName,
-            .WSATYPE_NOT_FOUND => return error.ServiceUnavailable,
-            .WSAEINVAL => unreachable,
-            .WSAESOCKTNOSUPPORT => unreachable,
-            else => |err| return os.windows.unexpectedWSAError(err),
-        } else switch (rc) {
+        switch (sys.getaddrinfo(name_c.ptr, port_c.ptr, &hints, &res)) {
             @intToEnum(sys.EAI, 0) => {},
             .ADDRFAMILY => return error.HostLacksNetworkAddresses,
             .AGAIN => return error.TemporaryNameServerFailure,
@@ -824,6 +884,7 @@ pub fn getAddressList(allocator: mem.Allocator, name: []const u8, port: u16) !*A
 
         return result;
     }
+
     if (builtin.target.os.tag == .linux) {
         const flags = std.c.AI.NUMERICSERV;
         const family = os.AF.UNSPEC;