Commit 651aa5e8e4

Andrew Kelley <andrew@ziglang.org>
2024-02-17 10:21:18
std.http.Client: eliminate arena allocator usage
Before, this code constructed an arena allocator and then used it when handling redirects. You know what's better than having threads fight over an allocator? Avoiding dynamic memory allocation in the first place. This commit reuses the http headers static buffer for handling redirects. The new location is copied to the beginning of the static header buffer and then the subsequent request uses a subslice of that buffer.
1 parent 107992d
Changed files (2)
lib
lib/std/http/Client.zig
@@ -597,9 +597,6 @@ pub const Request = struct {
     /// This field is undefined until `wait` is called.
     response: Response,
 
-    /// Used as a allocator for resolving redirects locations.
-    arena: std.heap.ArenaAllocator,
-
     /// Standard headers that have default, but overridable, behavior.
     headers: Headers,
 
@@ -661,8 +658,6 @@ pub const Request = struct {
             }
             req.client.connection_pool.release(req.client.allocator, connection);
         }
-
-        req.arena.deinit();
         req.* = undefined;
     }
 
@@ -842,11 +837,12 @@ pub const Request = struct {
     }
 
     pub const WaitError = RequestError || SendError || TransferReadError ||
-        proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError ||
+        proto.HeadersParser.CheckCompleteHeadError || Response.ParseError ||
         error{ // TODO: file zig fmt issue for this bad indentation
         TooManyHttpRedirects,
         RedirectRequiresResend,
-        HttpRedirectMissingLocation,
+        HttpRedirectLocationMissing,
+        HttpRedirectLocationInvalid,
         CompressionInitializationFailed,
         CompressionUnsupported,
     };
@@ -927,31 +923,40 @@ pub const Request = struct {
             }
 
             if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) {
-                req.response.skip = true;
-
                 // skip the body of the redirect response, this will at least
                 // leave the connection in a known good state.
+                req.response.skip = true;
                 assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary
 
                 if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects;
 
                 const location = req.response.location orelse
-                    return error.HttpRedirectMissingLocation;
-
-                const arena = req.arena.allocator();
-
-                const location_duped = try arena.dupe(u8, location);
-
-                const new_url = Uri.parse(location_duped) catch try Uri.parseWithoutScheme(location_duped);
-                const resolved_url = try req.uri.resolve(new_url, false, arena);
+                    return error.HttpRedirectLocationMissing;
+
+                // This mutates the beginning of header_buffer and uses that
+                // for the backing memory of the returned new_uri.
+                const header_buffer = req.response.parser.header_bytes_buffer;
+                const new_uri = req.uri.resolve_inplace(location, header_buffer) catch
+                    return error.HttpRedirectLocationInvalid;
+
+                // The new URI references the beginning of header_bytes_buffer memory.
+                // That memory will be kept, but everything after it will be
+                // reused by the subsequent request. In other words,
+                // header_bytes_buffer must be large enough to store all
+                // redirect locations as well as the final request header.
+                const path_end = new_uri.path.ptr + new_uri.path.len;
+                // https://github.com/ziglang/zig/issues/1738
+                const path_offset = @intFromPtr(path_end) - @intFromPtr(header_buffer.ptr);
+                const end_offset = @max(path_offset, location.len);
+                req.response.parser.header_bytes_buffer = header_buffer[end_offset..];
 
                 const is_same_domain_or_subdomain =
-                    std.ascii.endsWithIgnoreCase(resolved_url.host.?, req.uri.host.?) and
-                    (resolved_url.host.?.len == req.uri.host.?.len or
-                    resolved_url.host.?[resolved_url.host.?.len - req.uri.host.?.len - 1] == '.');
+                    std.ascii.endsWithIgnoreCase(new_uri.host.?, req.uri.host.?) and
+                    (new_uri.host.?.len == req.uri.host.?.len or
+                    new_uri.host.?[new_uri.host.?.len - req.uri.host.?.len - 1] == '.');
 
-                if (resolved_url.host == null or !is_same_domain_or_subdomain or
-                    !std.ascii.eqlIgnoreCase(resolved_url.scheme, req.uri.scheme))
+                if (new_uri.host == null or !is_same_domain_or_subdomain or
+                    !std.ascii.eqlIgnoreCase(new_uri.scheme, req.uri.scheme))
                 {
                     // When redirecting to a different domain, strip privileged headers.
                     req.privileged_headers = &.{};
@@ -975,7 +980,7 @@ pub const Request = struct {
                     return error.RedirectRequiresResend;
                 }
 
-                try req.redirect(resolved_url);
+                try req.redirect(new_uri);
                 try req.send(.{});
             } else {
                 req.response.skip = false;
@@ -1341,7 +1346,7 @@ pub fn connectTunnel(
             client.connection_pool.release(client.allocator, conn);
         }
 
-        const uri = Uri{
+        const uri: Uri = .{
             .scheme = "http",
             .user = null,
             .password = null,
@@ -1548,15 +1553,12 @@ pub fn open(
             .version = undefined,
             .parser = proto.HeadersParser.init(options.server_header_buffer),
         },
-        .arena = undefined,
         .headers = options.headers,
         .extra_headers = options.extra_headers,
         .privileged_headers = options.privileged_headers,
     };
     errdefer req.deinit();
 
-    req.arena = std.heap.ArenaAllocator.init(client.allocator);
-
     return req;
 }
 
lib/std/Uri.zig
@@ -342,7 +342,7 @@ pub fn format(
 /// The return value will contain unescaped strings pointing into the
 /// original `text`. Each component that is provided, will be non-`null`.
 pub fn parse(text: []const u8) ParseError!Uri {
-    var reader = SliceReader{ .slice = text };
+    var reader: SliceReader = .{ .slice = text };
     const scheme = reader.readWhile(isSchemeChar);
 
     // after the scheme, a ':' must appear
@@ -359,111 +359,145 @@ pub fn parse(text: []const u8) ParseError!Uri {
     return uri;
 }
 
-/// Implementation of RFC 3986, Section 5.2.4. Removes dot segments from a URI path.
-///
-/// `std.fs.path.resolvePosix` is not sufficient here because it may return relative paths and does not preserve trailing slashes.
-fn removeDotSegments(allocator: Allocator, paths: []const []const u8) Allocator.Error![]const u8 {
-    var result = std.ArrayList(u8).init(allocator);
-    defer result.deinit();
-
-    for (paths) |p| {
-        var it = std.mem.tokenizeScalar(u8, p, '/');
-        while (it.next()) |component| {
-            if (std.mem.eql(u8, component, ".")) {
-                continue;
-            } else if (std.mem.eql(u8, component, "..")) {
-                if (result.items.len == 0)
-                    continue;
+pub const ResolveInplaceError = ParseError || error{OutOfMemory};
 
-                while (true) {
-                    const ends_with_slash = result.items[result.items.len - 1] == '/';
-                    result.items.len -= 1;
-                    if (ends_with_slash or result.items.len == 0) break;
-                }
-            } else {
-                try result.ensureUnusedCapacity(1 + component.len);
-                result.appendAssumeCapacity('/');
-                result.appendSliceAssumeCapacity(component);
-            }
-        }
-    }
+/// Resolves a URI against a base URI, conforming to RFC 3986, Section 5.
+/// Copies `new` to the beginning of `aux_buf`, allowing the slices to overlap,
+/// then parses `new` as a URI, and then resolves the path in place.
+/// If a merge needs to take place, the newly constructed path will be stored
+/// in `aux_buf` just after the copied `new`.
+pub fn resolve_inplace(base: Uri, new: []const u8, aux_buf: []u8) ResolveInplaceError!Uri {
+    std.mem.copyBackwards(u8, aux_buf, new);
+    // At this point, new is an invalid pointer.
+    const new_mut = aux_buf[0..new.len];
+
+    const new_parsed, const has_scheme = p: {
+        break :p .{
+            parse(new_mut) catch |first_err| {
+                break :p .{
+                    parseWithoutScheme(new_mut) catch return first_err,
+                    false,
+                };
+            },
+            true,
+        };
+    };
 
-    // ensure a trailing slash is kept
-    const last_path = paths[paths.len - 1];
-    if (last_path.len > 0 and last_path[last_path.len - 1] == '/') {
-        try result.append('/');
-    }
+    // As you can see above, `new_mut` is not a const pointer.
+    const new_path: []u8 = @constCast(new_parsed.path);
+
+    if (has_scheme) return .{
+        .scheme = new_parsed.scheme,
+        .user = new_parsed.user,
+        .host = new_parsed.host,
+        .port = new_parsed.port,
+        .path = remove_dot_segments(new_path),
+        .query = new_parsed.query,
+        .fragment = new_parsed.fragment,
+    };
 
-    return result.toOwnedSlice();
-}
+    if (new_parsed.host) |host| return .{
+        .scheme = base.scheme,
+        .user = new_parsed.user,
+        .host = host,
+        .port = new_parsed.port,
+        .path = remove_dot_segments(new_path),
+        .query = new_parsed.query,
+        .fragment = new_parsed.fragment,
+    };
 
-/// Resolves a URI against a base URI, conforming to RFC 3986, Section 5.
-///
-/// Assumes `arena` owns all memory in `base` and `ref`. `arena` will own all memory in the returned URI.
-pub fn resolve(base: Uri, ref: Uri, strict: bool, arena: Allocator) Allocator.Error!Uri {
-    var target: Uri = Uri{
-        .scheme = "",
-        .user = null,
-        .password = null,
-        .host = null,
-        .port = null,
-        .path = "",
-        .query = null,
-        .fragment = null,
+    const path, const query = b: {
+        if (new_path.len == 0)
+            break :b .{
+                base.path,
+                new_parsed.query orelse base.query,
+            };
+
+        if (new_path[0] == '/')
+            break :b .{
+                remove_dot_segments(new_path),
+                new_parsed.query,
+            };
+
+        break :b .{
+            try merge_paths(base.path, new_path, aux_buf[new_mut.len..]),
+            new_parsed.query,
+        };
     };
 
-    if (ref.scheme.len > 0 and (strict or !std.mem.eql(u8, ref.scheme, base.scheme))) {
-        target.scheme = ref.scheme;
-        target.user = ref.user;
-        target.host = ref.host;
-        target.port = ref.port;
-        target.path = try removeDotSegments(arena, &.{ref.path});
-        target.query = ref.query;
-    } else {
-        target.scheme = base.scheme;
-        if (ref.host) |host| {
-            target.user = ref.user;
-            target.host = host;
-            target.port = ref.port;
-            target.path = ref.path;
-            target.path = try removeDotSegments(arena, &.{ref.path});
-            target.query = ref.query;
+    return .{
+        .scheme = base.scheme,
+        .user = base.user,
+        .host = base.host,
+        .port = base.port,
+        .path = path,
+        .query = query,
+        .fragment = new_parsed.fragment,
+    };
+}
+
+/// In-place implementation of RFC 3986, Section 5.2.4.
+fn remove_dot_segments(path: []u8) []u8 {
+    var in_i: usize = 0;
+    var out_i: usize = 0;
+    while (in_i < path.len) {
+        if (std.mem.startsWith(u8, path[in_i..], "./")) {
+            in_i += 2;
+        } else if (std.mem.startsWith(u8, path[in_i..], "../")) {
+            in_i += 3;
+        } else if (std.mem.startsWith(u8, path[in_i..], "/./")) {
+            in_i += 2;
+        } else if (std.mem.eql(u8, path[in_i..], "/.")) {
+            in_i += 1;
+            path[in_i] = '/';
+        } else if (std.mem.startsWith(u8, path[in_i..], "/../")) {
+            in_i += 3;
+            while (out_i > 0) {
+                out_i -= 1;
+                if (path[out_i] == '/') break;
+            }
+        } else if (std.mem.eql(u8, path[in_i..], "/..")) {
+            in_i += 2;
+            path[in_i] = '/';
+            while (out_i > 0) {
+                out_i -= 1;
+                if (path[out_i] == '/') break;
+            }
+        } else if (std.mem.eql(u8, path[in_i..], ".")) {
+            in_i += 1;
+        } else if (std.mem.eql(u8, path[in_i..], "..")) {
+            in_i += 2;
         } else {
-            if (ref.path.len == 0) {
-                target.path = base.path;
-                target.query = ref.query orelse base.query;
-            } else {
-                if (ref.path[0] == '/') {
-                    target.path = try removeDotSegments(arena, &.{ref.path});
-                } else {
-                    target.path = try removeDotSegments(arena, &.{ std.fs.path.dirnamePosix(base.path) orelse "", ref.path });
-                }
-                target.query = ref.query;
+            while (true) {
+                path[out_i] = path[in_i];
+                out_i += 1;
+                in_i += 1;
+                if (in_i >= path.len or path[in_i] == '/') break;
             }
-
-            target.user = base.user;
-            target.host = base.host;
-            target.port = base.port;
         }
     }
-
-    target.fragment = ref.fragment;
-
-    return target;
+    return path[0..out_i];
 }
 
-test resolve {
-    const base = try parse("http://a/b/c/d;p?q");
-
-    var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
-    defer arena.deinit();
+test remove_dot_segments {
+    {
+        var buffer = "/a/b/c/./../../g".*;
+        try std.testing.expectEqualStrings("/a/g", remove_dot_segments(&buffer));
+    }
+}
 
-    try std.testing.expectEqualDeep(try parse("http://a/b/c/blog/"), try base.resolve(try parseWithoutScheme("blog/"), true, arena.allocator()));
-    try std.testing.expectEqualDeep(try parse("http://a/b/c/blog/?k"), try base.resolve(try parseWithoutScheme("blog/?k"), true, arena.allocator()));
-    try std.testing.expectEqualDeep(try parse("http://a/b/blog/"), try base.resolve(try parseWithoutScheme("../blog/"), true, arena.allocator()));
-    try std.testing.expectEqualDeep(try parse("http://a/b/blog"), try base.resolve(try parseWithoutScheme("../blog"), true, arena.allocator()));
-    try std.testing.expectEqualDeep(try parse("http://e"), try base.resolve(try parseWithoutScheme("//e"), true, arena.allocator()));
-    try std.testing.expectEqualDeep(try parse("https://a:1/"), try base.resolve(try parse("https://a:1/"), true, arena.allocator()));
+/// 5.2.3. Merge Paths
+fn merge_paths(base: []const u8, new: []u8, aux: []u8) error{OutOfMemory}![]u8 {
+    if (aux.len < base.len + 1 + new.len) return error.OutOfMemory;
+    if (base.len == 0) {
+        aux[0] = '/';
+        @memcpy(aux[1..][0..new.len], new);
+        return remove_dot_segments(aux[0 .. new.len + 1]);
+    }
+    const pos = std.mem.lastIndexOfScalar(u8, base, '/') orelse return remove_dot_segments(new);
+    @memcpy(aux[0 .. pos + 1], base[0 .. pos + 1]);
+    @memcpy(aux[pos + 1 ..][0..new.len], new);
+    return remove_dot_segments(aux[0 .. pos + 1 + new.len]);
 }
 
 const SliceReader = struct {