Commit 9ade31faaf

Andrew Kelley <andrew@ziglang.org>
2019-10-29 19:03:39
implement CNAME expansion
1 parent 67058b9
Changed files (3)
lib/std/c.zig
@@ -191,3 +191,11 @@ pub extern "c" fn getnameinfo(
 pub extern "c" fn gai_strerror(errcode: c_int) [*]const u8;
 
 pub extern "c" fn poll(fds: [*]pollfd, nfds: nfds_t, timeout: c_int) c_int;
+
+pub extern "c" fn dn_expand(
+    msg: [*]const u8,
+    eomorig: [*]const u8,
+    comp_dn: [*]const u8,
+    exp_dn: [*]u8,
+    length: c_int,
+) c_int;
lib/std/net.zig
@@ -1017,7 +1017,6 @@ fn dnsParse(
 }
 
 fn dnsParseCallback(ctx: dpc_ctx, rr: u8, data: []const u8, packet: []const u8) !void {
-    var tmp: [256]u8 = undefined;
     switch (rr) {
         os.RR_A => {
             if (data.len != 4) return error.InvalidDnsARecord;
@@ -1038,10 +1037,13 @@ fn dnsParseCallback(ctx: dpc_ctx, rr: u8, data: []const u8, packet: []const u8)
             mem.copy(u8, &new_addr.addr, data);
         },
         os.RR_CNAME => {
-            @panic("TODO dn_expand");
-            //if (__dn_expand(packet, (const unsigned char *)packet + 512,
-            //    data, tmp, sizeof tmp) > 0 && is_valid_hostname(tmp))
-            //    strcpy(ctx->canon, tmp);
+            var tmp: [256]u8 = undefined;
+            // Returns len of compressed name. strlen to get canon name.
+            _ = try os.dn_expand(packet, data, &tmp);
+            const canon_name = mem.toSliceConst(u8, &tmp);
+            if (isValidHostName(canon_name)) {
+                try ctx.canon.replaceContents(canon_name);
+            }
         },
         else => return,
     }
lib/std/os.zig
@@ -3076,3 +3076,51 @@ pub fn recvfrom(
         }
     }
 }
+
+pub const DnExpandError = error{InvalidDnsPacket};
+
+pub fn dn_expand(
+    msg: []const u8,
+    comp_dn: []const u8,
+    exp_dn: []u8,
+) DnExpandError!usize {
+    var p = comp_dn.ptr;
+    var len: usize = std.math.maxInt(usize);
+    const end = msg.ptr + msg.len;
+    if (p == end or exp_dn.len == 0) return error.InvalidDnsPacket;
+    var dest = exp_dn.ptr;
+    const dend = dest + std.math.min(exp_dn.len, 254);
+    // detect reference loop using an iteration counter
+    var i: usize = 0;
+    while (i < msg.len) : (i += 2) {
+        // loop invariants: p<end, dest<dend
+        if ((p[0] & 0xc0) != 0) {
+            if (p + 1 == end) return error.InvalidDnsPacket;
+            var j = ((p[0] & usize(0x3f)) << 8) | p[1];
+            if (len == std.math.maxInt(usize)) len = @ptrToInt(p) + 2 - @ptrToInt(comp_dn.ptr);
+            if (j >= msg.len) return error.InvalidDnsPacket;
+            p = msg.ptr + j;
+        } else if (p[0] != 0) {
+            if (dest != exp_dn.ptr) {
+                dest.* = '.';
+                dest += 1;
+            }
+            var j = p[0];
+            p += 1;
+            if (j >= @ptrToInt(end) - @ptrToInt(p) or j >= @ptrToInt(dend) - @ptrToInt(dest)) {
+                return error.InvalidDnsPacket;
+            }
+            while (j != 0) {
+                j -= 1;
+                dest.* = p[0];
+                dest += 1;
+                p += 1;
+            }
+        } else {
+            dest.* = 0;
+            if (len == std.math.maxInt(usize)) len = @ptrToInt(p) + 1 - @ptrToInt(comp_dn.ptr);
+            return len;
+        }
+    }
+    return error.InvalidDnsPacket;
+}