Commit afb26f4e6b

Nameless <truemedian@gmail.com>
2023-03-02 19:45:34
std.http: add connection pooling and make keep-alive requests by default
1 parent 95f6a59
Changed files (1)
lib
std
lib/std/http/Client.zig
@@ -21,11 +21,27 @@ ca_bundle: std.crypto.Certificate.Bundle = .{},
 /// it will first rescan the system for root certificates.
 next_https_rescan_certs: bool = true,
 
+connection_pool: std.TailQueue(Connection) = .{},
+
+const ConnectionPool = std.TailQueue(Connection);
+const ConnectionNode = ConnectionPool.Node;
+
+pub fn release(client: *Client, node: *ConnectionNode) void {
+    if (node.data.unusable) return node.data.close(client);
+
+    client.connection_pool.append(node);
+}
+
 pub const Connection = struct {
     stream: net.Stream,
     /// undefined unless protocol is tls.
-    tls_client: std.crypto.tls.Client,
+    tls_client: std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB.
     protocol: Protocol,
+    host: []u8,
+    port: u16,
+
+    // This connection has been part of a non keepalive request and cannot be added to the pool.
+    unusable: bool = false,
 
     pub const Protocol = enum { plain, tls };
 
@@ -56,6 +72,17 @@ pub const Connection = struct {
             .tls => return conn.tls_client.write(conn.stream, buffer),
         }
     }
+
+    pub fn close(conn: *Connection, client: *const Client) void {
+        if (conn.protocol == .tls) {
+            // try to cleanly close the TLS connection, for any server that cares.
+            _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {};
+        }
+
+        conn.stream.close();
+
+        client.allocator.free(conn.host);
+    }
 };
 
 /// TODO: emit error.UnexpectedEndOfStream or something like that when the read
@@ -63,7 +90,7 @@ pub const Connection = struct {
 /// close_notify protection on underlying TLS streams.
 pub const Request = struct {
     client: *Client,
-    connection: Connection,
+    connection: *ConnectionNode,
     redirects_left: u32,
     response: Response,
     /// These are stored in Request so that they are available when following
@@ -79,6 +106,7 @@ pub const Request = struct {
         header_bytes: std.ArrayListUnmanaged(u8),
         max_header_bytes: usize,
         next_chunk_length: u64,
+        done: bool,
 
         pub const Headers = struct {
             status: http.Status,
@@ -86,6 +114,7 @@ pub const Request = struct {
             location: ?[]const u8 = null,
             content_length: ?u64 = null,
             transfer_encoding: ?http.TransferEncoding = null,
+            connection_close: bool = true,
 
             pub fn parse(bytes: []const u8) !Response.Headers {
                 var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n");
@@ -126,6 +155,14 @@ pub const Request = struct {
                         if (headers.transfer_encoding != null) return error.HttpHeadersInvalid;
                         headers.transfer_encoding = std.meta.stringToEnum(http.TransferEncoding, header_value) orelse
                             return error.HttpTransferEncodingUnsupported;
+                    } else if (std.ascii.eqlIgnoreCase(header_name, "connection")) {
+                        if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) {
+                            headers.connection_close = false;
+                        } else if (std.ascii.eqlIgnoreCase(header_value, "close")) {
+                            headers.connection_close = true;
+                        } else {
+                            return error.HttpConnectionHeaderUnsupported;
+                        }
                     }
                 }
 
@@ -185,10 +222,10 @@ pub const Request = struct {
             chunk_r,
             chunk_data,
 
-            pub fn zeroMeansEnd(state: State) bool {
-                return switch (state) {
-                    .finished, .chunk_data => true,
-                    else => false,
+            pub fn isContent(self: State) bool {
+                return switch (self) {
+                    .invalid, .start, .seen_r, .seen_rn, .seen_rnr => false,
+                    .finished, .chunk_size_prefix_r, .chunk_size_prefix_n, .chunk_size, .chunk_r, .chunk_data => true,
                 };
             }
         };
@@ -201,6 +238,7 @@ pub const Request = struct {
                 .max_header_bytes = max,
                 .header_bytes_owned = true,
                 .next_chunk_length = undefined,
+                .done = false,
             };
         }
 
@@ -212,6 +250,7 @@ pub const Request = struct {
                 .max_header_bytes = buf.len,
                 .header_bytes_owned = false,
                 .next_chunk_length = undefined,
+                .done = false,
             };
         }
 
@@ -501,6 +540,7 @@ pub const Request = struct {
     pub const Headers = struct {
         version: http.Version = .@"HTTP/1.1",
         method: http.Method = .GET,
+        connection_close: bool = false,
     };
 
     pub const Options = struct {
@@ -545,6 +585,7 @@ pub const Request = struct {
         HttpHeadersExceededSizeLimit,
         HttpRedirectMissingLocation,
         HttpTransferEncodingUnsupported,
+        HttpConnectionHeaderUnsupported,
         HttpContentLengthUnknown,
         TooManyHttpRedirects,
         ShortHttpStatusLine,
@@ -669,8 +710,9 @@ pub const Request = struct {
         assert(len <= buffer.len);
         var index: usize = 0;
         while (index < len) {
-            const zero_means_end = req.response.state.zeroMeansEnd();
             const amt = try readAdvanced(req, buffer[index..]);
+            const zero_means_end = req.response.done and req.response.headers.status.class() != .redirect;
+
             if (amt == 0 and zero_means_end) break;
             index += amt;
         }
@@ -680,7 +722,29 @@ pub const Request = struct {
     /// This one can return 0 without meaning EOF.
     /// TODO change to readvAdvanced
     pub fn readAdvanced(req: *Request, buffer: []u8) !usize {
-        var in = buffer[0..try req.connection.read(buffer)];
+        if (req.response.done) {
+            if (req.response.headers.status.class() == .redirect) {
+                if (req.redirects_left == 0) return error.TooManyHttpRedirects;
+
+                const location = req.response.headers.location orelse
+                    return error.HttpRedirectMissingLocation;
+                const new_url = try std.Uri.parse(location);
+                const new_req = try req.client.request(new_url, req.headers, .{
+                    .max_redirects = req.redirects_left - 1,
+                    .header_strategy = if (req.response.header_bytes_owned) .{
+                        .dynamic = req.response.max_header_bytes,
+                    } else .{
+                        .static = req.response.header_bytes.unusedCapacitySlice(),
+                    },
+                });
+                req.deinit();
+                req.* = new_req;
+            } else {
+                return 0;
+            }
+        }
+
+        var in = buffer[0..try req.connection.data.read(buffer)];
         var out_index: usize = 0;
         while (true) {
             switch (req.response.state) {
@@ -698,24 +762,10 @@ pub const Request = struct {
                     if (req.response.state == .finished) {
                         req.response.headers = try Response.Headers.parse(req.response.header_bytes.items);
 
-                        if (req.response.headers.status.class() == .redirect) {
-                            if (req.redirects_left == 0) return error.TooManyHttpRedirects;
-                            const location = req.response.headers.location orelse
-                                return error.HttpRedirectMissingLocation;
-                            const new_url = try std.Uri.parse(location);
-                            const new_req = try req.client.request(new_url, req.headers, .{
-                                .max_redirects = req.redirects_left - 1,
-                                .header_strategy = if (req.response.header_bytes_owned) .{
-                                    .dynamic = req.response.max_header_bytes,
-                                } else .{
-                                    .static = req.response.header_bytes.unusedCapacitySlice(),
-                                },
-                            });
-                            req.deinit();
-                            req.* = new_req;
-                            assert(out_index == 0);
-                            in = buffer[0..try req.connection.read(buffer)];
-                            continue;
+                        if (req.response.headers.connection_close == true) {
+                            req.connection.data.unusable = true;
+                        } else {
+                            req.connection.data.unusable = false;
                         }
 
                         if (req.response.headers.transfer_encoding) |transfer_encoding| {
@@ -742,11 +792,29 @@ pub const Request = struct {
                     return 0;
                 },
                 .finished => {
+                    const sub_amt = @intCast(usize, @min(req.response.next_chunk_length, in.len));
+                    req.response.next_chunk_length -= sub_amt;
+
+                    if (req.response.next_chunk_length == 0) {
+                        req.client.release(req.connection);
+                        req.connection = undefined;
+
+                        req.response.done = true;
+                        assert(in.len == sub_amt); // TODO: figure out how to not read more than necessary.
+
+                        if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) return 0;
+
+                        mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
+                        return out_index + sub_amt;
+                    }
+
+                    if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) return 0;
+
                     if (in.ptr == buffer.ptr) {
-                        return in.len;
+                        return sub_amt;
                     } else {
-                        mem.copy(u8, buffer[out_index..], in);
-                        return out_index + in.len;
+                        mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
+                        return out_index + sub_amt;
                     }
                 },
                 .chunk_size_prefix_r => switch (in.len) {
@@ -793,7 +861,10 @@ pub const Request = struct {
                         .invalid => return error.HttpHeadersInvalid,
                         .chunk_data => {
                             if (req.response.next_chunk_length == 0) {
-                                req.response.state = .start;
+                                req.response.done = true;
+                                req.client.release(req.connection);
+                                req.connection = undefined;
+
                                 return out_index;
                             }
                             in = in[i..];
@@ -807,20 +878,27 @@ pub const Request = struct {
                     // TODO https://github.com/ziglang/zig/issues/14039
                     const sub_amt = @intCast(usize, @min(req.response.next_chunk_length, in.len));
                     req.response.next_chunk_length -= sub_amt;
-                    if (req.response.next_chunk_length > 0) {
-                        if (in.ptr == buffer.ptr) {
-                            return sub_amt;
-                        } else {
-                            mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
-                            out_index += sub_amt;
-                            return out_index;
-                        }
+
+                    if (req.response.next_chunk_length == 0) {
+                        req.response.state = .chunk_size_prefix_r;
+                        in = in[sub_amt..];
+
+                        if (req.response.headers.status.class() == .redirect) continue;
+
+                        mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
+                        out_index += sub_amt;
+                        continue;
+                    }
+
+                    if (req.response.headers.status.class() == .redirect) return 0;
+
+                    if (in.ptr == buffer.ptr) {
+                        return sub_amt;
+                    } else {
+                        mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
+                        out_index += sub_amt;
+                        return out_index;
                     }
-                    mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
-                    out_index += sub_amt;
-                    req.response.state = .chunk_size_prefix_r;
-                    in = in[sub_amt..];
-                    continue;
                 },
             }
         }
@@ -844,24 +922,52 @@ pub const Request = struct {
 };
 
 pub fn deinit(client: *Client) void {
+    var next = client.connection_pool.first;
+    while (next) |node| {
+        next = node.next;
+
+        node.data.close(client);
+
+        client.allocator.destroy(node);
+    }
+
     client.ca_bundle.deinit(client.allocator);
     client.* = undefined;
 }
 
-pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) !Connection {
-    var conn: Connection = .{
+pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) !*ConnectionNode {
+    var potential = client.connection_pool.last;
+    while (potential) |node| {
+        const same_host = mem.eql(u8, node.data.host, host);
+        const same_port = node.data.port == port;
+        const same_protocol = node.data.protocol == protocol;
+
+        if (same_host and same_port and same_protocol) {
+            client.connection_pool.remove(node);
+            return node;
+        }
+
+        potential = node.prev;
+    }
+
+    const conn = try client.allocator.create(ConnectionNode);
+    errdefer client.allocator.destroy(conn);
+
+    conn.* = .{ .data = .{
         .stream = try net.tcpConnectToHost(client.allocator, host, port),
         .tls_client = undefined,
         .protocol = protocol,
-    };
+        .host = try client.allocator.dupe(u8, host),
+        .port = port,
+    } };
 
     switch (protocol) {
         .plain => {},
         .tls => {
-            conn.tls_client = try std.crypto.tls.Client.init(conn.stream, client.ca_bundle, host);
+            conn.data.tls_client = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host);
             // This is appropriate for HTTPS because the HTTP headers contain
             // the content length which is used to detect truncation attacks.
-            conn.tls_client.allow_truncation_attacks = true;
+            conn.data.tls_client.allow_truncation_attacks = true;
         },
     }
 
@@ -908,10 +1014,15 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req
         try h.appendSlice(@tagName(headers.version));
         try h.appendSlice("\r\nHost: ");
         try h.appendSlice(host);
-        try h.appendSlice("\r\nConnection: close\r\n\r\n");
+        if (headers.connection_close) {
+            try h.appendSlice("\r\nConnection: close");
+        } else {
+            try h.appendSlice("\r\nConnection: keep-alive");
+        }
+        try h.appendSlice("\r\n\r\n");
 
         const header_bytes = h.slice();
-        try req.connection.writeAll(header_bytes);
+        try req.connection.data.writeAll(header_bytes);
     }
 
     return req;