Commit 6cc72af03d

Jonathan Marler <johnnymarler@gmail.com>
2020-07-26 07:29:02
Provide Ip4Address and Ip6Address in addition to Address
1 parent c95091e
Changed files (1)
lib
lib/std/net.zig
@@ -14,8 +14,8 @@ const has_unix_sockets = @hasDecl(os, "sockaddr_un");
 
 pub const Address = extern union {
     any: os.sockaddr,
-    in: os.sockaddr_in,
-    in6: os.sockaddr_in6,
+    in: Ip4Address,
+    in6: Ip6Address,
     un: if (has_unix_sockets) os.sockaddr_un else void,
 
     // TODO this crashed the compiler. https://github.com/ziglang/zig/issues/3512
@@ -76,19 +76,227 @@ pub const Address = extern union {
         }
     }
 
+    pub fn parseIp6(buf: []const u8, port: u16) !Address {
+        return Address{.in6 = try Ip6Address.parse(buf, port) };
+    }
+
+    pub fn resolveIp6(buf: []const u8, port: u16) !Address {
+        return Address{.in6 = try Ip6Address.resolve(buf, port) };
+    }
+
+    pub fn parseIp4(buf: []const u8, port: u16) !Address {
+        return Address {.in = try Ip4Address.parse(buf, port) };
+    }
+
+    pub fn initIp4(addr: [4]u8, port: u16) Address {
+        return Address{.in = Ip4Address.init(addr, port) };
+    }
+
+    pub fn initIp6(addr: [16]u8, port: u16, flowinfo: u32, scope_id: u32) Address {
+        return Address{.in6 = Ip6Address.init(addr, port, flowinfo, scope_id) };
+    }
+
+    pub fn initUnix(path: []const u8) !Address {
+        var sock_addr = os.sockaddr_un{
+            .family = os.AF_UNIX,
+            .path = undefined,
+        };
+
+        // this enables us to have the proper length of the socket in getOsSockLen
+        mem.set(u8, &sock_addr.path, 0);
+
+        if (path.len > sock_addr.path.len) return error.NameTooLong;
+        mem.copy(u8, &sock_addr.path, path);
+
+        return Address{ .un = sock_addr };
+    }
+
+    /// Returns the port in native endian.
+    /// Asserts that the address is ip4 or ip6.
+    pub fn getPort(self: Address) u16 {
+        return switch (self.any.family) {
+            os.AF_INET => self.in.getPort(),
+            os.AF_INET6 => self.in6.getPort(),
+            else => unreachable,
+        };
+    }
+
+    /// `port` is native-endian.
+    /// Asserts that the address is ip4 or ip6.
+    pub fn setPort(self: *Address, port: u16) void {
+        switch (self.any.family) {
+            os.AF_INET => self.in.setPort(port),
+            os.AF_INET6 => self.in6.setPort(port),
+            else => unreachable,
+        }
+    }
+
+    /// Asserts that `addr` is an IP address.
+    /// This function will read past the end of the pointer, with a size depending
+    /// on the address family.
+    pub fn initPosix(addr: *align(4) const os.sockaddr) Address {
+        switch (addr.family) {
+            os.AF_INET => return Address{ .in = Ip4Address{ .sa = @ptrCast(*const os.sockaddr_in, addr).*} },
+            os.AF_INET6 => return Address{ .in6 = Ip6Address{ .sa = @ptrCast(*const os.sockaddr_in6, addr).*} },
+            else => unreachable,
+        }
+    }
+
+    pub fn format(
+        self: Address,
+        comptime fmt: []const u8,
+        options: std.fmt.FormatOptions,
+        out_stream: anytype,
+    ) !void {
+        switch (self.any.family) {
+            os.AF_INET => try self.in.format(fmt, options, out_stream),
+            os.AF_INET6 => try self.in6.format(fmt, options, out_stream),
+            os.AF_UNIX => {
+                if (!has_unix_sockets) {
+                    unreachable;
+                }
+
+                try std.fmt.format(out_stream, "{}", .{&self.un.path});
+            },
+            else => unreachable,
+        }
+    }
+
+    pub fn eql(a: Address, b: Address) bool {
+        const a_bytes = @ptrCast([*]const u8, &a.any)[0..a.getOsSockLen()];
+        const b_bytes = @ptrCast([*]const u8, &b.any)[0..b.getOsSockLen()];
+        return mem.eql(u8, a_bytes, b_bytes);
+    }
+
+    pub fn getOsSockLen(self: Address) os.socklen_t {
+        switch (self.any.family) {
+            os.AF_INET => return self.in.getOsSockLen(),
+            os.AF_INET6 => return self.in6.getOsSockLen(),
+            os.AF_UNIX => {
+                if (!has_unix_sockets) {
+                    unreachable;
+                }
+
+                const path_len = std.mem.len(@ptrCast([*:0]const u8, &self.un.path));
+                return @intCast(os.socklen_t, @sizeOf(os.sockaddr_un) - self.un.path.len + path_len);
+            },
+            else => unreachable,
+        }
+    }
+};
+
+pub const Ip4Address = extern struct {
+    sa: os.sockaddr_in,
+
+    pub fn parse(buf: []const u8, port: u16) !Ip4Address {
+        var result = Ip4Address{
+            .sa = .{
+                .port = mem.nativeToBig(u16, port),
+                .addr = undefined,
+            }
+        };
+        const out_ptr = mem.sliceAsBytes(@as(*[1]u32, &result.sa.addr)[0..]);
+
+        var x: u8 = 0;
+        var index: u8 = 0;
+        var saw_any_digits = false;
+        for (buf) |c| {
+            if (c == '.') {
+                if (!saw_any_digits) {
+                    return error.InvalidCharacter;
+                }
+                if (index == 3) {
+                    return error.InvalidEnd;
+                }
+                out_ptr[index] = x;
+                index += 1;
+                x = 0;
+                saw_any_digits = false;
+            } else if (c >= '0' and c <= '9') {
+                saw_any_digits = true;
+                x = try std.math.mul(u8, x, 10);
+                x = try std.math.add(u8, x, c - '0');
+            } else {
+                return error.InvalidCharacter;
+            }
+        }
+        if (index == 3 and saw_any_digits) {
+            out_ptr[index] = x;
+            return result;
+        }
+
+        return error.Incomplete;
+    }
+
+    pub fn resolveIp(name: []const u8, port: u16) !Ip4Address {
+        if (parse(name, port)) |ip4| return ip4 else |err| switch (err) {
+            error.Overflow,
+            error.InvalidEnd,
+            error.InvalidCharacter,
+            error.Incomplete,
+            => {},
+        }
+        return error.InvalidIPAddressFormat;
+    }
+
+    pub fn init(addr: [4]u8, port: u16) Ip4Address {
+        return Ip4Address {
+            .sa = os.sockaddr_in{
+                .port = mem.nativeToBig(u16, port),
+                .addr = @ptrCast(*align(1) const u32, &addr).*,
+            },
+        };
+    }
+
+    /// Returns the port in native endian.
+    /// Asserts that the address is ip4 or ip6.
+    pub fn getPort(self: Ip4Address) u16 {
+        return mem.bigToNative(u16, self.sa.port);
+    }
+
+    /// `port` is native-endian.
+    /// Asserts that the address is ip4 or ip6.
+    pub fn setPort(self: *Ip4Address, port: u16) void {
+        self.sa.port = mem.nativeToBig(u16, port);
+    }
+
+    pub fn format(
+        self: Ip4Address,
+        comptime fmt: []const u8,
+        options: std.fmt.FormatOptions,
+        out_stream: anytype,
+    ) !void {
+        const bytes = @ptrCast(*const [4]u8, &self.sa.addr);
+        try std.fmt.format(out_stream, "{}.{}.{}.{}:{}", .{
+            bytes[0],
+            bytes[1],
+            bytes[2],
+            bytes[3],
+            self.getPort(),
+        });
+    }
+
+    pub fn getOsSockLen(self: Ip4Address) os.socklen_t {
+        return @sizeOf(os.sockaddr_in);
+    }
+};
+
+pub const Ip6Address = extern struct {
+    sa: os.sockaddr_in6,
+
     /// Parse a given IPv6 address string into an Address.
     /// Assumes the Scope ID of the address is fully numeric.
     /// For non-numeric addresses, see `resolveIp6`.
-    pub fn parseIp6(buf: []const u8, port: u16) !Address {
-        var result = Address{
-            .in6 = os.sockaddr_in6{
+    pub fn parse(buf: []const u8, port: u16) !Ip6Address {
+        var result = Ip6Address{
+            .sa = os.sockaddr_in6{
                 .scope_id = 0,
                 .port = mem.nativeToBig(u16, port),
                 .flowinfo = 0,
                 .addr = undefined,
             },
         };
-        var ip_slice = result.in6.addr[0..];
+        var ip_slice = result.sa.addr[0..];
 
         var tail: [16]u8 = undefined;
 
@@ -101,10 +309,10 @@ pub const Address = extern union {
             if (scope_id) {
                 if (c >= '0' and c <= '9') {
                     const digit = c - '0';
-                    if (@mulWithOverflow(u32, result.in6.scope_id, 10, &result.in6.scope_id)) {
+                    if (@mulWithOverflow(u32, result.sa.scope_id, 10, &result.sa.scope_id)) {
                         return error.Overflow;
                     }
-                    if (@addWithOverflow(u32, result.in6.scope_id, digit, &result.in6.scope_id)) {
+                    if (@addWithOverflow(u32, result.sa.scope_id, digit, &result.sa.scope_id)) {
                         return error.Overflow;
                     }
                 } else {
@@ -141,10 +349,10 @@ pub const Address = extern union {
                     return error.InvalidIpv4Mapping;
                 }
                 const start_index = mem.lastIndexOfScalar(u8, buf[0..i], ':').? + 1;
-                const addr = (parseIp4(buf[start_index..], 0) catch {
+                const addr = (Ip4Address.parse(buf[start_index..], 0) catch {
                     return error.InvalidIpv4Mapping;
-                }).in.addr;
-                ip_slice = result.in6.addr[0..];
+                }).sa.addr;
+                ip_slice = result.sa.addr[0..];
                 ip_slice[10] = 0xff;
                 ip_slice[11] = 0xff;
 
@@ -180,22 +388,22 @@ pub const Address = extern union {
             index += 1;
             ip_slice[index] = @truncate(u8, x);
             index += 1;
-            mem.copy(u8, result.in6.addr[16 - index ..], ip_slice[0..index]);
+            mem.copy(u8, result.sa.addr[16 - index ..], ip_slice[0..index]);
             return result;
         }
     }
 
-    pub fn resolveIp6(buf: []const u8, port: u16) !Address {
+    pub fn resolve(buf: []const u8, port: u16) !Ip6Address {
         // TODO: Unify the implementations of resolveIp6 and parseIp6.
-        var result = Address{
-            .in6 = os.sockaddr_in6{
+        var result = Ip6Address{
+            .sa = os.sockaddr_in6{
                 .scope_id = 0,
                 .port = mem.nativeToBig(u16, port),
                 .flowinfo = 0,
                 .addr = undefined,
             },
         };
-        var ip_slice = result.in6.addr[0..];
+        var ip_slice = result.sa.addr[0..];
 
         var tail: [16]u8 = undefined;
 
@@ -256,10 +464,10 @@ pub const Address = extern union {
                     return error.InvalidIpv4Mapping;
                 }
                 const start_index = mem.lastIndexOfScalar(u8, buf[0..i], ':').? + 1;
-                const addr = (parseIp4(buf[start_index..], 0) catch {
+                const addr = (Ip4Address.parse(buf[start_index..], 0) catch {
                     return error.InvalidIpv4Mapping;
-                }).in.addr;
-                ip_slice = result.in6.addr[0..];
+                }).sa.addr;
+                ip_slice = result.sa.addr[0..];
                 ip_slice[10] = 0xff;
                 ip_slice[11] = 0xff;
 
@@ -299,7 +507,7 @@ pub const Address = extern union {
             };
         }
 
-        result.in6.scope_id = resolved_scope_id;
+        result.sa.scope_id = resolved_scope_id;
 
         if (index == 14) {
             ip_slice[14] = @truncate(u8, x >> 8);
@@ -310,63 +518,14 @@ pub const Address = extern union {
             index += 1;
             ip_slice[index] = @truncate(u8, x);
             index += 1;
-            mem.copy(u8, result.in6.addr[16 - index ..], ip_slice[0..index]);
-            return result;
-        }
-    }
-
-    pub fn parseIp4(buf: []const u8, port: u16) !Address {
-        var result = Address{
-            .in = os.sockaddr_in{
-                .port = mem.nativeToBig(u16, port),
-                .addr = undefined,
-            },
-        };
-        const out_ptr = mem.sliceAsBytes(@as(*[1]u32, &result.in.addr)[0..]);
-
-        var x: u8 = 0;
-        var index: u8 = 0;
-        var saw_any_digits = false;
-        for (buf) |c| {
-            if (c == '.') {
-                if (!saw_any_digits) {
-                    return error.InvalidCharacter;
-                }
-                if (index == 3) {
-                    return error.InvalidEnd;
-                }
-                out_ptr[index] = x;
-                index += 1;
-                x = 0;
-                saw_any_digits = false;
-            } else if (c >= '0' and c <= '9') {
-                saw_any_digits = true;
-                x = try std.math.mul(u8, x, 10);
-                x = try std.math.add(u8, x, c - '0');
-            } else {
-                return error.InvalidCharacter;
-            }
-        }
-        if (index == 3 and saw_any_digits) {
-            out_ptr[index] = x;
+            mem.copy(u8, result.sa.addr[16 - index ..], ip_slice[0..index]);
             return result;
         }
-
-        return error.Incomplete;
     }
 
-    pub fn initIp4(addr: [4]u8, port: u16) Address {
-        return Address{
-            .in = os.sockaddr_in{
-                .port = mem.nativeToBig(u16, port),
-                .addr = @ptrCast(*align(1) const u32, &addr).*,
-            },
-        };
-    }
-
-    pub fn initIp6(addr: [16]u8, port: u16, flowinfo: u32, scope_id: u32) Address {
-        return Address{
-            .in6 = os.sockaddr_in6{
+    pub fn init(addr: [16]u8, port: u16, flowinfo: u32, scope_id: u32) Ip6Address {
+        return Ip6Address{
+            .sa = os.sockaddr_in6{
                 .addr = addr,
                 .port = mem.nativeToBig(u16, port),
                 .flowinfo = flowinfo,
@@ -375,147 +534,71 @@ pub const Address = extern union {
         };
     }
 
-    pub fn initUnix(path: []const u8) !Address {
-        var sock_addr = os.sockaddr_un{
-            .family = os.AF_UNIX,
-            .path = undefined,
-        };
-
-        // this enables us to have the proper length of the socket in getOsSockLen
-        mem.set(u8, &sock_addr.path, 0);
-
-        if (path.len > sock_addr.path.len) return error.NameTooLong;
-        mem.copy(u8, &sock_addr.path, path);
-
-        return Address{ .un = sock_addr };
-    }
-
     /// Returns the port in native endian.
     /// Asserts that the address is ip4 or ip6.
-    pub fn getPort(self: Address) u16 {
-        const big_endian_port = switch (self.any.family) {
-            os.AF_INET => self.in.port,
-            os.AF_INET6 => self.in6.port,
-            else => unreachable,
-        };
-        return mem.bigToNative(u16, big_endian_port);
+    pub fn getPort(self: Ip6Address) u16 {
+        return mem.bigToNative(u16, self.sa.port);
     }
 
     /// `port` is native-endian.
     /// Asserts that the address is ip4 or ip6.
-    pub fn setPort(self: *Address, port: u16) void {
-        const ptr = switch (self.any.family) {
-            os.AF_INET => &self.in.port,
-            os.AF_INET6 => &self.in6.port,
-            else => unreachable,
-        };
-        ptr.* = mem.nativeToBig(u16, port);
-    }
-
-    /// Asserts that `addr` is an IP address.
-    /// This function will read past the end of the pointer, with a size depending
-    /// on the address family.
-    pub fn initPosix(addr: *align(4) const os.sockaddr) Address {
-        switch (addr.family) {
-            os.AF_INET => return Address{ .in = @ptrCast(*const os.sockaddr_in, addr).* },
-            os.AF_INET6 => return Address{ .in6 = @ptrCast(*const os.sockaddr_in6, addr).* },
-            else => unreachable,
-        }
+    pub fn setPort(self: *Ip6Address, port: u16) void {
+        self.sa.port = mem.nativeToBig(u16, port);
     }
 
     pub fn format(
-        self: Address,
+        self: Ip6Address,
         comptime fmt: []const u8,
         options: std.fmt.FormatOptions,
         out_stream: anytype,
     ) !void {
-        switch (self.any.family) {
-            os.AF_INET => {
-                const port = mem.bigToNative(u16, self.in.port);
-                const bytes = @ptrCast(*const [4]u8, &self.in.addr);
-                try std.fmt.format(out_stream, "{}.{}.{}.{}:{}", .{
-                    bytes[0],
-                    bytes[1],
-                    bytes[2],
-                    bytes[3],
-                    port,
-                });
-            },
-            os.AF_INET6 => {
-                const port = mem.bigToNative(u16, self.in6.port);
-                if (mem.eql(u8, self.in6.addr[0..12], &[_]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff })) {
-                    try std.fmt.format(out_stream, "[::ffff:{}.{}.{}.{}]:{}", .{
-                        self.in6.addr[12],
-                        self.in6.addr[13],
-                        self.in6.addr[14],
-                        self.in6.addr[15],
-                        port,
-                    });
-                    return;
-                }
-                const big_endian_parts = @ptrCast(*align(1) const [8]u16, &self.in6.addr);
-                const native_endian_parts = switch (builtin.endian) {
-                    .Big => big_endian_parts.*,
-                    .Little => blk: {
-                        var buf: [8]u16 = undefined;
-                        for (big_endian_parts) |part, i| {
-                            buf[i] = mem.bigToNative(u16, part);
-                        }
-                        break :blk buf;
-                    },
-                };
-                try out_stream.writeAll("[");
-                var i: usize = 0;
-                var abbrv = false;
-                while (i < native_endian_parts.len) : (i += 1) {
-                    if (native_endian_parts[i] == 0) {
-                        if (!abbrv) {
-                            try out_stream.writeAll(if (i == 0) "::" else ":");
-                            abbrv = true;
-                        }
-                        continue;
-                    }
-                    try std.fmt.format(out_stream, "{x}", .{native_endian_parts[i]});
-                    if (i != native_endian_parts.len - 1) {
-                        try out_stream.writeAll(":");
-                    }
+        const port = mem.bigToNative(u16, self.sa.port);
+        if (mem.eql(u8, self.sa.addr[0..12], &[_]u8{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff })) {
+            try std.fmt.format(out_stream, "[::ffff:{}.{}.{}.{}]:{}", .{
+                self.sa.addr[12],
+                self.sa.addr[13],
+                self.sa.addr[14],
+                self.sa.addr[15],
+                port,
+            });
+            return;
+        }
+        const big_endian_parts = @ptrCast(*align(1) const [8]u16, &self.sa.addr);
+        const native_endian_parts = switch (builtin.endian) {
+            .Big => big_endian_parts.*,
+            .Little => blk: {
+                var buf: [8]u16 = undefined;
+                for (big_endian_parts) |part, i| {
+                    buf[i] = mem.bigToNative(u16, part);
                 }
-                try std.fmt.format(out_stream, "]:{}", .{port});
+                break :blk buf;
             },
-            os.AF_UNIX => {
-                if (!has_unix_sockets) {
-                    unreachable;
+        };
+        try out_stream.writeAll("[");
+        var i: usize = 0;
+        var abbrv = false;
+        while (i < native_endian_parts.len) : (i += 1) {
+            if (native_endian_parts[i] == 0) {
+                if (!abbrv) {
+                    try out_stream.writeAll(if (i == 0) "::" else ":");
+                    abbrv = true;
                 }
-
-                try std.fmt.format(out_stream, "{}", .{&self.un.path});
-            },
-            else => unreachable,
+                continue;
+            }
+            try std.fmt.format(out_stream, "{x}", .{native_endian_parts[i]});
+            if (i != native_endian_parts.len - 1) {
+                try out_stream.writeAll(":");
+            }
         }
+        try std.fmt.format(out_stream, "]:{}", .{port});
     }
 
-    pub fn eql(a: Address, b: Address) bool {
-        const a_bytes = @ptrCast([*]const u8, &a.any)[0..a.getOsSockLen()];
-        const b_bytes = @ptrCast([*]const u8, &b.any)[0..b.getOsSockLen()];
-        return mem.eql(u8, a_bytes, b_bytes);
-    }
-
-    pub fn getOsSockLen(self: Address) os.socklen_t {
-        switch (self.any.family) {
-            os.AF_INET => return @sizeOf(os.sockaddr_in),
-            os.AF_INET6 => return @sizeOf(os.sockaddr_in6),
-            os.AF_UNIX => {
-                if (!has_unix_sockets) {
-                    unreachable;
-                }
-
-                const path_len = std.mem.len(@ptrCast([*:0]const u8, &self.un.path));
-                return @intCast(os.socklen_t, @sizeOf(os.sockaddr_un) - self.un.path.len + path_len);
-            },
-            else => unreachable,
-        }
+    pub fn getOsSockLen(self: Ip6Address) os.socklen_t {
+        return @sizeOf(os.sockaddr_in6);
     }
 };
 
+
 pub fn connectUnixSocket(path: []const u8) !fs.File {
     const opt_non_block = if (std.io.is_async) os.SOCK_NONBLOCK else 0;
     const sockfd = try os.socket(
@@ -777,7 +860,7 @@ fn linuxLookupName(
         @memset(@ptrCast([*]u8, &sa6), 0, @sizeOf(os.sockaddr_in6));
         var da6 = os.sockaddr_in6{
             .family = os.AF_INET6,
-            .scope_id = addr.addr.in6.scope_id,
+            .scope_id = addr.addr.in6.sa.scope_id,
             .port = 65535,
             .flowinfo = 0,
             .addr = [1]u8{0} ** 16,
@@ -795,7 +878,7 @@ fn linuxLookupName(
         var salen: os.socklen_t = undefined;
         var dalen: os.socklen_t = undefined;
         if (addr.addr.any.family == os.AF_INET6) {
-            mem.copy(u8, &da6.addr, &addr.addr.in6.addr);
+            mem.copy(u8, &da6.addr, &addr.addr.in6.sa.addr);
             da = @ptrCast(*os.sockaddr, &da6);
             dalen = @sizeOf(os.sockaddr_in6);
             sa = @ptrCast(*os.sockaddr, &sa6);
@@ -803,8 +886,8 @@ fn linuxLookupName(
         } else {
             mem.copy(u8, &sa6.addr, "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff");
             mem.copy(u8, &da6.addr, "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff");
-            mem.writeIntNative(u32, da6.addr[12..], addr.addr.in.addr);
-            da4.addr = addr.addr.in.addr;
+            mem.writeIntNative(u32, da6.addr[12..], addr.addr.in.sa.addr);
+            da4.addr = addr.addr.in.sa.addr;
             da = @ptrCast(*os.sockaddr, &da4);
             dalen = @sizeOf(os.sockaddr_in);
             sa = @ptrCast(*os.sockaddr, &sa4);