Commit 524e0cd987

Nameless <truemedian@gmail.com>
2023-03-08 18:27:13
std.http: rework connection pool into its own type
1 parent 634e715
Changed files (4)
lib/std/http/Client/Request.zig
@@ -6,7 +6,7 @@ const assert = std.debug.assert;
 
 const Client = @import("../Client.zig");
 const Connection = Client.Connection;
-const ConnectionNode = Client.ConnectionNode;
+const ConnectionNode = Client.ConnectionPool.Node;
 const Response = @import("Response.zig");
 
 const Request = @This();
@@ -85,7 +85,7 @@ pub fn deinit(req: *Request) void {
     if (!req.response.done) {
         // If the response wasn't fully read, then we need to close the connection.
         req.connection.data.closing = true;
-        req.client.release(req.connection);
+        req.client.connection_pool.release(req.client, req.connection);
     }
 
     req.arena.deinit();
@@ -135,7 +135,7 @@ fn checkForCompleteHead(req: *Request, buffer: []u8) !usize {
     if (req.response.state == .finished) {
         req.response.headers = try Response.Headers.parse(req.response.header_bytes.items);
 
-        if (req.response.upgrade) |_| {
+        if (req.response.headers.upgrade) |_| {
             req.connection.data.closing = false;
             req.response.done = true;
             return i;
@@ -226,7 +226,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
                     req.response.next_chunk_length -= can_read;
 
                     if (req.response.next_chunk_length == 0) {
-                        req.client.release(req.connection);
+                        req.client.connection_pool.release(req.client, req.connection);
                         req.connection = undefined;
                         req.response.done = true;
                     }
@@ -241,7 +241,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
                 req.read_buffer_start += @intCast(ReadBufferIndex, can_read);
 
                 if (req.response.next_chunk_length == 0) {
-                    req.client.release(req.connection);
+                    req.client.connection_pool.release(req.client, req.connection);
                     req.connection = undefined;
                     req.response.done = true;
                 }
@@ -293,7 +293,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
                     .chunk_data => {
                         if (req.response.next_chunk_length == 0) {
                             req.response.done = true;
-                            req.client.release(req.connection);
+                            req.client.connection_pool.release(req.client, req.connection);
                             req.connection = undefined;
 
                             return out_index;
@@ -317,7 +317,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
                     req.response.next_chunk_length -= can_read;
 
                     if (req.response.next_chunk_length == 0) {
-                        req.client.release(req.connection);
+                        req.client.connection_pool.release(req.client, req.connection);
                         req.connection = undefined;
                         req.response.done = true;
                         continue;
@@ -345,13 +345,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
     }
 }
 
-pub const ReadError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error || WaitForCompleteHeadError || error{
-    BadHeader,
-    InvalidCompression,
-    StreamTooLong,
-    InvalidWindowSize,
-    CompressionNotSupported
-};
+pub const ReadError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error || WaitForCompleteHeadError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize, CompressionNotSupported };
 
 pub const Reader = std.io.Reader(*Request, ReadError, read);
 
lib/std/http/Client/Response.zig
@@ -32,6 +32,7 @@ pub const Headers = struct {
     transfer_encoding: ?http.TransferEncoding = null,
     transfer_compression: ?http.ContentEncoding = null,
     connection: http.Connection = .close,
+    upgrade: ?[]const u8 = null,
 
     number_of_headers: usize = 0,
 
@@ -93,7 +94,7 @@ pub const Headers = struct {
 
                 if (iter.next()) |second| {
                     if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported;
-                        
+
                     const trimmed = std.mem.trim(u8, second, " ");
 
                     if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
@@ -122,6 +123,8 @@ pub const Headers = struct {
                 } else {
                     return error.HttpConnectionHeaderUnsupported;
                 }
+            } else if (std.ascii.eqlIgnoreCase(header_name, "upgrade")) {
+                headers.upgrade = header_value;
             }
         }
 
lib/std/http/Client.zig
@@ -16,6 +16,9 @@ const testing = std.testing;
 pub const Request = @import("Client/Request.zig");
 pub const Response = @import("Client/Response.zig");
 
+pub const default_connection_pool_size = 32;
+const connection_pool_size = std.options.http_connection_pool_size;
+
 /// Used for tcpConnectToHost and storing HTTP headers when an externally
 /// managed buffer is not provided.
 allocator: Allocator,
@@ -24,39 +27,115 @@ ca_bundle: std.crypto.Certificate.Bundle = .{},
 /// it will first rescan the system for root certificates.
 next_https_rescan_certs: bool = true,
 
-connection_mutex: std.Thread.Mutex = .{},
 connection_pool: ConnectionPool = .{},
-connection_used: ConnectionPool = .{},
 
-pub const ConnectionPool = std.TailQueue(Connection);
-pub const ConnectionNode = ConnectionPool.Node;
+pub const ConnectionPool = struct {
+    pub const Criteria = struct {
+        host: []const u8,
+        port: u16,
+        is_tls: bool,
+    };
 
-/// Acquires an existing connection from the connection pool. This function is threadsafe.
-/// If the caller already holds the connection mutex, it should pass `true` for `held`.
-pub fn acquire(client: *Client, node: *ConnectionNode, held: bool) void {
-    if (!held) client.connection_mutex.lock();
-    defer if (!held) client.connection_mutex.unlock();
+    const Queue = std.TailQueue(Connection);
+    pub const Node = Queue.Node;
+
+    mutex: std.Thread.Mutex = .{},
+    used: Queue = .{},
+    free: Queue = .{},
+    free_len: usize = 0,
+    free_size: usize = default_connection_pool_size,
+
+    /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe.
+    /// If no connection is found, null is returned.
+    pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Node {
+        pool.mutex.lock();
+        defer pool.mutex.unlock();
+
+        var next = pool.free.last;
+        while (next) |node| : (next = node.prev) {
+            if ((node.data.protocol == .tls) != criteria.is_tls) continue;
+            if (node.data.port != criteria.port) continue;
+            if (std.mem.eql(u8, node.data.host, criteria.host)) continue;
+
+            pool.acquireUnsafe(node);
+            return node;
+        }
 
-    client.connection_pool.remove(node);
-    client.connection_used.append(node);
-}
+        return null;
+    }
 
-/// Tries to release a connection back to the connection pool. This function is threadsafe.
-/// If the connection is marked as closing, it will be closed instead.
-pub fn release(client: *Client, node: *ConnectionNode) void {
-    client.connection_mutex.lock();
-    defer client.connection_mutex.unlock();
+    /// Acquires an existing connection from the connection pool. This function is not threadsafe.
+    pub fn acquireUnsafe(pool: *ConnectionPool, node: *Node) void {
+        pool.free.remove(node);
+        pool.free_len -= 1;
 
-    client.connection_used.remove(node);
+        pool.used.append(node);
+    }
 
-    if (node.data.closing) {
-        node.data.close(client);
+    /// Acquires an existing connection from the connection pool. This function is threadsafe.
+    pub fn acquire(pool: *ConnectionPool, node: *Node) void {
+        pool.mutex.lock();
+        defer pool.mutex.unlock();
 
-        return client.allocator.destroy(node);
+        return pool.acquireUnsafe(node);
     }
 
-    client.connection_pool.append(node);
-}
+    /// Tries to release a connection back to the connection pool. This function is threadsafe.
+    /// If the connection is marked as closing, it will be closed instead.
+    pub fn release(pool: *ConnectionPool, client: *Client, node: *Node) void {
+        pool.mutex.lock();
+        defer pool.mutex.unlock();
+
+        pool.used.remove(node);
+
+        if (node.data.closing) {
+            node.data.close(client);
+
+            return client.allocator.destroy(node);
+        }
+
+        if (pool.free_len + 1 >= pool.free_size) {
+            const popped = pool.free.popFirst() orelse unreachable;
+
+            popped.data.close(client);
+
+            return client.allocator.destroy(popped);
+        }
+
+        pool.free.append(node);
+        pool.free_len += 1;
+    }
+
+    /// Adds a newly created node to the pool of used connections. This function is threadsafe.
+    pub fn addUsed(pool: *ConnectionPool, node: *Node) void {
+        pool.mutex.lock();
+        defer pool.mutex.unlock();
+
+        pool.used.append(node);
+    }
+
+    pub fn deinit(pool: *ConnectionPool, client: *Client) void {
+        pool.mutex.lock();
+
+        var next = pool.free.first;
+        while (next) |node| {
+            defer client.allocator.destroy(node);
+            next = node.next;
+
+            node.data.close(client);
+        }
+
+        next = pool.used.first;
+        while (next) |node| {
+            defer client.allocator.destroy(node);
+            next = node.next;
+
+            node.data.close(client);
+        }
+
+        pool.* = undefined;
+    }
+};
 
 pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw);
 pub const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw);
@@ -142,25 +221,7 @@ pub const Connection = struct {
 };
 
 pub fn deinit(client: *Client) void {
-    client.connection_mutex.lock();
-
-    var next = client.connection_pool.first;
-    while (next) |node| {
-        next = node.next;
-
-        node.data.close(client);
-
-        client.allocator.destroy(node);
-    }
-
-    next = client.connection_used.first;
-    while (next) |node| {
-        next = node.next;
-
-        node.data.close(client);
-
-        client.allocator.destroy(node);
-    }
+    client.connection_pool.deinit(client);
 
     client.ca_bundle.deinit(client.allocator);
     client.* = undefined;
@@ -168,36 +229,25 @@ pub fn deinit(client: *Client) void {
 
 pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream);
 
-pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionNode {
-    { // Search through the connection pool for a potential connection.
-        client.connection_mutex.lock();
-        defer client.connection_mutex.unlock();
-
-        var potential = client.connection_pool.last;
-        while (potential) |node| {
-            const same_host = mem.eql(u8, node.data.host, host);
-            const same_port = node.data.port == port;
-            const same_protocol = node.data.protocol == protocol;
-
-            if (same_host and same_port and same_protocol) {
-                client.acquire(node, true);
-                return node;
-            }
-
-            potential = node.prev;
-        }
-    }
+pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node {
+    if (client.connection_pool.findConnection(.{
+        .host = host,
+        .port = port,
+        .is_tls = protocol == .tls,
+    })) |node|
+        return node;
 
-    const conn = try client.allocator.create(ConnectionNode);
+    const conn = try client.allocator.create(ConnectionPool.Node);
     errdefer client.allocator.destroy(conn);
+    conn.* = .{ .data = undefined };
 
-    conn.* = .{ .data = .{
+    conn.data = .{
         .stream = try net.tcpConnectToHost(client.allocator, host, port),
         .tls_client = undefined,
         .protocol = protocol,
         .host = try client.allocator.dupe(u8, host),
         .port = port,
-    } };
+    };
 
     switch (protocol) {
         .plain => {},
@@ -210,12 +260,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
         },
     }
 
-    {
-        client.connection_mutex.lock();
-        defer client.connection_mutex.unlock();
-
-        client.connection_used.append(conn);
-    }
+    client.connection_pool.addUsed(conn);
 
     return conn;
 }
@@ -247,8 +292,8 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req
     const host = uri.host orelse return error.UriMissingHost;
 
     if (client.next_https_rescan_certs and protocol == .tls) {
-        client.connection_mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex.
-        defer client.connection_mutex.unlock();
+        client.connection_pool.mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex.
+        defer client.connection_pool.mutex.unlock();
 
         if (client.next_https_rescan_certs) {
             try client.ca_bundle.rescan(client.allocator);
lib/std/std.zig
@@ -185,6 +185,11 @@ pub const options = struct {
         options_override.keep_sigpipe
     else
         false;
+
+    pub const http_connection_pool_size = if (@hasDecl(options_override, "http_connection_pool_size"))
+        options_override.http_connection_pool_size
+    else
+        http.Client.default_connection_pool_size;
 };
 
 // This forces the start.zig file to be imported, and the comptime logic inside that