Commit 85a6fea3be

Andrew Kelley <andrew@ziglang.org>
2025-10-03 05:45:16
std.Io.net.HostName: implement DNS name expansion
1 parent f1a590c
Changed files (1)
lib
std
lib/std/Io/net/HostName.zig
@@ -51,6 +51,7 @@ pub const LookupError = error{
     ResolvConfParseFailed,
     InvalidDnsARecord,
     InvalidDnsAAAARecord,
+    InvalidDnsCnameRecord,
     NameServerFailure,
 } || Io.Timestamp.Error || IpAddress.BindError || Io.File.OpenError || Io.File.Reader.Error || Io.Cancelable;
 
@@ -381,16 +382,8 @@ fn lookupDns(io: Io, lookup_canon_name: []const u8, rc: *const ResolvConf, optio
                 addresses_len += 1;
             },
             std.posix.RR.CNAME => {
-                _ = &canonical_name;
-                @panic("TODO");
-                //var tmp: [256]u8 = undefined;
-                //// Returns len of compressed name. strlen to get canon name.
-                //_ = try posix.dn_expand(packet, record.data, &tmp);
-                //const canon_name = mem.sliceTo(&tmp, 0);
-                //if (isValidHostName(canon_name)) {
-                //    ctx.canon.items.len = 0;
-                //    try ctx.canon.appendSlice(gpa, canon_name);
-                //}
+                _, canonical_name = expand(record.packet, record.data_off, options.canonical_name_buffer) catch
+                    return error.InvalidDnsCnameRecord;
             },
             else => continue,
         };
@@ -525,51 +518,50 @@ fn writeResolutionQuery(q: *[280]u8, op: u4, dname: []const u8, class: u8, ty: u
     return n;
 }
 
-pub const ExpandDomainNameError = error{InvalidDnsPacket};
-
-pub fn expandDomainName(
-    msg: []const u8,
-    comp_dn: []const u8,
-    exp_dn: []u8,
-) ExpandDomainNameError!usize {
-    // This implementation is ported from musl libc.
-    // A more idiomatic "ziggy" implementation would be welcome.
-    var p = comp_dn.ptr;
-    var len: usize = std.math.maxInt(usize);
-    const end = msg.ptr + msg.len;
-    if (p == end or exp_dn.len == 0) return error.InvalidDnsPacket;
-    var dest = exp_dn.ptr;
-    const dend = dest + @min(exp_dn.len, 254);
-    // detect reference loop using an iteration counter
-    var i: usize = 0;
-    while (i < msg.len) : (i += 2) {
-        // loop invariants: p<end, dest<dend
-        if ((p[0] & 0xc0) != 0) {
-            if (p + 1 == end) return error.InvalidDnsPacket;
-            const j = @as(usize, p[0] & 0x3f) << 8 | p[1];
-            if (len == std.math.maxInt(usize)) len = @intFromPtr(p) + 2 - @intFromPtr(comp_dn.ptr);
-            if (j >= msg.len) return error.InvalidDnsPacket;
-            p = msg.ptr + j;
-        } else if (p[0] != 0) {
-            if (dest != exp_dn.ptr) {
-                dest[0] = '.';
-                dest += 1;
-            }
-            var j = p[0];
-            p += 1;
-            if (j >= @intFromPtr(end) - @intFromPtr(p) or j >= @intFromPtr(dend) - @intFromPtr(dest)) {
-                return error.InvalidDnsPacket;
-            }
-            while (j != 0) {
-                j -= 1;
-                dest[0] = p[0];
-                dest += 1;
-                p += 1;
+pub const ExpandError = error{InvalidDnsPacket} || InitError;
+
+/// Decompresses a DNS name.
+///
+/// Returns number of bytes consumed from `packet` starting at `i`,
+/// along with the expanded `HostName`.
+///
+/// Asserts `buffer` is has length at least `max_len`.
+pub fn expand(noalias packet: []const u8, start_i: usize, noalias dest_buffer: []u8) ExpandError!struct { usize, HostName } {
+    const dest = dest_buffer[0..max_len];
+
+    var i = start_i;
+    var dest_i: usize = 0;
+    var len: ?usize = null;
+
+    // Detect reference loop using an iteration counter.
+    for (0..packet.len / 2) |_| {
+        if (i >= packet.len) return error.InvalidDnsPacket;
+
+        const c = packet[i];
+        if ((c & 0xc0) != 0) {
+            if (i + 1 >= packet.len) return error.InvalidDnsPacket;
+            const j: usize = (@as(usize, c & 0x3F) << 8) | packet[i + 1];
+            if (j >= packet.len) return error.InvalidDnsPacket;
+            if (len == null) len = (i + 2) - start_i;
+            i = j;
+        } else if (c != 0) {
+            if (dest_i != 0) {
+                dest[dest_i] = '.';
+                dest_i += 1;
             }
+            const label_len: usize = c;
+            if (i + 1 + label_len > packet.len) return error.InvalidDnsPacket;
+            if (dest_i + label_len + 1 > dest.len) return error.InvalidDnsPacket;
+            @memcpy(dest[dest_i..][0..label_len], packet[i + 1 ..][0..label_len]);
+            dest_i += label_len;
+            i += 1 + label_len;
         } else {
-            dest[0] = 0;
-            if (len == std.math.maxInt(usize)) len = @intFromPtr(p) + 1 - @intFromPtr(comp_dn.ptr);
-            return len;
+            dest[dest_i] = 0;
+            dest_i += 1;
+            return .{
+                len orelse i - start_i + 1,
+                try .init(dest[0..dest_i]),
+            };
         }
     }
     return error.InvalidDnsPacket;