Commit a23c8662b4

Nameless <truemedian@gmail.com>
2023-04-18 02:37:24
std.http: pass Method to request directly, parse trailing headers
1 parent e65cbff
Changed files (3)
lib/std/http/Client.zig
@@ -527,7 +527,7 @@ pub const Request = struct {
     pub const StartError = BufferedConnection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding };
 
     /// Send the request to the server.
-    pub fn start(req: *Request, uri: Uri) StartError!void {
+    pub fn start(req: *Request) StartError!void {
         var buffered = std.io.bufferedWriter(req.connection.data.buffered.writer());
         const w = buffered.writer();
 
@@ -535,14 +535,14 @@ pub const Request = struct {
         try w.writeByte(' ');
 
         if (req.method == .CONNECT) {
-            try w.writeAll(uri.host.?);
+            try w.writeAll(req.uri.host.?);
             try w.writeByte(':');
-            try w.print("{}", .{uri.port.?});
+            try w.print("{}", .{req.uri.port.?});
         } else if (req.connection.data.proxied) {
             // proxied connections require the full uri
-            try w.print("{+/}", .{uri});
+            try w.print("{+/}", .{req.uri});
         } else {
-            try w.print("{/}", .{uri});
+            try w.print("{/}", .{req.uri});
         }
 
         try w.writeByte(' ');
@@ -551,7 +551,7 @@ pub const Request = struct {
 
         if (!req.headers.contains("host")) {
             try w.writeAll("Host: ");
-            try w.writeAll(uri.host.?);
+            try w.writeAll(req.uri.host.?);
             try w.writeAll("\r\n");
         }
 
@@ -704,8 +704,7 @@ pub const Request = struct {
                 req.arena.deinit();
                 req.arena = new_arena;
 
-                const new_req = try req.client.request(resolved_url, req.headers, .{
-                    .method = req.method,
+                const new_req = try req.client.request(req.method, resolved_url, req.headers, .{
                     .version = req.version,
                     .max_redirects = req.redirects_left - 1,
                     .header_strategy = if (req.response.parser.header_bytes_owned) .{
@@ -738,7 +737,7 @@ pub const Request = struct {
         }
     }
 
-    pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{DecompressionFailure};
+    pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{ DecompressionFailure, InvalidTrailers };
 
     pub const Reader = std.io.Reader(*Request, ReadError, read);
 
@@ -756,12 +755,22 @@ pub const Request = struct {
         };
 
         if (out_index == 0) {
+            const has_trail = !req.response.parser.state.isContent();
+
             while (!req.response.parser.state.isContent()) { // read trailing headers
                 try req.connection.data.buffered.fill();
 
                 const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
                 req.connection.data.buffered.clear(@intCast(u16, nchecked));
             }
+
+            if (has_trail) {
+                req.response.headers = http.Headers{ .allocator = req.client.allocator, .owned = false };
+
+                // The response headers before the trailers are already guaranteed to be valid, so they will always be parsed again and cannot return an error.
+                // This will *only* fail for a malformed trailer.
+                req.response.parse(req.response.parser.header_bytes.items) catch return error.InvalidTrailers;
+            }
         }
 
         return out_index;
@@ -943,7 +952,6 @@ pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request
 };
 
 pub const Options = struct {
-    method: http.Method = .GET,
     version: http.Version = .@"HTTP/1.1",
 
     handle_redirects: bool = true,
@@ -976,7 +984,7 @@ pub const protocol_map = std.ComptimeStringMap(Connection.Protocol, .{
 
 /// Form and send a http request to a server.
 /// This function is threadsafe.
-pub fn request(client: *Client, uri: Uri, headers: http.Headers, options: Options) RequestError!Request {
+pub fn request(client: *Client, method: http.Method, uri: Uri, headers: http.Headers, options: Options) RequestError!Request {
     const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme;
 
     const port: u16 = uri.port orelse switch (protocol) {
@@ -1003,7 +1011,7 @@ pub fn request(client: *Client, uri: Uri, headers: http.Headers, options: Option
         .client = client,
         .connection = conn,
         .headers = headers,
-        .method = options.method,
+        .method = method,
         .version = options.version,
         .redirects_left = options.max_redirects,
         .handle_redirects = options.handle_redirects,
lib/std/http/Server.zig
@@ -518,12 +518,22 @@ pub const Response = struct {
         };
 
         if (out_index == 0) {
+            const has_trail = !res.request.parser.state.isContent();
+
             while (!res.request.parser.state.isContent()) { // read trailing headers
                 try res.connection.fill();
 
                 const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek());
                 res.connection.clear(@intCast(u16, nchecked));
             }
+
+            if (has_trail) {
+                res.request.headers = http.Headers{ .allocator = res.server.allocator, .owned = false };
+
+                // The response headers before the trailers are already guaranteed to be valid, so they will always be parsed again and cannot return an error.
+                // This will *only* fail for a malformed trailer.
+                res.request.parse(res.request.parser.header_bytes.items) catch return error.InvalidTrailers;
+            }
         }
 
         return out_index;
src/Package.zig
@@ -482,7 +482,7 @@ fn fetchAndUnpack(
         var h = std.http.Headers{ .allocator = gpa };
         defer h.deinit();
 
-        var req = try http_client.request(uri, h, .{ .method = .GET });
+        var req = try http_client.request(.GET, uri, h, .{});
         defer req.deinit();
 
         try req.start();