Commit 5d40338f21

Nameless <truemedian@gmail.com>
2023-08-12 03:34:59
std.http: add Client.fetch and improve redirect logic
1 parent 49075d2
Changed files (3)
lib
test
standalone
lib/std/http/Client.zig
@@ -365,8 +365,11 @@ pub const Response = struct {
             if (trailing) continue;
 
             if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
-                if (res.content_length != null) return error.HttpHeadersInvalid;
-                res.content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength;
+                const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength;
+
+                if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid;
+
+                res.content_length = content_length;
             } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
                 // Transfer-Encoding: second, first
                 // Transfer-Encoding: deflate, chunked
@@ -536,6 +539,8 @@ pub const Request = struct {
 
     /// Send the request to the server.
     pub fn start(req: *Request) StartError!void {
+        if (!req.method.requestHasBody() and req.transfer_encoding != .none) return error.UnsupportedTransferEncoding;
+
         var buffered = std.io.bufferedWriter(req.connection.?.data.writer());
         const w = buffered.writer();
 
@@ -607,7 +612,14 @@ pub const Request = struct {
             }
         }
 
-        try w.print("{}", .{req.headers});
+        for (req.headers.list.items) |entry| {
+            if (entry.value.len == 0) continue;
+
+            try w.writeAll(entry.name);
+            try w.writeAll(": ");
+            try w.writeAll(entry.value);
+            try w.writeAll("\r\n");
+        }
 
         try w.writeAll("\r\n");
 
@@ -635,13 +647,13 @@ pub const Request = struct {
         return index;
     }
 
-    pub const WaitError = RequestError || StartError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, CannotRedirect, HttpRedirectMissingLocation, CompressionInitializationFailed, CompressionNotSupported };
+    pub const WaitError = RequestError || StartError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, RedirectRequiresResend, HttpRedirectMissingLocation, CompressionInitializationFailed, CompressionNotSupported };
 
     /// Waits for a response from the server and parses any headers that are sent.
     /// This function will block until the final response is received.
     ///
     /// If `handle_redirects` is true and the request has no payload, then this function will automatically follow
-    /// redirects. If a request payload is present, then this function will error with error.CannotRedirect.
+    /// redirects. If a request payload is present, then this function will error with error.RedirectRequiresResend.
     pub fn wait(req: *Request) WaitError!void {
         while (true) { // handle redirects
             while (true) { // read headers
@@ -697,9 +709,10 @@ pub const Request = struct {
                 req.response.parser.done = true;
             }
 
-            if (req.transfer_encoding == .none and req.response.status.class() == .redirect and req.handle_redirects) {
+            if (req.response.status.class() == .redirect and req.handle_redirects) {
                 req.response.skip = true;
 
+                // skip the body of the redirect response, this will at least leave the connection in a known good state.
                 const empty = @as([*]u8, undefined)[0..0];
                 assert(try req.transferRead(empty) == 0); // we're skipping, no buffer is necessary
 
@@ -715,6 +728,30 @@ pub const Request = struct {
                 const new_url = Uri.parse(location_duped) catch try Uri.parseWithoutScheme(location_duped);
                 const resolved_url = try req.uri.resolve(new_url, false, arena);
 
+                // is the redirect location on the same domain, or a subdomain of the original request?
+                const is_same_domain_or_subdomain = std.ascii.endsWithIgnoreCase(resolved_url.host.?, req.uri.host.?) and (resolved_url.host.?.len == req.uri.host.?.len or resolved_url.host.?[resolved_url.host.?.len - req.uri.host.?.len - 1] == '.');
+
+                if (resolved_url.host == null or !is_same_domain_or_subdomain or !std.ascii.eqlIgnoreCase(resolved_url.scheme, req.uri.scheme)) {
+                    // we're redirecting to a different domain, strip privileged headers like cookies
+                    _ = req.headers.delete("authorization");
+                    _ = req.headers.delete("www-authenticate");
+                    _ = req.headers.delete("cookie");
+                    _ = req.headers.delete("cookie2");
+                }
+
+                if (req.response.status == .see_other or ((req.response.status == .moved_permanently or req.response.status == .found) and req.method == .POST)) {
+                    // we're redirecting to a GET, so we need to change the method and remove the body
+                    req.method = .GET;
+                    req.transfer_encoding = .none;
+                    _ = req.headers.delete("transfer-encoding");
+                    _ = req.headers.delete("content-length");
+                    _ = req.headers.delete("content-type");
+                }
+
+                if (req.transfer_encoding != .none) {
+                    return error.RedirectRequiresResend; // The request body has already been sent. The request is still in a valid state, but the redirect must be handled manually.
+                }
+
                 try req.redirect(resolved_url);
 
                 try req.start();
@@ -735,9 +772,6 @@ pub const Request = struct {
                     };
                 }
 
-                if (req.response.status.class() == .redirect and req.handle_redirects and req.transfer_encoding != .none)
-                    return error.CannotRedirect; // The request body has already been sent. The request is still in a valid state, but the redirect must be handled manually.
-
                 break;
             }
         }
@@ -956,17 +990,17 @@ pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request
     UnsupportedTransferEncoding,
 };
 
-pub const Options = struct {
+pub const RequestOptions = struct {
     version: http.Version = .@"HTTP/1.1",
 
     handle_redirects: bool = true,
     max_redirects: u32 = 3,
-    header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 },
+    header_strategy: StorageStrategy = .{ .dynamic = 16 * 1024 },
 
     /// Must be an already acquired connection.
     connection: ?*ConnectionPool.Node = null,
 
-    pub const HeaderStrategy = union(enum) {
+    pub const StorageStrategy = union(enum) {
         /// In this case, the client's Allocator will be used to store the
         /// entire HTTP header. This value is the maximum total size of
         /// HTTP headers allowed, otherwise
@@ -988,8 +1022,12 @@ pub const protocol_map = std.ComptimeStringMap(Connection.Protocol, .{
 });
 
 /// Form and send a http request to a server.
+///
+/// `uri` must remain alive during the entire request.
+/// `headers` is cloned and may be freed after this function returns.
+///
 /// This function is threadsafe.
-pub fn request(client: *Client, method: http.Method, uri: Uri, headers: http.Headers, options: Options) RequestError!Request {
+pub fn request(client: *Client, method: http.Method, uri: Uri, headers: http.Headers, options: RequestOptions) RequestError!Request {
     const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme;
 
     const port: u16 = uri.port orelse switch (protocol) {
@@ -1015,7 +1053,7 @@ pub fn request(client: *Client, method: http.Method, uri: Uri, headers: http.Hea
         .uri = uri,
         .client = client,
         .connection = conn,
-        .headers = headers,
+        .headers = try headers.clone(client.allocator), // Headers must be cloned to properly handle header transformations in redirects.
         .method = method,
         .version = options.version,
         .redirects_left = options.max_redirects,
@@ -1039,6 +1077,123 @@ pub fn request(client: *Client, method: http.Method, uri: Uri, headers: http.Hea
     return req;
 }
 
+pub const FetchOptions = struct {
+    pub const Location = union(enum) {
+        url: []const u8,
+        uri: Uri,
+    };
+
+    pub const Payload = union(enum) {
+        string: []const u8,
+        file: std.fs.File,
+        none,
+    };
+
+    pub const ResponseStrategy = union(enum) {
+        storage: RequestOptions.StorageStrategy,
+        file: std.fs.File,
+        none,
+    };
+
+    header_strategy: RequestOptions.StorageStrategy = .{ .dynamic = 16 * 1024 },
+    response_strategy: ResponseStrategy = .{ .storage = .{ .dynamic = 16 * 1024 * 1024 } },
+
+    location: Location,
+    method: http.Method = .GET,
+    headers: http.Headers = http.Headers{ .allocator = std.heap.page_allocator, .owned = false },
+    payload: Payload = .none,
+};
+
+pub const FetchResult = struct {
+    status: http.Status,
+    body: ?[]const u8 = null,
+    headers: http.Headers,
+
+    allocator: Allocator,
+    options: FetchOptions,
+
+    pub fn deinit(res: *FetchResult) void {
+        if (res.options.response_strategy == .storage and res.options.response_strategy.storage == .dynamic) {
+            if (res.body) |body| res.allocator.free(body);
+        }
+
+        res.headers.deinit();
+    }
+};
+
+pub fn fetch(client: *Client, allocator: Allocator, options: FetchOptions) !FetchResult {
+    const has_transfer_encoding = options.headers.contains("transfer-encoding");
+    const has_content_length = options.headers.contains("content-length");
+
+    if (has_content_length or has_transfer_encoding) return error.UnsupportedHeader;
+
+    const uri = switch (options.location) {
+        .url => |u| try Uri.parse(u),
+        .uri => |u| u,
+    };
+
+    var req = try request(client, options.method, uri, options.headers, .{
+        .header_strategy = options.header_strategy,
+        .handle_redirects = options.payload == .none,
+    });
+    defer req.deinit();
+
+    { // Block to maintain lock of file to attempt to prevent a race condition where another process modifies the file while we are reading it.
+        // This relies on other processes actually obeying the advisory lock, which is not guaranteed.
+        if (options.payload == .file) try options.payload.file.lock(.shared);
+        defer if (options.payload == .file) options.payload.file.unlock();
+
+        switch (options.payload) {
+            .string => |str| req.transfer_encoding = .{ .content_length = str.len },
+            .file => |file| req.transfer_encoding = .{ .content_length = (try file.stat()).size },
+            .none => {},
+        }
+
+        try req.start();
+
+        switch (options.payload) {
+            .string => |str| try req.writeAll(str),
+            .file => |file| {
+                try file.seekTo(0);
+                var fifo = std.fifo.LinearFifo(u8, .{ .Static = 8192 }).init();
+                try fifo.pump(file.reader(), req.writer());
+            },
+            .none => {},
+        }
+
+        try req.finish();
+    }
+
+    try req.wait();
+
+    var res = FetchResult{
+        .status = req.response.status,
+        .headers = try req.response.headers.clone(allocator),
+
+        .allocator = allocator,
+        .options = options,
+    };
+
+    switch (options.response_strategy) {
+        .storage => |storage| switch (storage) {
+            .dynamic => |max| res.body = try req.reader().readAllAlloc(allocator, max),
+            .static => |buf| res.body = buf[0..try req.reader().readAll(buf)],
+        },
+        .file => |file| {
+            var fifo = std.fifo.LinearFifo(u8, .{ .Static = 8192 }).init();
+            try fifo.pump(req.reader(), file.writer());
+        },
+        .none => { // Take advantage of request internals to discard the response body and make the connection available for another request.
+            req.response.skip = true;
+
+            const empty = @as([*]u8, undefined)[0..0];
+            assert(try req.transferRead(empty) == 0); // we're skipping, no buffer is necessary
+        },
+    }
+
+    return res;
+}
+
 test {
     const builtin = @import("builtin");
     const native_endian = comptime builtin.cpu.arch.endian();
lib/std/http/Headers.zig
@@ -57,6 +57,18 @@ pub const Headers = struct {
         return .{ .allocator = allocator };
     }
 
+    pub fn initList(allocator: Allocator, list: []const Field) Headers {
+        var new = Headers.init(allocator);
+
+        try new.list.ensureTotalCapacity(allocator, list.len);
+        try new.index.ensureTotalCapacity(allocator, list.len);
+        for (list) |field| {
+            try new.append(field.name, field.value);
+        }
+
+        return new;
+    }
+
     pub fn deinit(headers: *Headers) void {
         headers.deallocateIndexListsAndFields();
         headers.index.deinit(headers.allocator);
@@ -78,7 +90,7 @@ pub const Headers = struct {
             entry.name = kv.key_ptr.*;
             try kv.value_ptr.append(headers.allocator, n);
         } else {
-            const name_duped = if (headers.owned) try headers.allocator.dupe(u8, name) else name;
+            const name_duped = if (headers.owned) try std.ascii.allocLowerString(headers.allocator, name) else name;
             errdefer if (headers.owned) headers.allocator.free(name_duped);
 
             entry.name = name_duped;
@@ -97,6 +109,7 @@ pub const Headers = struct {
         return headers.index.contains(name);
     }
 
+    /// Removes all headers with the given name.
     pub fn delete(headers: *Headers, name: []const u8) bool {
         if (headers.index.fetchRemove(name)) |kv| {
             var index = kv.value;
@@ -268,6 +281,18 @@ pub const Headers = struct {
         headers.index.clearRetainingCapacity();
         headers.list.clearRetainingCapacity();
     }
+
+    pub fn clone(headers: Headers, allocator: Allocator) !Headers {
+        var new = Headers.init(allocator);
+
+        try new.list.ensureTotalCapacity(allocator, headers.list.capacity);
+        try new.index.ensureTotalCapacity(allocator, headers.index.capacity());
+        for (headers.list.items) |field| {
+            try new.append(field.name, field.value);
+        }
+
+        return new;
+    }
 };
 
 test "Headers.append" {
test/standalone/http.zig
@@ -571,7 +571,28 @@ pub fn main() !void {
     // connection has been kept alive
     try testing.expect(client.connection_pool.free_len == 1);
 
-    { // issue 16282
+    { // Client.fetch()
+        var h = http.Headers{ .allocator = calloc };
+        defer h.deinit();
+
+        try h.append("content-type", "text/plain");
+
+        const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content#fetch", .{port});
+        defer calloc.free(location);
+
+        log.info("{s}", .{location});
+        var res = try client.fetch(calloc, .{
+            .location = .{ .url = location },
+            .method = .POST,
+            .headers = h,
+            .payload = .{ .string = "Hello, World!\n" },
+        });
+        defer res.deinit();
+
+        try testing.expectEqualStrings("Hello, World!\n", res.body.?);
+    }
+
+    { // issue 16282 *** This test leaves the client in an invalid state, it must be last ***
         const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get", .{port});
         defer calloc.free(location);
         const uri = try std.Uri.parse(location);