Commit 8136123aa7

Nameless <truemedian@gmail.com>
2023-05-27 14:40:56
std.http.Client: collapse BufferedConnection into Connection
1 parent 6c2f374
Changed files (2)
lib/std/http/Client.zig
@@ -36,21 +36,7 @@ pub const ConnectionPool = struct {
         is_tls: bool,
     };
 
-    pub const StoredConnection = struct {
-        buffered: BufferedConnection,
-        host: []u8,
-        port: u16,
-
-        proxied: bool = false,
-        closing: bool = false,
-
-        pub fn deinit(self: *StoredConnection, client: *Client) void {
-            self.buffered.close(client);
-            client.allocator.free(self.host);
-        }
-    };
-
-    const Queue = std.TailQueue(StoredConnection);
+    const Queue = std.TailQueue(Connection);
     pub const Node = Queue.Node;
 
     mutex: std.Thread.Mutex = .{},
@@ -69,7 +55,7 @@ pub const ConnectionPool = struct {
 
         var next = pool.free.last;
         while (next) |node| : (next = node.prev) {
-            if ((node.data.buffered.conn.protocol == .tls) != criteria.is_tls) continue;
+            if ((node.data.protocol == .tls) != criteria.is_tls) continue;
             if (node.data.port != criteria.port) continue;
             if (!mem.eql(u8, node.data.host, criteria.host)) continue;
 
@@ -160,27 +146,25 @@ pub const ConnectionPool = struct {
 
 /// An interface to either a plain or TLS connection.
 pub const Connection = struct {
+    pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
+    pub const Protocol = enum { plain, tls };
+
     stream: net.Stream,
     /// undefined unless protocol is tls.
     tls_client: *std.crypto.tls.Client,
+
     protocol: Protocol,
+    host: []u8,
+    port: u16,
 
-    pub const Protocol = enum { plain, tls };
+    proxied: bool = false,
+    closing: bool = false,
 
-    pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
-        return switch (conn.protocol) {
-            .plain => conn.stream.read(buffer),
-            .tls => conn.tls_client.read(conn.stream, buffer),
-        } catch |err| switch (err) {
-            error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
-            error.TlsAlert => return error.TlsAlert,
-            error.ConnectionTimedOut => return error.ConnectionTimedOut,
-            error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
-            else => return error.UnexpectedReadFailure,
-        };
-    }
+    read_start: u16 = 0,
+    read_end: u16 = 0,
+    read_buf: [buffer_size]u8 = undefined,
 
-    pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
+    pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
         return switch (conn.protocol) {
             .plain => conn.stream.readAtLeast(buffer, len),
             .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len),
@@ -193,12 +177,70 @@ pub const Connection = struct {
         };
     }
 
+    pub fn fill(conn: *Connection) ReadError!void {
+        if (conn.read_end != conn.read_start) return;
+
+        const nread = try conn.conn.read(conn.read_buf[0..]);
+        if (nread == 0) return error.EndOfStream;
+        conn.read_start = 0;
+        conn.read_end = @intCast(u16, nread);
+    }
+
+    pub fn peek(conn: *Connection) []const u8 {
+        return conn.read_buf[conn.read_start..conn.read_end];
+    }
+
+    pub fn drop(conn: *Connection, num: u16) void {
+        conn.read_start += num;
+    }
+
+    pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
+        assert(len <= buffer.len);
+
+        var out_index: u16 = 0;
+        while (out_index < len) {
+            const available_read = conn.read_end - conn.read_start;
+            const available_buffer = buffer.len - out_index;
+
+            if (available_read > available_buffer) { // partially read buffered data
+                @memcpy(buffer[out_index..], conn.read_buf[conn.read_start..][0..available_buffer]);
+                out_index += available_buffer;
+                conn.read_start += available_buffer;
+
+                break;
+            } else if (available_read > 0) { // fully read buffered data
+                @memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..]);
+                out_index += available_read;
+                conn.read_start += available_read;
+
+                if (out_index >= len) break;
+            }
+
+            const leftover_buffer = available_buffer - available_read;
+            const leftover_len = len - out_index;
+
+            if (leftover_buffer > conn.read_buf.len) {
+                // skip the buffer if the output is large enough
+                return conn.rawReadAtLeast(buffer[out_index..], leftover_len);
+            }
+
+            try conn.fill();
+        }
+
+        return out_index;
+    }
+
+    pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
+        return conn.readAtLeast(buffer, 1);
+    }
+
     pub const ReadError = error{
         TlsFailure,
         TlsAlert,
         ConnectionTimedOut,
         ConnectionResetByPeer,
         UnexpectedReadFailure,
+        EndOfStream,
     };
 
     pub const Reader = std.io.Reader(*Connection, ReadError, read);
@@ -247,111 +289,10 @@ pub const Connection = struct {
 
         conn.stream.close();
     }
-};
-
-/// A buffered (and peekable) Connection.
-pub const BufferedConnection = struct {
-    pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
-
-    conn: Connection,
-    read_buf: [buffer_size]u8 = undefined,
-    read_start: u16 = 0,
-    read_end: u16 = 0,
-
-    write_buf: [buffer_size]u8 = undefined,
-    write_end: u16 = 0,
-
-    pub fn fill(bconn: *BufferedConnection) ReadError!void {
-        if (bconn.read_end != bconn.read_start) return;
-
-        const nread = try bconn.conn.read(bconn.read_buf[0..]);
-        if (nread == 0) return error.EndOfStream;
-        bconn.read_start = 0;
-        bconn.read_end = @intCast(u16, nread);
-    }
-
-    pub fn peek(bconn: *BufferedConnection) []const u8 {
-        return bconn.read_buf[bconn.read_start..bconn.read_end];
-    }
-
-    pub fn clear(bconn: *BufferedConnection, num: u16) void {
-        bconn.read_start += num;
-    }
-
-    pub fn readAtLeast(bconn: *BufferedConnection, buffer: []u8, len: usize) ReadError!usize {
-        var out_index: u16 = 0;
-        while (out_index < len) {
-            const available = bconn.read_end - bconn.read_start;
-            const left = buffer.len - out_index;
-
-            if (available > 0) {
-                const can_read = @intCast(u16, @min(available, left));
-
-                @memcpy(buffer[out_index..][0..can_read], bconn.read_buf[bconn.read_start..][0..can_read]);
-                out_index += can_read;
-                bconn.read_start += can_read;
-
-                continue;
-            }
-
-            if (left > bconn.read_buf.len) {
-                // skip the buffer if the output is large enough
-                return bconn.conn.read(buffer[out_index..]);
-            }
-
-            try bconn.fill();
-        }
-
-        return out_index;
-    }
-
-    pub fn read(bconn: *BufferedConnection, buffer: []u8) ReadError!usize {
-        return bconn.readAtLeast(buffer, 1);
-    }
 
-    pub const ReadError = Connection.ReadError || error{EndOfStream};
-    pub const Reader = std.io.Reader(*BufferedConnection, ReadError, read);
-
-    pub fn reader(bconn: *BufferedConnection) Reader {
-        return Reader{ .context = bconn };
-    }
-
-    pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void {
-        if (bconn.write_buf.len - bconn.write_end >= buffer.len) {
-            @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer);
-            bconn.write_end += @intCast(u16, buffer.len);
-        } else {
-            try bconn.flush();
-            try bconn.conn.writeAll(buffer);
-        }
-    }
-
-    pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize {
-        if (bconn.write_buf.len - bconn.write_end >= buffer.len) {
-            @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer);
-            bconn.write_end += @intCast(u16, buffer.len);
-
-            return buffer.len;
-        } else {
-            try bconn.flush();
-            return try bconn.conn.write(buffer);
-        }
-    }
-
-    pub fn flush(bconn: *BufferedConnection) WriteError!void {
-        defer bconn.write_end = 0;
-        return bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]);
-    }
-
-    pub const WriteError = Connection.WriteError;
-    pub const Writer = std.io.Writer(*BufferedConnection, WriteError, write);
-
-    pub fn writer(bconn: *BufferedConnection) Writer {
-        return Writer{ .context = bconn };
-    }
-
-    pub fn close(bconn: *BufferedConnection, client: *const Client) void {
-        bconn.conn.close(client);
+    pub fn deinit(conn: *Connection, client: *const Client) void {
+        conn.close(client);
+        client.allocator.free(conn.host);
     }
 };
 
@@ -585,11 +526,12 @@ pub const Request = struct {
         };
     }
 
-    pub const StartError = BufferedConnection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding };
+    pub const StartError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding };
 
     /// Send the request to the server.
     pub fn start(req: *Request) StartError!void {
-        const w = req.connection.data.buffered.writer();
+        var buffered = std.io.bufferedWriter(req.connection.data.writer());
+        const w = buffered.writer();
 
         try w.writeAll(@tagName(req.method));
         try w.writeByte(' ');
@@ -662,11 +604,9 @@ pub const Request = struct {
         try w.print("{}", .{req.headers});
 
         try w.writeAll("\r\n");
-
-        try req.connection.data.buffered.flush();
     }
 
-    pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError;
+    pub const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
 
     pub const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead);
 
@@ -679,7 +619,7 @@ pub const Request = struct {
 
         var index: usize = 0;
         while (index == 0) {
-            const amt = try req.response.parser.read(&req.connection.data.buffered, buf[index..], req.response.skip);
+            const amt = try req.response.parser.read(&req.connection.data, buf[index..], req.response.skip);
             if (amt == 0 and req.response.parser.done) break;
             index += amt;
         }
@@ -697,10 +637,10 @@ pub const Request = struct {
     pub fn wait(req: *Request) WaitError!void {
         while (true) { // handle redirects
             while (true) { // read headers
-                try req.connection.data.buffered.fill();
+                try req.connection.data.fill();
 
-                const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
-                req.connection.data.buffered.clear(@intCast(u16, nchecked));
+                const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
+                req.connection.data.drop(@intCast(u16, nchecked));
 
                 if (req.response.parser.state.isContent()) break;
             }
@@ -816,10 +756,10 @@ pub const Request = struct {
             const has_trail = !req.response.parser.state.isContent();
 
             while (!req.response.parser.state.isContent()) { // read trailing headers
-                try req.connection.data.buffered.fill();
+                try req.connection.data.fill();
 
-                const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
-                req.connection.data.buffered.clear(@intCast(u16, nchecked));
+                const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
+                req.connection.data.clear(@intCast(u16, nchecked));
             }
 
             if (has_trail) {
@@ -845,7 +785,7 @@ pub const Request = struct {
         return index;
     }
 
-    pub const WriteError = BufferedConnection.WriteError || error{ NotWriteable, MessageTooLong };
+    pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong };
 
     pub const Writer = std.io.Writer(*Request, WriteError, write);
 
@@ -857,16 +797,16 @@ pub const Request = struct {
     pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
         switch (req.transfer_encoding) {
             .chunked => {
-                try req.connection.data.buffered.writer().print("{x}\r\n", .{bytes.len});
-                try req.connection.data.buffered.writeAll(bytes);
-                try req.connection.data.buffered.writeAll("\r\n");
+                try req.connection.data.writer().print("{x}\r\n", .{bytes.len});
+                try req.connection.data.writeAll(bytes);
+                try req.connection.data.writeAll("\r\n");
 
                 return bytes.len;
             },
             .content_length => |*len| {
                 if (len.* < bytes.len) return error.MessageTooLong;
 
-                const amt = try req.connection.data.buffered.write(bytes);
+                const amt = try req.connection.data.write(bytes);
                 len.* -= amt;
                 return amt;
             },
@@ -886,12 +826,10 @@ pub const Request = struct {
     /// Finish the body of a request. This notifies the server that you have no more data to send.
     pub fn finish(req: *Request) FinishError!void {
         switch (req.transfer_encoding) {
-            .chunked => try req.connection.data.buffered.writeAll("0\r\n\r\n"),
+            .chunked => try req.connection.data.writeAll("0\r\n\r\n"),
             .content_length => |len| if (len != 0) return error.MessageNotCompleted,
             .none => {},
         }
-
-        try req.connection.data.buffered.flush();
     }
 };
 
@@ -948,11 +886,10 @@ pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol:
     errdefer stream.close();
 
     conn.data = .{
-        .buffered = .{ .conn = .{
-            .stream = stream,
-            .tls_client = undefined,
-            .protocol = protocol,
-        } },
+        .stream = stream,
+        .tls_client = undefined,
+        .protocol = protocol,
+
         .host = try client.allocator.dupe(u8, host),
         .port = port,
     };
@@ -961,13 +898,13 @@ pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol:
     switch (protocol) {
         .plain => {},
         .tls => {
-            conn.data.buffered.conn.tls_client = try client.allocator.create(std.crypto.tls.Client);
-            errdefer client.allocator.destroy(conn.data.buffered.conn.tls_client);
+            conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client);
+            errdefer client.allocator.destroy(conn.data.tls_client);
 
-            conn.data.buffered.conn.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed;
+            conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed;
             // This is appropriate for HTTPS because the HTTP headers contain
             // the content length which is used to detect truncation attacks.
-            conn.data.buffered.conn.tls_client.allow_truncation_attacks = true;
+            conn.data.tls_client.allow_truncation_attacks = true;
         },
     }
 
@@ -1003,7 +940,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
     }
 }
 
-pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request.StartError || std.fmt.ParseIntError || BufferedConnection.WriteError || error{
+pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request.StartError || std.fmt.ParseIntError || Connection.WriteError || error{
     UnsupportedUrlScheme,
     UriMissingHost,
 
lib/std/http/protocol.zig
@@ -641,8 +641,8 @@ const MockBufferedConnection = struct {
         return bconn.buf[bconn.start..bconn.end];
     }
 
-    pub fn clear(bconn: *MockBufferedConnection, num: u16) void {
-        bconn.start += num;
+    pub fn drop(conn: *MockBufferedConnection, num: u16) void {
+        conn.start += num;
     }
 
     pub fn readAtLeast(bconn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize {
@@ -760,8 +760,8 @@ test "HeadersParser.read length" {
     while (true) { // read headers
         try bconn.fill();
 
-        const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek());
-        bconn.clear(@intCast(u16, nchecked));
+        const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek());
+        conn.drop(@intCast(u16, nchecked));
 
         if (r.state.isContent()) break;
     }
@@ -791,8 +791,8 @@ test "HeadersParser.read chunked" {
     while (true) { // read headers
         try bconn.fill();
 
-        const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek());
-        bconn.clear(@intCast(u16, nchecked));
+        const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek());
+        conn.drop(@intCast(u16, nchecked));
 
         if (r.state.isContent()) break;
     }
@@ -821,8 +821,8 @@ test "HeadersParser.read chunked trailer" {
     while (true) { // read headers
         try bconn.fill();
 
-        const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek());
-        bconn.clear(@intCast(u16, nchecked));
+        const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek());
+        conn.drop(@intCast(u16, nchecked));
 
         if (r.state.isContent()) break;
     }
@@ -836,8 +836,8 @@ test "HeadersParser.read chunked trailer" {
     while (true) { // read headers
         try bconn.fill();
 
-        const nchecked = try r.checkCompleteHead(std.testing.allocator, bconn.peek());
-        bconn.clear(@intCast(u16, nchecked));
+        const nchecked = try r.checkCompleteHead(std.testing.allocator, conn.peek());
+        conn.drop(@intCast(u16, nchecked));
 
         if (r.state.isContent()) break;
     }