master
  1//! An already-validated host name. A valid host name:
  2//! * Has length less than or equal to `max_len`.
  3//! * Is valid UTF-8.
  4//! * Lacks ASCII characters other than alphanumeric, '-', and '.'.
  5const HostName = @This();
  6
  7const builtin = @import("builtin");
  8const native_os = builtin.os.tag;
  9
 10const std = @import("../../std.zig");
 11const Io = std.Io;
 12const IpAddress = Io.net.IpAddress;
 13const Ip6Address = Io.net.Ip6Address;
 14const assert = std.debug.assert;
 15const Stream = Io.net.Stream;
 16
 17/// Externally managed memory. Already checked to be valid.
 18bytes: []const u8,
 19
 20pub const max_len = 255;
 21
 22pub const ValidateError = error{
 23    NameTooLong,
 24    InvalidHostName,
 25};
 26
 27pub fn validate(bytes: []const u8) ValidateError!void {
 28    if (bytes.len > max_len) return error.NameTooLong;
 29    if (!std.unicode.utf8ValidateSlice(bytes)) return error.InvalidHostName;
 30    for (bytes) |byte| {
 31        if (!std.ascii.isAscii(byte) or byte == '.' or byte == '-' or std.ascii.isAlphanumeric(byte)) {
 32            continue;
 33        }
 34        return error.InvalidHostName;
 35    }
 36}
 37
 38pub fn init(bytes: []const u8) ValidateError!HostName {
 39    try validate(bytes);
 40    return .{ .bytes = bytes };
 41}
 42
 43pub fn sameParentDomain(parent_host: HostName, child_host: HostName) bool {
 44    const parent_bytes = parent_host.bytes;
 45    const child_bytes = child_host.bytes;
 46    if (!std.ascii.endsWithIgnoreCase(child_bytes, parent_bytes)) return false;
 47    if (child_bytes.len == parent_bytes.len) return true;
 48    if (parent_bytes.len > child_bytes.len) return false;
 49    return child_bytes[child_bytes.len - parent_bytes.len - 1] == '.';
 50}
 51
 52test sameParentDomain {
 53    try std.testing.expect(!sameParentDomain(try .init("foo.com"), try .init("bar.com")));
 54    try std.testing.expect(sameParentDomain(try .init("foo.com"), try .init("foo.com")));
 55    try std.testing.expect(sameParentDomain(try .init("foo.com"), try .init("bar.foo.com")));
 56    try std.testing.expect(!sameParentDomain(try .init("bar.foo.com"), try .init("foo.com")));
 57}
 58
 59/// Domain names are case-insensitive (RFC 5890, Section 2.3.2.4)
 60pub fn eql(a: HostName, b: HostName) bool {
 61    return std.ascii.eqlIgnoreCase(a.bytes, b.bytes);
 62}
 63
 64pub const LookupOptions = struct {
 65    port: u16,
 66    canonical_name_buffer: *[max_len]u8,
 67    /// `null` means either.
 68    family: ?IpAddress.Family = null,
 69};
 70
 71pub const LookupError = error{
 72    UnknownHostName,
 73    ResolvConfParseFailed,
 74    InvalidDnsARecord,
 75    InvalidDnsAAAARecord,
 76    InvalidDnsCnameRecord,
 77    NameServerFailure,
 78    /// Failed to open or read "/etc/hosts" or "/etc/resolv.conf".
 79    DetectingNetworkConfigurationFailed,
 80} || Io.Clock.Error || IpAddress.BindError || Io.Cancelable;
 81
 82pub const LookupResult = union(enum) {
 83    address: IpAddress,
 84    canonical_name: HostName,
 85    end: LookupError!void,
 86};
 87
 88/// Adds any number of `IpAddress` into resolved, exactly one canonical_name,
 89/// and then always finishes by adding one `LookupResult.end` entry.
 90///
 91/// Guaranteed not to block if provided queue has capacity at least 16.
 92pub fn lookup(
 93    host_name: HostName,
 94    io: Io,
 95    resolved: *Io.Queue(LookupResult),
 96    options: LookupOptions,
 97) void {
 98    return io.vtable.netLookup(io.userdata, host_name, resolved, options);
 99}
100
101pub const ExpandError = error{InvalidDnsPacket} || ValidateError;
102
103/// Decompresses a DNS name.
104///
105/// Returns number of bytes consumed from `packet` starting at `i`,
106/// along with the expanded `HostName`.
107///
108/// Asserts `buffer` is has length at least `max_len`.
109pub fn expand(noalias packet: []const u8, start_i: usize, noalias dest_buffer: []u8) ExpandError!struct { usize, HostName } {
110    const dest = dest_buffer[0..max_len];
111
112    var i = start_i;
113    var dest_i: usize = 0;
114    var len: ?usize = null;
115
116    // Detect reference loop using an iteration counter.
117    for (0..packet.len / 2) |_| {
118        if (i >= packet.len) return error.InvalidDnsPacket;
119
120        const c = packet[i];
121        if ((c & 0xc0) != 0) {
122            if (i + 1 >= packet.len) return error.InvalidDnsPacket;
123            const j: usize = (@as(usize, c & 0x3F) << 8) | packet[i + 1];
124            if (j >= packet.len) return error.InvalidDnsPacket;
125            if (len == null) len = (i + 2) - start_i;
126            i = j;
127        } else if (c != 0) {
128            if (dest_i != 0) {
129                dest[dest_i] = '.';
130                dest_i += 1;
131            }
132            const label_len: usize = c;
133            if (i + 1 + label_len > packet.len) return error.InvalidDnsPacket;
134            if (dest_i + label_len + 1 > dest.len) return error.InvalidDnsPacket;
135            @memcpy(dest[dest_i..][0..label_len], packet[i + 1 ..][0..label_len]);
136            dest_i += label_len;
137            i += 1 + label_len;
138        } else {
139            return .{
140                len orelse i - start_i + 1,
141                try .init(dest[0..dest_i]),
142            };
143        }
144    }
145    return error.InvalidDnsPacket;
146}
147
148pub const DnsRecord = enum(u8) {
149    A = 1,
150    CNAME = 5,
151    AAAA = 28,
152    _,
153};
154
155pub const DnsResponse = struct {
156    bytes: []const u8,
157    bytes_index: u32,
158    answers_remaining: u16,
159
160    pub const Answer = struct {
161        rr: DnsRecord,
162        packet: []const u8,
163        data_off: u32,
164        data_len: u16,
165    };
166
167    pub const Error = error{InvalidDnsPacket};
168
169    pub fn init(r: []const u8) Error!DnsResponse {
170        if (r.len < 12) return error.InvalidDnsPacket;
171        if ((r[3] & 15) != 0) return .{ .bytes = r, .bytes_index = 3, .answers_remaining = 0 };
172        var i: u32 = 12;
173        var query_count = std.mem.readInt(u16, r[4..6], .big);
174        while (query_count != 0) : (query_count -= 1) {
175            while (i < r.len and r[i] -% 1 < 127) i += 1;
176            if (r.len - i < 6) return error.InvalidDnsPacket;
177            i = i + 5 + @intFromBool(r[i] != 0);
178        }
179        return .{
180            .bytes = r,
181            .bytes_index = i,
182            .answers_remaining = std.mem.readInt(u16, r[6..8], .big),
183        };
184    }
185
186    pub fn next(dr: *DnsResponse) Error!?Answer {
187        if (dr.answers_remaining == 0) return null;
188        dr.answers_remaining -= 1;
189        const r = dr.bytes;
190        var i = dr.bytes_index;
191        while (i < r.len and r[i] -% 1 < 127) i += 1;
192        if (r.len - i < 12) return error.InvalidDnsPacket;
193        i = i + 1 + @intFromBool(r[i] != 0);
194        const len = std.mem.readInt(u16, r[i + 8 ..][0..2], .big);
195        if (i + 10 + len > r.len) return error.InvalidDnsPacket;
196        defer dr.bytes_index = i + 10 + len;
197        return .{
198            .rr = @enumFromInt(r[i + 1]),
199            .packet = r,
200            .data_off = i + 10,
201            .data_len = len,
202        };
203    }
204};
205
206pub const ConnectError = LookupError || IpAddress.ConnectError;
207
208pub fn connect(
209    host_name: HostName,
210    io: Io,
211    port: u16,
212    options: IpAddress.ConnectOptions,
213) ConnectError!Stream {
214    var connect_many_buffer: [32]ConnectManyResult = undefined;
215    var connect_many_queue: Io.Queue(ConnectManyResult) = .init(&connect_many_buffer);
216
217    var connect_many = io.async(connectMany, .{ host_name, io, port, &connect_many_queue, options });
218    var saw_end = false;
219    defer {
220        connect_many.cancel(io);
221        if (!saw_end) while (true) switch (connect_many_queue.getOneUncancelable(io)) {
222            .connection => |loser| if (loser) |s| s.close(io) else |_| continue,
223            .end => break,
224        };
225    }
226
227    var aggregate_error: ConnectError = error.UnknownHostName;
228
229    while (connect_many_queue.getOne(io)) |result| switch (result) {
230        .connection => |connection| if (connection) |stream| return stream else |err| switch (err) {
231            error.SystemResources,
232            error.OptionUnsupported,
233            error.ProcessFdQuotaExceeded,
234            error.SystemFdQuotaExceeded,
235            error.Canceled,
236            => |e| return e,
237
238            error.WouldBlock => return error.Unexpected,
239
240            else => |e| aggregate_error = e,
241        },
242        .end => |end| {
243            saw_end = true;
244            try end;
245            return aggregate_error;
246        },
247    } else |err| switch (err) {
248        error.Canceled => |e| return e,
249    }
250}
251
252pub const ConnectManyResult = union(enum) {
253    connection: IpAddress.ConnectError!Stream,
254    end: ConnectError!void,
255};
256
257/// Asynchronously establishes a connection to all IP addresses associated with
258/// a host name, adding them to a results queue upon completion.
259pub fn connectMany(
260    host_name: HostName,
261    io: Io,
262    port: u16,
263    results: *Io.Queue(ConnectManyResult),
264    options: IpAddress.ConnectOptions,
265) void {
266    var canonical_name_buffer: [max_len]u8 = undefined;
267    var lookup_buffer: [32]HostName.LookupResult = undefined;
268    var lookup_queue: Io.Queue(LookupResult) = .init(&lookup_buffer);
269    var group: Io.Group = .init;
270    defer group.cancel(io);
271
272    group.async(io, lookup, .{ host_name, io, &lookup_queue, .{
273        .port = port,
274        .canonical_name_buffer = &canonical_name_buffer,
275    } });
276
277    while (lookup_queue.getOne(io)) |dns_result| switch (dns_result) {
278        .address => |address| group.async(io, enqueueConnection, .{ address, io, results, options }),
279        .canonical_name => continue,
280        .end => |lookup_result| {
281            group.wait(io);
282            results.putOneUncancelable(io, .{ .end = lookup_result });
283            return;
284        },
285    } else |err| switch (err) {
286        error.Canceled => |e| {
287            group.cancel(io);
288            results.putOneUncancelable(io, .{ .end = e });
289        },
290    }
291}
292
293fn enqueueConnection(
294    address: IpAddress,
295    io: Io,
296    queue: *Io.Queue(ConnectManyResult),
297    options: IpAddress.ConnectOptions,
298) void {
299    queue.putOneUncancelable(io, .{ .connection = address.connect(io, options) });
300}
301
302pub const ResolvConf = struct {
303    attempts: u32,
304    ndots: u32,
305    timeout_seconds: u32,
306    nameservers_buffer: [max_nameservers]IpAddress,
307    nameservers_len: usize,
308    search_buffer: [max_len]u8,
309    search_len: usize,
310
311    /// According to resolv.conf(5) there is a maximum of 3 nameservers in this
312    /// file.
313    pub const max_nameservers = 3;
314
315    /// Returns `error.StreamTooLong` if a line is longer than 512 bytes.
316    pub fn init(io: Io) !ResolvConf {
317        var rc: ResolvConf = .{
318            .nameservers_buffer = undefined,
319            .nameservers_len = 0,
320            .search_buffer = undefined,
321            .search_len = 0,
322            .ndots = 1,
323            .timeout_seconds = 5,
324            .attempts = 2,
325        };
326
327        const file = Io.File.openAbsolute(io, "/etc/resolv.conf", .{}) catch |err| switch (err) {
328            error.FileNotFound,
329            error.NotDir,
330            error.AccessDenied,
331            => {
332                try addNumeric(&rc, io, "127.0.0.1", 53);
333                return rc;
334            },
335
336            else => |e| return e,
337        };
338        defer file.close(io);
339
340        var line_buf: [512]u8 = undefined;
341        var file_reader = file.reader(io, &line_buf);
342        parse(&rc, io, &file_reader.interface) catch |err| switch (err) {
343            error.ReadFailed => return file_reader.err.?,
344            else => |e| return e,
345        };
346        return rc;
347    }
348
349    const Directive = enum { options, nameserver, domain, search };
350    const Option = enum { ndots, attempts, timeout };
351
352    pub fn parse(rc: *ResolvConf, io: Io, reader: *Io.Reader) !void {
353        while (reader.takeSentinel('\n')) |line_with_comment| {
354            const line = line: {
355                var split = std.mem.splitScalar(u8, line_with_comment, '#');
356                break :line split.first();
357            };
358            var line_it = std.mem.tokenizeAny(u8, line, " \t");
359
360            const token = line_it.next() orelse continue;
361            switch (std.meta.stringToEnum(Directive, token) orelse continue) {
362                .options => while (line_it.next()) |sub_tok| {
363                    var colon_it = std.mem.splitScalar(u8, sub_tok, ':');
364                    const name = colon_it.first();
365                    const value_txt = colon_it.next() orelse continue;
366                    const value = std.fmt.parseInt(u8, value_txt, 10) catch |err| switch (err) {
367                        error.Overflow => 255,
368                        error.InvalidCharacter => continue,
369                    };
370                    switch (std.meta.stringToEnum(Option, name) orelse continue) {
371                        .ndots => rc.ndots = @min(value, 15),
372                        .attempts => rc.attempts = @min(value, 10),
373                        .timeout => rc.timeout_seconds = @min(value, 60),
374                    }
375                },
376                .nameserver => {
377                    const ip_txt = line_it.next() orelse continue;
378                    try addNumeric(rc, io, ip_txt, 53);
379                },
380                .domain, .search => {
381                    const rest = line_it.rest();
382                    @memcpy(rc.search_buffer[0..rest.len], rest);
383                    rc.search_len = rest.len;
384                },
385            }
386        } else |err| switch (err) {
387            error.EndOfStream => if (reader.bufferedLen() != 0) return error.EndOfStream,
388            else => |e| return e,
389        }
390
391        if (rc.nameservers_len == 0) {
392            try addNumeric(rc, io, "127.0.0.1", 53);
393        }
394    }
395
396    fn addNumeric(rc: *ResolvConf, io: Io, name: []const u8, port: u16) !void {
397        if (rc.nameservers_len < rc.nameservers_buffer.len) {
398            rc.nameservers_buffer[rc.nameservers_len] = try .resolve(io, name, port);
399            rc.nameservers_len += 1;
400        }
401    }
402
403    pub fn nameservers(rc: *const ResolvConf) []const IpAddress {
404        return rc.nameservers_buffer[0..rc.nameservers_len];
405    }
406};
407
408test ResolvConf {
409    const input =
410        \\# Generated by resolvconf
411        \\nameserver 1.0.0.1
412        \\nameserver 1.1.1.1
413        \\nameserver fe80::e0e:76ff:fed4:cf22
414        \\options edns0
415        \\
416    ;
417    var reader: Io.Reader = .fixed(input);
418
419    var rc: ResolvConf = .{
420        .nameservers_buffer = undefined,
421        .nameservers_len = 0,
422        .search_buffer = undefined,
423        .search_len = 0,
424        .ndots = 1,
425        .timeout_seconds = 5,
426        .attempts = 2,
427    };
428
429    try rc.parse(std.testing.io, &reader);
430    try std.testing.expectEqual(3, rc.nameservers().len);
431}