Commit 2cdc0a8b50

Andrew Kelley <andrew@ziglang.org>
2023-01-04 00:03:28
std.http.Client: do not heap allocate for requests
1 parent ed23615
Changed files (1)
lib
std
lib/std/http/Client.zig
@@ -9,9 +9,8 @@ const net = std.net;
 const Client = @This();
 const Url = std.Url;
 
+/// TODO: remove this field (currently required due to tcpConnectToHost)
 allocator: std.mem.Allocator,
-headers: std.ArrayListUnmanaged(u8) = .{},
-active_requests: usize = 0,
 ca_bundle: std.crypto.Certificate.Bundle = .{},
 
 /// TODO: emit error.UnexpectedEndOfStream or something like that when the read
@@ -20,44 +19,23 @@ ca_bundle: std.crypto.Certificate.Bundle = .{},
 pub const Request = struct {
     client: *Client,
     stream: net.Stream,
-    headers: std.ArrayListUnmanaged(u8) = .{},
     tls_client: std.crypto.tls.Client,
     protocol: Protocol,
     response_headers: http.Headers = .{},
 
-    pub const Protocol = enum { http, https };
-
-    pub const Options = struct {
+    pub const Headers = struct {
         method: http.Method = .GET,
-    };
+        connection: Connection,
 
-    pub fn deinit(req: *Request) void {
-        req.client.active_requests -= 1;
-        req.headers.deinit(req.client.allocator);
-        req.* = undefined;
-    }
+        pub const Connection = enum {
+            close,
+            @"keep-alive",
+        };
+    };
 
-    pub fn addHeader(req: *Request, name: []const u8, value: []const u8) !void {
-        const gpa = req.client.allocator;
-        // Ensure an extra +2 for the \r\n in end()
-        try req.headers.ensureUnusedCapacity(gpa, name.len + value.len + 6);
-        req.headers.appendSliceAssumeCapacity(name);
-        req.headers.appendSliceAssumeCapacity(": ");
-        req.headers.appendSliceAssumeCapacity(value);
-        req.headers.appendSliceAssumeCapacity("\r\n");
-    }
+    pub const Protocol = enum { http, https };
 
-    pub fn end(req: *Request) !void {
-        req.headers.appendSliceAssumeCapacity("\r\n");
-        switch (req.protocol) {
-            .http => {
-                try req.stream.writeAll(req.headers.items);
-            },
-            .https => {
-                try req.tls_client.writeAll(req.stream, req.headers.items);
-            },
-        }
-    }
+    pub const Options = struct {};
 
     pub fn readAll(req: *Request, buffer: []u8) !usize {
         return readAtLeast(req, buffer, buffer.len);
@@ -113,13 +91,14 @@ pub const Request = struct {
     }
 };
 
-pub fn deinit(client: *Client) void {
-    assert(client.active_requests == 0);
-    client.headers.deinit(client.allocator);
+pub fn deinit(client: *Client, gpa: std.mem.Allocator) void {
+    client.ca_bundle.deinit(gpa);
     client.* = undefined;
 }
 
-pub fn request(client: *Client, url: Url, options: Request.Options) !Request {
+pub fn request(client: *Client, url: Url, headers: Request.Headers, options: Request.Options) !Request {
+    _ = options; // we have no options yet
+
     const protocol = std.meta.stringToEnum(Request.Protocol, url.scheme) orelse
         return error.UnsupportedUrlScheme;
     const port: u16 = url.port orelse switch (protocol) {
@@ -133,8 +112,6 @@ pub fn request(client: *Client, url: Url, options: Request.Options) !Request {
         .protocol = protocol,
         .tls_client = undefined,
     };
-    client.active_requests += 1;
-    errdefer req.deinit();
 
     switch (protocol) {
         .http => {},
@@ -146,36 +123,30 @@ pub fn request(client: *Client, url: Url, options: Request.Options) !Request {
         },
     }
 
-    try req.headers.ensureUnusedCapacity(
-        client.allocator,
-        @tagName(options.method).len +
-            1 +
-            url.path.len +
-            " HTTP/1.1\r\nHost: ".len +
-            url.host.len +
-            "\r\nUpgrade-Insecure-Requests: 1\r\n".len +
-            client.headers.items.len +
-            2, // for the \r\n at the end of headers
-    );
-    req.headers.appendSliceAssumeCapacity(@tagName(options.method));
-    req.headers.appendSliceAssumeCapacity(" ");
-    req.headers.appendSliceAssumeCapacity(url.path);
-    req.headers.appendSliceAssumeCapacity(" HTTP/1.1\r\nHost: ");
-    req.headers.appendSliceAssumeCapacity(url.host);
-    switch (protocol) {
-        .https => req.headers.appendSliceAssumeCapacity("\r\nUpgrade-Insecure-Requests: 1\r\n"),
-        .http => req.headers.appendSliceAssumeCapacity("\r\n"),
+    {
+        var h = try std.BoundedArray(u8, 1000).init(0);
+        try h.appendSlice(@tagName(headers.method));
+        try h.appendSlice(" ");
+        try h.appendSlice(url.path);
+        try h.appendSlice(" HTTP/1.1\r\nHost: ");
+        try h.appendSlice(url.host);
+        switch (protocol) {
+            .https => try h.appendSlice("\r\nUpgrade-Insecure-Requests: 1\r\n"),
+            .http => try h.appendSlice("\r\n"),
+        }
+        try h.writer().print("Connection: {s}\r\n", .{@tagName(headers.connection)});
+        try h.appendSlice("\r\n");
+
+        const header_bytes = h.slice();
+        switch (req.protocol) {
+            .http => {
+                try req.stream.writeAll(header_bytes);
+            },
+            .https => {
+                try req.tls_client.writeAll(req.stream, header_bytes);
+            },
+        }
     }
-    req.headers.appendSliceAssumeCapacity(client.headers.items);
 
     return req;
 }
-
-pub fn addHeader(client: *Client, name: []const u8, value: []const u8) !void {
-    const gpa = client.allocator;
-    try client.headers.ensureUnusedCapacity(gpa, name.len + value.len + 4);
-    client.headers.appendSliceAssumeCapacity(name);
-    client.headers.appendSliceAssumeCapacity(": ");
-    client.headers.appendSliceAssumeCapacity(value);
-    client.headers.appendSliceAssumeCapacity("\r\n");
-}