Commit 7285eedcd2

Nameless <truemedian@gmail.com>
2023-04-24 18:24:51
std.http: do -> wait, fix redirects
1 parent 1310129
Changed files (3)
lib/std/http/Client.zig
@@ -365,7 +365,7 @@ pub const Response = struct {
         CompressionNotSupported,
     };
 
-    pub fn parse(res: *Response, bytes: []const u8) ParseError!void {
+    pub fn parse(res: *Response, bytes: []const u8, trailing: bool) ParseError!void {
         var it = mem.tokenize(u8, bytes[0 .. bytes.len - 4], "\r\n");
 
         const first_line = it.next() orelse return error.HttpHeadersInvalid;
@@ -398,6 +398,8 @@ pub const Response = struct {
 
             try res.headers.append(header_name, header_value);
 
+            if (trailing) continue;
+
             if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
                 if (res.content_length != null) return error.HttpHeadersInvalid;
                 res.content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength;
@@ -480,7 +482,7 @@ pub const Response = struct {
 
 /// A HTTP request that has been sent.
 ///
-/// Order of operations: request[ -> write -> finish] -> do -> read
+/// Order of operations: request -> start[ -> write -> finish] -> wait -> read
 pub const Request = struct {
     uri: Uri,
     client: *Client,
@@ -508,8 +510,9 @@ pub const Request = struct {
             .zstd => |*zstd| zstd.deinit(),
         }
 
+        req.response.headers.deinit();
+
         if (req.response.parser.header_bytes_owned) {
-            req.response.headers.deinit();
             req.response.parser.header_bytes.deinit(req.client.allocator);
         }
 
@@ -524,6 +527,44 @@ pub const Request = struct {
         req.* = undefined;
     }
 
+    // This function must deallocate all resources associated with the request, or keep those which will be used
+    // This needs to be kept in sync with deinit and request
+    fn redirect(req: *Request, uri: Uri) !void {
+        assert(req.response.parser.done);
+
+        switch (req.response.compression) {
+            .none => {},
+            .deflate => |*deflate| deflate.deinit(),
+            .gzip => |*gzip| gzip.deinit(),
+            .zstd => |*zstd| zstd.deinit(),
+        }
+
+        req.client.connection_pool.release(req.client, req.connection);
+
+        const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme;
+
+        const port: u16 = uri.port orelse switch (protocol) {
+            .plain => 80,
+            .tls => 443,
+        };
+
+        const host = uri.host orelse return error.UriMissingHost;
+
+        req.uri = uri;
+        req.connection = try req.client.connect(host, port, protocol);
+        req.redirects_left -= 1;
+        req.response.headers.clearRetainingCapacity();
+        req.response.parser.reset();
+
+        req.response = .{
+            .status = undefined,
+            .reason = undefined,
+            .version = undefined,
+            .headers = req.response.headers,
+            .parser = req.response.parser,
+        };
+    }
+
     pub const StartError = BufferedConnection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding };
 
     /// Send the request to the server.
@@ -627,14 +668,14 @@ pub const Request = struct {
         return index;
     }
 
-    pub const DoError = RequestError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, HttpRedirectMissingLocation, CompressionInitializationFailed, CompressionNotSupported };
+    pub const WaitError = RequestError || StartError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, CannotRedirect, HttpRedirectMissingLocation, CompressionInitializationFailed, CompressionNotSupported };
 
     /// Waits for a response from the server and parses any headers that are sent.
     /// This function will block until the final response is received.
     ///
-    /// If `handle_redirects` is true, then this function will automatically follow
-    /// redirects.
-    pub fn do(req: *Request) DoError!void {
+    /// If `handle_redirects` is true and the request has no payload, then this function will automatically follow
+    /// redirects. If a request payload is present, then this function will error with error.CannotRedirect.
+    pub fn wait(req: *Request) WaitError!void {
         while (true) { // handle redirects
             while (true) { // read headers
                 try req.connection.data.buffered.fill();
@@ -645,7 +686,7 @@ pub const Request = struct {
                 if (req.response.parser.state.isContent()) break;
             }
 
-            try req.response.parse(req.response.parser.header_bytes.items);
+            try req.response.parse(req.response.parser.header_bytes.items, false);
 
             if (req.response.status == .switching_protocols) {
                 req.connection.data.closing = false;
@@ -684,7 +725,7 @@ pub const Request = struct {
                 req.response.parser.done = true;
             }
 
-            if (req.response.status.class() == .redirect and req.handle_redirects) {
+            if (req.transfer_encoding == .none and req.response.status.class() == .redirect and req.handle_redirects) {
                 req.response.skip = true;
 
                 const empty = @as([*]u8, undefined)[0..0];
@@ -694,26 +735,17 @@ pub const Request = struct {
 
                 const location = req.response.headers.getFirstValue("location") orelse
                     return error.HttpRedirectMissingLocation;
-                const new_url = Uri.parse(location) catch try Uri.parseWithoutScheme(location);
-
-                var new_arena = std.heap.ArenaAllocator.init(req.client.allocator);
-                const resolved_url = try req.uri.resolve(new_url, false, new_arena.allocator());
-                errdefer new_arena.deinit();
-
-                req.arena.deinit();
-                req.arena = new_arena;
-
-                const new_req = try req.client.request(req.method, resolved_url, req.headers, .{
-                    .version = req.version,
-                    .max_redirects = req.redirects_left - 1,
-                    .header_strategy = if (req.response.parser.header_bytes_owned) .{
-                        .dynamic = req.response.parser.max_header_bytes,
-                    } else .{
-                        .static = req.response.parser.header_bytes.items.ptr[0..req.response.parser.max_header_bytes],
-                    },
-                });
-                req.deinit();
-                req.* = new_req;
+
+                const arena = req.arena.allocator();
+
+                const location_duped = try arena.dupe(u8, location);
+
+                const new_url = Uri.parse(location_duped) catch try Uri.parseWithoutScheme(location_duped);
+                const resolved_url = try req.uri.resolve(new_url, false, arena);
+
+                try req.redirect(resolved_url);
+
+                try req.start();
             } else {
                 req.response.skip = false;
                 if (!req.response.parser.done) {
@@ -731,6 +763,9 @@ pub const Request = struct {
                     };
                 }
 
+                if (req.response.status.class() == .redirect and req.handle_redirects and req.transfer_encoding != .none)
+                    return error.CannotRedirect; // The request body has already been sent. The request is still in a valid state, but the redirect must be handled manually.
+
                 break;
             }
         }
@@ -768,7 +803,7 @@ pub const Request = struct {
 
                 // The response headers before the trailers are already guaranteed to be valid, so they will always be parsed again and cannot return an error.
                 // This will *only* fail for a malformed trailer.
-                req.response.parse(req.response.parser.header_bytes.items) catch return error.InvalidTrailers;
+                req.response.parse(req.response.parser.header_bytes.items, true) catch return error.InvalidTrailers;
             }
         }
 
lib/std/http/test.zig
@@ -62,7 +62,7 @@ test "client requests server" {
     try client_req.writeAll("Hello, ");
     try client_req.writeAll("World!\n");
     try client_req.finish();
-    try client_req.do(); // this waits for a response
+    try client_req.wait(); // this waits for a response
 
     const body = try client_req.reader().readAllAlloc(allocator, 8192 * 1024);
     defer allocator.free(body);
src/Package.zig
@@ -486,8 +486,7 @@ fn fetchAndUnpack(
         defer req.deinit();
 
         try req.start();
-
-        try req.do();
+        try req.wait();
 
         if (mem.endsWith(u8, uri.path, ".tar.gz")) {
             // I observed the gzip stream to read 1 byte at a time, so I am using a