Commit e1c37f70d4

Nameless <truemedian@gmail.com>
2023-10-03 21:26:06
std.http.Client: store *Connection instead of a pool node, buffer writes
1 parent 1afeada
Changed files (4)
lib
test
standalone
lib/std/crypto/tls/Client.zig
@@ -881,7 +881,7 @@ pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize {
 /// The `iovecs` parameter is mutable because this function needs to mutate the fields in
 /// order to handle partial reads from the underlying stream layer.
 pub fn readv(c: *Client, stream: anytype, iovecs: []std.os.iovec) !usize {
-    return readvAtLeast(c, stream, iovecs);
+    return readvAtLeast(c, stream, iovecs, 1);
 }
 
 /// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
lib/std/http/Client.zig
@@ -54,7 +54,7 @@ pub const ConnectionPool = struct {
 
     /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe.
     /// If no connection is found, null is returned.
-    pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Node {
+    pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection {
         pool.mutex.lock();
         defer pool.mutex.unlock();
 
@@ -65,7 +65,7 @@ pub const ConnectionPool = struct {
             if (!std.ascii.eqlIgnoreCase(node.data.host, criteria.host)) continue;
 
             pool.acquireUnsafe(node);
-            return node;
+            return &node.data;
         }
 
         return null;
@@ -89,10 +89,12 @@ pub const ConnectionPool = struct {
 
     /// Tries to release a connection back to the connection pool. This function is threadsafe.
     /// If the connection is marked as closing, it will be closed instead.
-    pub fn release(pool: *ConnectionPool, allocator: Allocator, node: *Node) void {
+    pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void {
         pool.mutex.lock();
         defer pool.mutex.unlock();
 
+        const node = @fieldParentPtr(Node, "data", connection);
+
         pool.used.remove(node);
 
         if (node.data.closing or pool.free_size == 0) {
@@ -151,6 +153,8 @@ pub const ConnectionPool = struct {
 /// An interface to either a plain or TLS connection.
 pub const Connection = struct {
     pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
+    const BufferSize = std.math.IntFittingRange(0, buffer_size);
+
     pub const Protocol = enum { plain, tls };
 
     stream: net.Stream,
@@ -164,14 +168,16 @@ pub const Connection = struct {
     proxied: bool = false,
     closing: bool = false,
 
-    read_start: u16 = 0,
-    read_end: u16 = 0,
+    read_start: BufferSize = 0,
+    read_end: BufferSize = 0,
+    write_end: BufferSize = 0,
     read_buf: [buffer_size]u8 = undefined,
+    write_buf: [buffer_size]u8 = undefined,
 
-    pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
+    pub fn readvDirect(conn: *Connection, buffers: []std.os.iovec) ReadError!usize {
         return switch (conn.protocol) {
-            .plain => conn.stream.readAtLeast(buffer, len),
-            .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len),
+            .plain => conn.stream.readv(buffers),
+            .tls => conn.tls_client.readv(conn.stream, buffers),
         } catch |err| {
             // TODO: https://github.com/ziglang/zig/issues/2473
             if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert;
@@ -188,58 +194,52 @@ pub const Connection = struct {
     pub fn fill(conn: *Connection) ReadError!void {
         if (conn.read_end != conn.read_start) return;
 
-        const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1);
+        var iovecs = [1]std.os.iovec{
+            .{ .iov_base = &conn.read_buf, .iov_len = conn.read_buf.len },
+        };
+        const nread = try conn.readvDirect(&iovecs);
         if (nread == 0) return error.EndOfStream;
         conn.read_start = 0;
-        conn.read_end = @as(u16, @intCast(nread));
+        conn.read_end = @intCast(nread);
     }
 
     pub fn peek(conn: *Connection) []const u8 {
         return conn.read_buf[conn.read_start..conn.read_end];
     }
 
-    pub fn drop(conn: *Connection, num: u16) void {
+    pub fn drop(conn: *Connection, num: BufferSize) void {
         conn.read_start += num;
     }
 
-    pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
-        assert(len <= buffer.len);
-
-        var out_index: u16 = 0;
-        while (out_index < len) {
-            const available_read = conn.read_end - conn.read_start;
-            const available_buffer = buffer.len - out_index;
-
-            if (available_read > available_buffer) { // partially read buffered data
-                @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]);
-                out_index += @as(u16, @intCast(available_buffer));
-                conn.read_start += @as(u16, @intCast(available_buffer));
+    pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
+        const available_read = conn.read_end - conn.read_start;
+        const available_buffer = buffer.len;
 
-                break;
-            } else if (available_read > 0) { // fully read buffered data
-                @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]);
-                out_index += available_read;
-                conn.read_start += available_read;
+        if (available_read > available_buffer) { // partially read buffered data
+            @memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]);
+            conn.read_start += @intCast(available_buffer);
 
-                if (out_index >= len) break;
-            }
+            return available_buffer;
+        } else if (available_read > 0) { // fully read buffered data
+            @memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]);
+            conn.read_start += available_read;
 
-            const leftover_buffer = available_buffer - available_read;
-            const leftover_len = len - out_index;
+            return available_read;
+        }
 
-            if (leftover_buffer > conn.read_buf.len) {
-                // skip the buffer if the output is large enough
-                return conn.rawReadAtLeast(buffer[out_index..], leftover_len);
-            }
+        var iovecs = [2]std.os.iovec{
+            .{ .iov_base = buffer.ptr, .iov_len = buffer.len },
+            .{ .iov_base = &conn.read_buf, .iov_len = conn.read_buf.len },
+        };
+        const nread = try conn.readvDirect(&iovecs);
 
-            try conn.fill();
+        if (nread > buffer.len) {
+            conn.read_start = 0;
+            conn.read_end = @intCast(nread - buffer.len);
+            return buffer.len;
         }
 
-        return out_index;
-    }
-
-    pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
-        return conn.readAtLeast(buffer, 1);
+        return nread;
     }
 
     pub const ReadError = error{
@@ -257,7 +257,7 @@ pub const Connection = struct {
         return Reader{ .context = conn };
     }
 
-    pub fn writeAll(conn: *Connection, buffer: []const u8) !void {
+    pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void {
         return switch (conn.protocol) {
             .plain => conn.stream.writeAll(buffer),
             .tls => conn.tls_client.writeAll(conn.stream, buffer),
@@ -267,14 +267,27 @@ pub const Connection = struct {
         };
     }
 
-    pub fn write(conn: *Connection, buffer: []const u8) !usize {
-        return switch (conn.protocol) {
-            .plain => conn.stream.write(buffer),
-            .tls => conn.tls_client.write(conn.stream, buffer),
-        } catch |err| switch (err) {
-            error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
-            else => return error.UnexpectedWriteFailure,
-        };
+    pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize {
+        if (conn.write_end + buffer.len > conn.write_buf.len) {
+            try conn.flush();
+
+            if (buffer.len > conn.write_buf.len) {
+                try conn.writeAllDirect(buffer);
+                return buffer.len;
+            }
+        }
+
+        @memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer);
+        conn.write_end += @intCast(buffer.len);
+
+        return buffer.len;
+    }
+
+    pub fn flush(conn: *Connection) WriteError!void {
+        if (conn.write_end == 0) return;
+
+        try conn.writeAllDirect(conn.write_buf[0..conn.write_end]);
+        conn.write_end = 0;
     }
 
     pub const WriteError = error{
@@ -455,7 +468,7 @@ pub const Request = struct {
     uri: Uri,
     client: *Client,
     /// is null when this connection is released
-    connection: ?*ConnectionPool.Node,
+    connection: ?*Connection,
 
     method: http.Method,
     version: http.Version = .@"HTTP/1.1",
@@ -489,7 +502,7 @@ pub const Request = struct {
         if (req.connection) |connection| {
             if (!req.response.parser.done) {
                 // If the response wasn't fully read, then we need to close the connection.
-                connection.data.closing = true;
+                connection.closing = true;
             }
             req.client.connection_pool.release(req.client.allocator, connection);
         }
@@ -548,8 +561,7 @@ pub const Request = struct {
     pub fn start(req: *Request, options: StartOptions) StartError!void {
         if (!req.method.requestHasBody() and req.transfer_encoding != .none) return error.UnsupportedTransferEncoding;
 
-        var buffered = std.io.bufferedWriter(req.connection.?.data.writer());
-        const w = buffered.writer();
+        const w = req.connection.?.writer();
 
         try req.method.write(w);
         try w.writeByte(' ');
@@ -558,9 +570,9 @@ pub const Request = struct {
             try req.uri.writeToStream(.{ .authority = true }, w);
         } else {
             try req.uri.writeToStream(.{
-                .scheme = req.connection.?.data.proxied,
-                .authentication = req.connection.?.data.proxied,
-                .authority = req.connection.?.data.proxied,
+                .scheme = req.connection.?.proxied,
+                .authentication = req.connection.?.proxied,
+                .authority = req.connection.?.proxied,
                 .path = true,
                 .query = true,
                 .raw = options.raw_uri,
@@ -629,8 +641,8 @@ pub const Request = struct {
             try w.writeAll("\r\n");
         }
 
-        if (req.connection.?.data.proxied) {
-            const proxy_headers: ?http.Headers = switch (req.connection.?.data.protocol) {
+        if (req.connection.?.proxied) {
+            const proxy_headers: ?http.Headers = switch (req.connection.?.protocol) {
                 .plain => if (req.client.http_proxy) |proxy| proxy.headers else null,
                 .tls => if (req.client.https_proxy) |proxy| proxy.headers else null,
             };
@@ -649,7 +661,7 @@ pub const Request = struct {
 
         try w.writeAll("\r\n");
 
-        try buffered.flush();
+        try req.connection.?.flush();
     }
 
     const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
@@ -665,7 +677,7 @@ pub const Request = struct {
 
         var index: usize = 0;
         while (index == 0) {
-            const amt = try req.response.parser.read(&req.connection.?.data, buf[index..], req.response.skip);
+            const amt = try req.response.parser.read(req.connection.?, buf[index..], req.response.skip);
             if (amt == 0 and req.response.parser.done) break;
             index += amt;
         }
@@ -683,10 +695,10 @@ pub const Request = struct {
     pub fn wait(req: *Request) WaitError!void {
         while (true) { // handle redirects
             while (true) { // read headers
-                try req.connection.?.data.fill();
+                try req.connection.?.fill();
 
-                const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek());
-                req.connection.?.data.drop(@as(u16, @intCast(nchecked)));
+                const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.peek());
+                req.connection.?.drop(@intCast(nchecked));
 
                 if (req.response.parser.state.isContent()) break;
             }
@@ -701,7 +713,7 @@ pub const Request = struct {
 
             // we're switching protocols, so this connection is no longer doing http
             if (req.response.status == .switching_protocols or (req.method == .CONNECT and req.response.status == .ok)) {
-                req.connection.?.data.closing = false;
+                req.connection.?.closing = false;
                 req.response.parser.done = true;
             }
 
@@ -712,9 +724,9 @@ pub const Request = struct {
             const res_connection = req.response.headers.getFirstValue("connection");
             const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?);
             if (res_keepalive and (req_keepalive or req_connection == null)) {
-                req.connection.?.data.closing = false;
+                req.connection.?.closing = false;
             } else {
-                req.connection.?.data.closing = true;
+                req.connection.?.closing = true;
             }
 
             if (req.response.transfer_encoding) |te| {
@@ -827,10 +839,10 @@ pub const Request = struct {
             const has_trail = !req.response.parser.state.isContent();
 
             while (!req.response.parser.state.isContent()) { // read trailing headers
-                try req.connection.?.data.fill();
+                try req.connection.?.fill();
 
-                const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek());
-                req.connection.?.data.drop(@as(u16, @intCast(nchecked)));
+                const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.peek());
+                req.connection.?.drop(@intCast(nchecked));
             }
 
             if (has_trail) {
@@ -868,16 +880,16 @@ pub const Request = struct {
     pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
         switch (req.transfer_encoding) {
             .chunked => {
-                try req.connection.?.data.writer().print("{x}\r\n", .{bytes.len});
-                try req.connection.?.data.writeAll(bytes);
-                try req.connection.?.data.writeAll("\r\n");
+                try req.connection.?.writer().print("{x}\r\n", .{bytes.len});
+                try req.connection.?.writer().writeAll(bytes);
+                try req.connection.?.writer().writeAll("\r\n");
 
                 return bytes.len;
             },
             .content_length => |*len| {
                 if (len.* < bytes.len) return error.MessageTooLong;
 
-                const amt = try req.connection.?.data.write(bytes);
+                const amt = try req.connection.?.write(bytes);
                 len.* -= amt;
                 return amt;
             },
@@ -897,10 +909,12 @@ pub const Request = struct {
     /// Finish the body of a request. This notifies the server that you have no more data to send.
     pub fn finish(req: *Request) FinishError!void {
         switch (req.transfer_encoding) {
-            .chunked => try req.connection.?.data.writeAll("0\r\n\r\n"),
+            .chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"),
             .content_length => |len| if (len != 0) return error.MessageNotCompleted,
             .none => {},
         }
+
+        try req.connection.?.flush();
     }
 };
 
@@ -1024,7 +1038,7 @@ pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, Network
 
 /// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open.
 /// This function is threadsafe.
-pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*ConnectionPool.Node {
+pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection {
     if (client.connection_pool.findConnection(.{
         .host = host,
         .port = port,
@@ -1074,12 +1088,12 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec
 
     client.connection_pool.addUsed(conn);
 
-    return conn;
+    return &conn.data;
 }
 
 pub const ConnectUnixError = Allocator.Error || std.os.SocketError || error{ NameTooLong, Unsupported } || std.os.ConnectError;
 
-pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*ConnectionPool.Node {
+pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connection {
     if (!net.has_unix_sockets) return error.Unsupported;
 
     if (client.connection_pool.findConnection(.{
@@ -1108,7 +1122,7 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti
 
     client.connection_pool.addUsed(conn);
 
-    return conn;
+    return &conn.data;
 }
 
 pub fn connectTunnel(
@@ -1116,7 +1130,7 @@ pub fn connectTunnel(
     proxy: *ProxyInformation,
     tunnel_host: []const u8,
     tunnel_port: u16,
-) !*ConnectionPool.Node {
+) !*Connection {
     if (!proxy.supports_connect) return error.TunnelNotSupported;
 
     if (client.connection_pool.findConnection(.{
@@ -1130,7 +1144,7 @@ pub fn connectTunnel(
     _ = tunnel: {
         const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol);
         errdefer {
-            conn.data.closing = true;
+            conn.closing = true;
             client.connection_pool.release(client.allocator, conn);
         }
 
@@ -1171,12 +1185,12 @@ pub fn connectTunnel(
         // this connection is now a tunnel, so we can't use it for anything else, it will only be released when the client is de-initialized.
         req.connection = null;
 
-        client.allocator.free(conn.data.host);
-        conn.data.host = try client.allocator.dupe(u8, tunnel_host);
-        errdefer client.allocator.free(conn.data.host);
+        client.allocator.free(conn.host);
+        conn.host = try client.allocator.dupe(u8, tunnel_host);
+        errdefer client.allocator.free(conn.host);
 
-        conn.data.port = tunnel_port;
-        conn.data.closing = false;
+        conn.port = tunnel_port;
+        conn.closing = false;
 
         return conn;
     } catch {
@@ -1190,7 +1204,7 @@ pub fn connectTunnel(
 const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUrlScheme, ConnectionRefused };
 pub const ConnectError = ConnectErrorPartial || RequestError;
 
-pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node {
+pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*Connection {
     // pointer required so that `supports_connect` can be updated if a CONNECT fails
     const potential_proxy: ?*ProxyInformation = switch (protocol) {
         .plain => if (client.http_proxy) |*proxy_info| proxy_info else null,
@@ -1213,11 +1227,11 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
         // fall back to using the proxy as a normal http proxy
         const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol);
         errdefer {
-            conn.data.closing = true;
+            conn.closing = true;
             client.connection_pool.release(conn);
         }
 
-        conn.data.proxied = true;
+        conn.proxied = true;
         return conn;
     }
 
@@ -1240,7 +1254,7 @@ pub const RequestOptions = struct {
     header_strategy: StorageStrategy = .{ .dynamic = 16 * 1024 },
 
     /// Must be an already acquired connection.
-    connection: ?*ConnectionPool.Node = null,
+    connection: ?*Connection = null,
 
     pub const StorageStrategy = union(enum) {
         /// In this case, the client's Allocator will be used to store the
lib/std/http/protocol.zig
@@ -529,7 +529,7 @@ pub const HeadersParser = struct {
                         try conn.fill();
 
                         const nread = @min(conn.peek().len, data_avail);
-                        conn.drop(@as(u16, @intCast(nread)));
+                        conn.drop(@intCast(nread));
                         r.next_chunk_length -= nread;
 
                         if (r.next_chunk_length == 0) r.done = true;
@@ -553,7 +553,7 @@ pub const HeadersParser = struct {
                     try conn.fill();
 
                     const i = r.findChunkedLen(conn.peek());
-                    conn.drop(@as(u16, @intCast(i)));
+                    conn.drop(@intCast(i));
 
                     switch (r.state) {
                         .invalid => return error.HttpChunkInvalid,
@@ -582,7 +582,7 @@ pub const HeadersParser = struct {
                         try conn.fill();
 
                         const nread = @min(conn.peek().len, data_avail);
-                        conn.drop(@as(u16, @intCast(nread)));
+                        conn.drop(@intCast(nread));
                         r.next_chunk_length -= nread;
                     } else if (out_avail > 0) {
                         const can_read: usize = @intCast(@min(data_avail, out_avail));
test/standalone/http.zig
@@ -680,7 +680,7 @@ pub fn main() !void {
         for (0..total_connections) |i| {
             var req = try client.request(.GET, uri, .{ .allocator = calloc }, .{});
             req.response.parser.done = true;
-            req.connection.?.data.closing = false;
+            req.connection.?.closing = false;
             requests[i] = req;
         }