Commit 0a4130f364

Nameless <truemedian@gmail.com>
2023-03-07 06:35:35
std.http: handle relative redirects
1 parent fd2f906
Changed files (4)
lib
lib/std/crypto/tls/Client.zig
@@ -89,7 +89,7 @@ pub const StreamInterface = struct {
 };
 
 pub fn InitError(comptime Stream: type) type {
-    return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || error {
+    return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || error{
         InsufficientEntropy,
         DiskQuota,
         LockViolation,
lib/std/http/Client.zig
@@ -29,9 +29,10 @@ const ConnectionPool = std.TailQueue(Connection);
 const ConnectionNode = ConnectionPool.Node;
 
 /// Acquires an existing connection from the connection pool. This function is threadsafe.
-pub fn acquire(client: *Client, node: *ConnectionNode) void {
-    client.connection_mutex.lock();
-    defer client.connection_mutex.unlock();
+/// If the caller already holds the connection mutex, it should pass `true` for `held`.
+pub fn acquire(client: *Client, node: *ConnectionNode, held: bool) void {
+    if (!held) client.connection_mutex.lock();
+    defer if (!held) client.connection_mutex.unlock();
 
     client.connection_pool.remove(node);
     client.connection_used.append(node);
@@ -40,16 +41,17 @@ pub fn acquire(client: *Client, node: *ConnectionNode) void {
 /// Tries to release a connection back to the connection pool. This function is threadsafe.
 /// If the connection is marked as closing, it will be closed instead.
 pub fn release(client: *Client, node: *ConnectionNode) void {
+    client.connection_mutex.lock();
+    defer client.connection_mutex.unlock();
+
+    client.connection_used.remove(node);
+
     if (node.data.closing) {
         node.data.close(client);
 
         return client.allocator.destroy(node);
     }
 
-    client.connection_mutex.lock();
-    defer client.connection_mutex.unlock();
-
-    client.connection_used.remove(node);
     client.connection_pool.append(node);
 }
 
@@ -83,7 +85,7 @@ pub const Connection = struct {
         }
     }
 
-    pub const ReadError = std.net.Stream.ReadError || error{
+    pub const ReadError = net.Stream.ReadError || error{
         TlsConnectionTruncated,
         TlsRecordOverflow,
         TlsDecodeError,
@@ -115,7 +117,7 @@ pub const Connection = struct {
         }
     }
 
-    pub const WriteError = std.net.Stream.WriteError || error{};
+    pub const WriteError = net.Stream.WriteError || error{};
     pub const Writer = std.io.Writer(*Connection, WriteError, write);
 
     pub fn writer(conn: *Connection) Writer {
@@ -139,14 +141,21 @@ pub const Request = struct {
     const read_buffer_size = 8192;
     const ReadBufferIndex = std.math.IntFittingRange(0, read_buffer_size);
 
+    uri: Uri,
     client: *Client,
     connection: *ConnectionNode,
-    redirects_left: u32,
     response: Response,
     /// These are stored in Request so that they are available when following
     /// redirects.
     headers: Headers,
 
+    redirects_left: u32,
+    handle_redirects: bool,
+    compression_init: bool,
+
+    /// Used as a allocator for resolving redirects locations.
+    arena: std.heap.ArenaAllocator,
+
     /// Read buffer for the connection. This is used to pull in large amounts of data from the connection even if the user asks for a small amount. This can probably be removed with careful planning.
     read_buffer: [read_buffer_size]u8 = undefined,
     read_buffer_start: ReadBufferIndex = 0,
@@ -661,6 +670,7 @@ pub const Request = struct {
     pub const Headers = struct {
         version: http.Version = .@"HTTP/1.1",
         method: http.Method = .GET,
+        user_agent: []const u8 = "Zig (std.http)",
         connection: http.Connection = .keep_alive,
         transfer_encoding: RequestTransfer = .none,
 
@@ -668,6 +678,7 @@ pub const Request = struct {
     };
 
     pub const Options = struct {
+        handle_redirects: bool = true,
         max_redirects: u32 = 3,
         header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 },
 
@@ -703,10 +714,11 @@ pub const Request = struct {
             req.client.release(req.connection);
         }
 
+        req.arena.deinit();
         req.* = undefined;
     }
 
-    const ReadRawError = Connection.ReadError || std.Uri.ParseError || RequestError || error{
+    const ReadRawError = Connection.ReadError || Uri.ParseError || RequestError || error{
         UnexpectedEndOfStream,
         TooManyHttpRedirects,
         HttpRedirectMissingLocation,
@@ -723,9 +735,7 @@ pub const Request = struct {
         var index: usize = 0;
         while (index == 0) {
             const amt = try req.readRawAdvanced(buffer[index..]);
-            const zero_means_end = req.response.done and req.response.headers.status.class() != .redirect;
-
-            if (amt == 0 and zero_means_end) break;
+            if (amt == 0 and req.response.done) break;
             index += amt;
         }
 
@@ -769,6 +779,8 @@ pub const Request = struct {
                 }
             } else if (req.response.headers.content_length) |content_length| {
                 req.response.next_chunk_length = content_length;
+
+                if (content_length == 0) req.response.done = true;
             } else {
                 req.response.done = true;
             }
@@ -779,7 +791,7 @@ pub const Request = struct {
         return 0;
     }
 
-    pub const WaitForCompleteHeadError = ReadRawError || error {
+    pub const WaitForCompleteHeadError = ReadRawError || error{
         UnexpectedEndOfStream,
 
         HttpHeadersExceededSizeLimit,
@@ -810,27 +822,8 @@ pub const Request = struct {
 
     /// This one can return 0 without meaning EOF.
     fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
-        if (req.response.done) {
-            if (req.response.headers.status.class() == .redirect) {
-                if (req.redirects_left == 0) return error.TooManyHttpRedirects;
-
-                const location = req.response.headers.location orelse
-                    return error.HttpRedirectMissingLocation;
-                const new_url = try std.Uri.parse(location);
-                const new_req = try req.client.request(new_url, req.headers, .{
-                    .max_redirects = req.redirects_left - 1,
-                    .header_strategy = if (req.response.header_bytes_owned) .{
-                        .dynamic = req.response.max_header_bytes,
-                    } else .{
-                        .static = req.response.header_bytes.unusedCapacitySlice(),
-                    },
-                });
-                req.deinit();
-                req.* = new_req;
-            } else {
-                return 0;
-            }
-        }
+        assert(req.response.state.isContent());
+        if (req.response.done) return 0;
 
         // var in: []const u8 = undefined;
         if (req.read_buffer_start == req.read_buffer_len) {
@@ -851,7 +844,7 @@ pub const Request = struct {
                     const data_avail = req.response.next_chunk_length;
                     const out_avail = buffer.len;
 
-                    if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) {
+                    if (req.handle_redirects and req.response.headers.status.class() == .redirect) {
                         const can_read = @intCast(usize, @min(buf_avail, data_avail));
                         req.response.next_chunk_length -= can_read;
 
@@ -859,7 +852,6 @@ pub const Request = struct {
                             req.client.release(req.connection);
                             req.connection = undefined;
                             req.response.done = true;
-                            continue;
                         }
 
                         return 0; // skip over as much data as possible
@@ -943,7 +935,7 @@ pub const Request = struct {
                     const data_avail = req.response.next_chunk_length;
                     const out_avail = buffer.len - out_index;
 
-                    if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) {
+                    if (req.handle_redirects and req.response.headers.status.class() == .redirect) {
                         const can_read = @intCast(usize, @min(buf_avail, data_avail));
                         req.response.next_chunk_length -= can_read;
 
@@ -990,9 +982,41 @@ pub const Request = struct {
     }
 
     pub fn read(req: *Request, buffer: []u8) ReadError!usize {
-        if (!req.response.state.isContent()) try req.waitForCompleteHead();
+        while (true) {
+            if (!req.response.state.isContent()) try req.waitForCompleteHead();
+
+            if (req.handle_redirects and req.response.headers.status.class() == .redirect) {
+                assert(try req.readRaw(buffer) == 0);
+
+                if (req.redirects_left == 0) return error.TooManyHttpRedirects;
+
+                const location = req.response.headers.location orelse
+                    return error.HttpRedirectMissingLocation;
+                const new_url = Uri.parse(location) catch try Uri.parseWithoutScheme(location);
+
+                var new_arena = std.heap.ArenaAllocator.init(req.client.allocator);
+                const resolved_url = try req.uri.resolve(new_url, false, new_arena.allocator());
+                errdefer new_arena.deinit();
+
+                req.arena.deinit();
+                req.arena = new_arena;
+
+                const new_req = try req.client.request(resolved_url, req.headers, .{
+                    .max_redirects = req.redirects_left - 1,
+                    .header_strategy = if (req.response.header_bytes_owned) .{
+                        .dynamic = req.response.max_header_bytes,
+                    } else .{
+                        .static = req.response.header_bytes.unusedCapacitySlice(),
+                    },
+                });
+                req.deinit();
+                req.* = new_req;
+            } else {
+                break;
+            }
+        }
 
-        if (req.response.compression == .none and req.response.state.isContent()) {
+        if (req.response.compression == .none) {
             if (req.response.headers.transfer_compression) |compression| {
                 switch (compression) {
                     .compress => unreachable,
@@ -1084,6 +1108,8 @@ pub const Request = struct {
 };
 
 pub fn deinit(client: *Client) void {
+    client.connection_mutex.lock();
+
     var next = client.connection_pool.first;
     while (next) |node| {
         next = node.next;
@@ -1106,7 +1132,7 @@ pub fn deinit(client: *Client) void {
     client.* = undefined;
 }
 
-pub const ConnectError = std.mem.Allocator.Error || std.net.TcpConnectToHostError || std.crypto.tls.Client.InitError(std.net.Stream);
+pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream);
 
 pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionNode {
     { // Search through the connection pool for a potential connection.
@@ -1120,7 +1146,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
             const same_protocol = node.data.protocol == protocol;
 
             if (same_host and same_port and same_protocol) {
-                client.acquire(node);
+                client.acquire(node, true);
                 return node;
             }
 
@@ -1168,6 +1194,7 @@ pub const RequestError = ConnectError || Connection.WriteError || error{
     InvalidPadding,
     MissingEndCertificateMarker,
     Unseekable,
+    EndOfStream,
 };
 
 pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) RequestError!Request {
@@ -1196,27 +1223,52 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req
     }
 
     var req: Request = .{
+        .uri = uri,
         .client = client,
         .headers = headers,
         .connection = try client.connect(host, port, protocol),
         .redirects_left = options.max_redirects,
+        .handle_redirects = options.handle_redirects,
+        .compression_init = false,
         .response = switch (options.header_strategy) {
             .dynamic => |max| Request.Response.initDynamic(max),
             .static => |buf| Request.Response.initStatic(buf),
         },
+        .arena = undefined,
     };
 
+    req.arena = std.heap.ArenaAllocator.init(client.allocator);
+
     {
         var buffered = std.io.bufferedWriter(req.connection.data.writer());
         const writer = buffered.writer();
 
+        const escaped_path = try Uri.escapePath(client.allocator, uri.path);
+        defer client.allocator.free(escaped_path);
+
+        const escaped_query = if (uri.query) |q| try Uri.escapeQuery(client.allocator, q) else null;
+        defer if (escaped_query) |q| client.allocator.free(q);
+
+        const escaped_fragment = if (uri.fragment) |f| try Uri.escapeQuery(client.allocator, f) else null;
+        defer if (escaped_fragment) |f| client.allocator.free(f);
+
         try writer.writeAll(@tagName(headers.method));
         try writer.writeByte(' ');
-        try writer.writeAll(uri.path);
+        try writer.writeAll(escaped_path);
+        if (escaped_query) |q| {
+            try writer.writeByte('?');
+            try writer.writeAll(q);
+        }
+        if (escaped_fragment) |f| {
+            try writer.writeByte('#');
+            try writer.writeAll(f);
+        }
         try writer.writeByte(' ');
         try writer.writeAll(@tagName(headers.version));
         try writer.writeAll("\r\nHost: ");
         try writer.writeAll(host);
+        try writer.writeAll("\r\nUser-Agent: ");
+        try writer.writeAll(headers.user_agent);
         if (headers.connection == .close) {
             try writer.writeAll("\r\nConnection: close");
         } else {
lib/std/net.zig
@@ -741,9 +741,9 @@ pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream {
     return Stream{ .handle = sockfd };
 }
 
-const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || std.os.SocketError || std.os.BindError || error {
+const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || std.os.SocketError || std.os.BindError || error{
     // TODO: break this up into error sets from the various underlying functions
-    
+
     TemporaryNameServerFailure,
     NameServerFailure,
     AddressFamilyNotSupported,
@@ -760,7 +760,7 @@ const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError ||
     Incomplete,
     InvalidIpv4Mapping,
     InvalidIPAddressFormat,
-    
+
     InterfaceNotFound,
     FileSystem,
 };
lib/std/Uri.zig
@@ -16,15 +16,27 @@ fragment: ?[]const u8,
 
 /// Applies URI encoding and replaces all reserved characters with their respective %XX code.
 pub fn escapeString(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 {
+    return escapeStringWithFn(allocator, input, isUnreserved);
+}
+
+pub fn escapePath(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 {
+    return escapeStringWithFn(allocator, input, isPathChar);
+}
+
+pub fn escapeQuery(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 {
+    return escapeStringWithFn(allocator, input, isQueryChar);
+}
+
+pub fn escapeStringWithFn(allocator: std.mem.Allocator, input: []const u8, comptime keepUnescaped: fn (c: u8) bool) std.mem.Allocator.Error![]const u8 {
     var outsize: usize = 0;
     for (input) |c| {
-        outsize += if (isUnreserved(c)) @as(usize, 1) else 3;
+        outsize += if (keepUnescaped(c)) @as(usize, 1) else 3;
     }
     var output = try allocator.alloc(u8, outsize);
     var outptr: usize = 0;
 
     for (input) |c| {
-        if (isUnreserved(c)) {
+        if (keepUnescaped(c)) {
             output[outptr] = c;
             outptr += 1;
         } else {
@@ -94,13 +106,14 @@ pub fn unescapeString(allocator: std.mem.Allocator, input: []const u8) error{Out
 
 pub const ParseError = error{ UnexpectedCharacter, InvalidFormat, InvalidPort };
 
-/// Parses the URI or returns an error.
+/// Parses the URI or returns an error. This function is not compliant, but is required to parse
+/// some forms of URIs in the wild. Such as HTTP Location headers.
 /// The return value will contain unescaped strings pointing into the
 /// original `text`. Each component that is provided, will be non-`null`.
-pub fn parse(text: []const u8) ParseError!Uri {
+pub fn parseWithoutScheme(text: []const u8) ParseError!Uri {
     var reader = SliceReader{ .slice = text };
     var uri = Uri{
-        .scheme = reader.readWhile(isSchemeChar),
+        .scheme = "",
         .user = null,
         .password = null,
         .host = null,
@@ -110,14 +123,6 @@ pub fn parse(text: []const u8) ParseError!Uri {
         .fragment = null,
     };
 
-    // after the scheme, a ':' must appear
-    if (reader.get()) |c| {
-        if (c != ':')
-            return error.UnexpectedCharacter;
-    } else {
-        return error.InvalidFormat;
-    }
-
     if (reader.peekPrefix("//")) { // authority part
         std.debug.assert(reader.get().? == '/');
         std.debug.assert(reader.get().? == '/');
@@ -179,6 +184,76 @@ pub fn parse(text: []const u8) ParseError!Uri {
     return uri;
 }
 
+/// Parses the URI or returns an error.
+/// The return value will contain unescaped strings pointing into the
+/// original `text`. Each component that is provided, will be non-`null`.
+pub fn parse(text: []const u8) ParseError!Uri {
+    var reader = SliceReader{ .slice = text };
+    const scheme = reader.readWhile(isSchemeChar);
+
+    // after the scheme, a ':' must appear
+    if (reader.get()) |c| {
+        if (c != ':')
+            return error.UnexpectedCharacter;
+    } else {
+        return error.InvalidFormat;
+    }
+
+    var uri = try parseWithoutScheme(reader.readUntilEof());
+    uri.scheme = scheme;
+
+    return uri;
+}
+
+/// Resolves a URI against a base URI, conforming to RFC 3986, Section 5.
+/// arena owns any memory allocated by this function.
+pub fn resolve(Base: Uri, R: Uri, strict: bool, arena: std.mem.Allocator) !Uri {
+    var T: Uri = undefined;
+
+    if (R.scheme.len > 0 and !((!strict) and (std.mem.eql(u8, R.scheme, Base.scheme)))) {
+        T.scheme = R.scheme;
+        T.user = R.user;
+        T.host = R.host;
+        T.port = R.port;
+        T.path = try std.fs.path.resolvePosix(arena, &.{ "/", R.path });
+        T.query = R.query;
+    } else {
+        if (R.host) |host| {
+            T.user = R.user;
+            T.host = host;
+            T.port = R.port;
+            T.path = R.path;
+            T.path = try std.fs.path.resolvePosix(arena, &.{ "/", R.path });
+            T.query = R.query;
+        } else {
+            if (R.path.len == 0) {
+                T.path = Base.path;
+                if (R.query) |query| {
+                    T.query = query;
+                } else {
+                    T.query = Base.query;
+                }
+            } else {
+                if (R.path[0] == '/') {
+                    T.path = try std.fs.path.resolvePosix(arena, &.{ "/", R.path });
+                } else {
+                    T.path = try std.fs.path.resolvePosix(arena, &.{ "/", Base.path, R.path });
+                }
+                T.query = R.query;
+            }
+
+            T.user = Base.user;
+            T.host = Base.host;
+            T.port = Base.port;
+        }
+        T.scheme = Base.scheme;
+    }
+
+    T.fragment = R.fragment;
+
+    return T;
+}
+
 const SliceReader = struct {
     const Self = @This();
 
@@ -284,6 +359,14 @@ fn isPathSeparator(c: u8) bool {
     };
 }
 
+fn isPathChar(c: u8) bool {
+    return isUnreserved(c) or isSubLimit(c) or c == '/' or c == ':' or c == '@';
+}
+
+fn isQueryChar(c: u8) bool {
+    return isPathChar(c) or c == '?';
+}
+
 fn isQuerySeparator(c: u8) bool {
     return switch (c) {
         '#' => true,