Commit a0f9a5e78d

Andrew Kelley <andrew@ziglang.org>
2025-08-16 06:52:18
std: more reliable HTTP and TLS networking
* std.Io.Reader: fix confused semantics of rebase. Before it was ambiguous whether it was supposed to be based on end or seek. Now it is clearly based on seek, with an added assertion for clarity. * std.crypto.tls.Client: fix panic due to not enough buffer size available. Also, avoid unnecessary rebasing. * std.http.Reader: introduce max_head_len to limit HTTP header length. This prevents crash in underlying reader which may require a minimum buffer length. * std.http.Client: choose better buffer sizes for streams and TLS client. Crucially, the buffer shared by HTTP reader and TLS client needs to be big enough for all http headers *and* the max TLS record size. Bump HTTP header size default from 4K to 8K. fixes #24872 I have noticed however that there are still fetch problems
1 parent 07b753f
Changed files (6)
lib/std/crypto/tls/Client.zig
@@ -183,7 +183,6 @@ const InitError = error{
 /// `input` is asserted to have buffer capacity at least `min_buffer_len`.
 pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client {
     assert(input.buffer.len >= min_buffer_len);
-    assert(output.buffer.len >= min_buffer_len);
     const host = switch (options.host) {
         .no_verification => "",
         .explicit => |host| host,
@@ -1124,12 +1123,6 @@ fn readIndirect(c: *Client) Reader.Error!usize {
         if (record_end > input.buffered().len) return 0;
     }
 
-    if (r.seek == r.end) {
-        r.seek = 0;
-        r.end = 0;
-    }
-    const cleartext_buffer = r.buffer[r.end..];
-
     const cleartext_len, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) {
         inline else => |*p| switch (c.tls_version) {
             .tls_1_3 => {
@@ -1145,7 +1138,8 @@ fn readIndirect(c: *Client) Reader.Error!usize {
                     const operand: V = pad ++ mem.toBytes(big(c.read_seq));
                     break :nonce @as(V, pv.server_iv) ^ operand;
                 };
-                const cleartext = cleartext_buffer[0..ciphertext.len];
+                rebase(r, ciphertext.len);
+                const cleartext = r.buffer[r.end..][0..ciphertext.len];
                 P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch
                     return failRead(c, error.TlsBadRecordMac);
                 // TODO use scalar, non-slice version
@@ -1171,7 +1165,8 @@ fn readIndirect(c: *Client) Reader.Error!usize {
                 };
                 const ciphertext = input.take(message_len) catch unreachable; // already peeked
                 const auth_tag = (input.takeArray(P.mac_length) catch unreachable).*; // already peeked
-                const cleartext = cleartext_buffer[0..ciphertext.len];
+                rebase(r, ciphertext.len);
+                const cleartext = r.buffer[r.end..][0..ciphertext.len];
                 P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch
                     return failRead(c, error.TlsBadRecordMac);
                 break :cleartext .{ cleartext.len, ct };
@@ -1179,7 +1174,7 @@ fn readIndirect(c: *Client) Reader.Error!usize {
             else => unreachable,
         },
     };
-    const cleartext = cleartext_buffer[0..cleartext_len];
+    const cleartext = r.buffer[r.end..][0..cleartext_len];
     c.read_seq = std.math.add(u64, c.read_seq, 1) catch return failRead(c, error.TlsSequenceOverflow);
     switch (inner_ct) {
         .alert => {
@@ -1275,6 +1270,15 @@ fn readIndirect(c: *Client) Reader.Error!usize {
     }
 }
 
+fn rebase(r: *Reader, capacity: usize) void {
+    if (r.buffer.len - r.end >= capacity) return;
+    const data = r.buffer[r.seek..r.end];
+    @memmove(r.buffer[0..data.len], data);
+    r.seek = 0;
+    r.end = data.len;
+    assert(r.buffer.len - r.end >= capacity);
+}
+
 fn failRead(c: *Client, err: ReadError) error{ReadFailed} {
     c.read_err = err;
     return error.ReadFailed;
lib/std/http/Client.zig
@@ -42,7 +42,7 @@ connection_pool: ConnectionPool = .{},
 ///
 /// If the entire HTTP header cannot fit in this amount of bytes,
 /// `error.HttpHeadersOversize` will be returned from `Request.wait`.
-read_buffer_size: usize = 4096 + if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len,
+read_buffer_size: usize = 8192,
 /// Each `Connection` allocates this amount for the writer buffer.
 write_buffer_size: usize = 1024,
 
@@ -302,18 +302,22 @@ pub const Connection = struct {
             const base = try gpa.alignedAlloc(u8, .of(Tls), alloc_len);
             errdefer gpa.free(base);
             const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len];
-            const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.tls_buffer_size];
+            // The TLS client wants enough buffer for the max encrypted frame
+            // size, and the HTTP body reader wants enough buffer for the
+            // entire HTTP header. This means we need a combined upper bound.
+            const tls_read_buffer_len = client.tls_buffer_size + client.read_buffer_size;
+            const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..tls_read_buffer_len];
             const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size];
-            const write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size];
-            const read_buffer = write_buffer.ptr[write_buffer.len..][0..client.read_buffer_size];
-            assert(base.ptr + alloc_len == read_buffer.ptr + read_buffer.len);
+            const socket_write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size];
+            const socket_read_buffer = socket_write_buffer.ptr[socket_write_buffer.len..][0..client.tls_buffer_size];
+            assert(base.ptr + alloc_len == socket_read_buffer.ptr + socket_read_buffer.len);
             @memcpy(host_buffer, remote_host);
             const tls: *Tls = @ptrCast(base);
             tls.* = .{
                 .connection = .{
                     .client = client,
                     .stream_writer = stream.writer(tls_write_buffer),
-                    .stream_reader = stream.reader(tls_read_buffer),
+                    .stream_reader = stream.reader(socket_read_buffer),
                     .pool_node = .{},
                     .port = port,
                     .host_len = @intCast(remote_host.len),
@@ -329,8 +333,8 @@ pub const Connection = struct {
                         .host = .{ .explicit = remote_host },
                         .ca = .{ .bundle = client.ca_bundle },
                         .ssl_key_log = client.ssl_key_log,
-                        .read_buffer = read_buffer,
-                        .write_buffer = write_buffer,
+                        .read_buffer = tls_read_buffer,
+                        .write_buffer = socket_write_buffer,
                         // This is appropriate for HTTPS because the HTTP headers contain
                         // the content length which is used to detect truncation attacks.
                         .allow_truncation_attacks = true,
@@ -348,8 +352,9 @@ pub const Connection = struct {
         }
 
         fn allocLen(client: *Client, host_len: usize) usize {
-            return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size +
-                client.write_buffer_size + client.read_buffer_size;
+            const tls_read_buffer_len = client.tls_buffer_size + client.read_buffer_size;
+            return @sizeOf(Tls) + host_len + tls_read_buffer_len + client.tls_buffer_size +
+                client.write_buffer_size + client.tls_buffer_size;
         }
 
         fn host(tls: *Tls) []u8 {
@@ -1214,6 +1219,7 @@ pub const Request = struct {
             .state = .ready,
             // Populated when `http.Reader.bodyReader` is called.
             .interface = undefined,
+            .max_head_len = r.client.read_buffer_size,
         };
         r.redirect_behavior.subtractOne();
     }
@@ -1679,6 +1685,7 @@ pub fn request(
             .state = .ready,
             // Populated when `http.Reader.bodyReader` is called.
             .interface = undefined,
+            .max_head_len = client.read_buffer_size,
         },
         .keep_alive = options.keep_alive,
         .method = method,
lib/std/http/Server.zig
@@ -29,6 +29,7 @@ pub fn init(in: *Reader, out: *Writer) Server {
             .state = .ready,
             // Populated when `http.Reader.bodyReader` is called.
             .interface = undefined,
+            .max_head_len = in.buffer.len,
         },
         .out = out,
     };
@@ -251,6 +252,7 @@ pub const Request = struct {
                 .in = undefined,
                 .state = .received_head,
                 .interface = undefined,
+                .max_head_len = 4096,
             },
             .out = undefined,
         };
lib/std/Io/Reader.zig
@@ -86,12 +86,12 @@ pub const VTable = struct {
     /// `Reader.buffer`, whichever is bigger.
     readVec: *const fn (r: *Reader, data: [][]u8) Error!usize = defaultReadVec,
 
-    /// Ensures `capacity` more data can be buffered without rebasing.
+    /// Ensures `capacity` data can be buffered without rebasing.
     ///
     /// Asserts `capacity` is within buffer capacity, or that the stream ends
     /// within `capacity` bytes.
     ///
-    /// Only called when `capacity` cannot fit into the unused capacity of
+    /// Only called when `capacity` cannot be satisfied by unused capacity of
     /// `buffer`.
     ///
     /// The default implementation moves buffered data to the start of
@@ -1037,7 +1037,7 @@ fn fillUnbuffered(r: *Reader, n: usize) Error!void {
 ///
 /// Asserts buffer capacity is at least 1.
 pub fn fillMore(r: *Reader) Error!void {
-    try rebase(r, 1);
+    try rebase(r, r.end - r.seek + 1);
     var bufs: [1][]u8 = .{""};
     _ = try r.vtable.readVec(r, &bufs);
 }
@@ -1205,24 +1205,6 @@ pub fn takeLeb128(r: *Reader, comptime Result: type) TakeLeb128Error!Result {
     } }))) orelse error.Overflow;
 }
 
-pub fn expandTotalCapacity(r: *Reader, allocator: Allocator, n: usize) Allocator.Error!void {
-    if (n <= r.buffer.len) return;
-    if (r.seek > 0) rebase(r, r.buffer.len);
-    var list: ArrayList(u8) = .{
-        .items = r.buffer[0..r.end],
-        .capacity = r.buffer.len,
-    };
-    defer r.buffer = list.allocatedSlice();
-    try list.ensureTotalCapacity(allocator, n);
-}
-
-pub const FillAllocError = Error || Allocator.Error;
-
-pub fn fillAlloc(r: *Reader, allocator: Allocator, n: usize) FillAllocError!void {
-    try expandTotalCapacity(r, allocator, n);
-    return fill(r, n);
-}
-
 fn takeMultipleOf7Leb128(r: *Reader, comptime Result: type) TakeLeb128Error!Result {
     const result_info = @typeInfo(Result).int;
     comptime assert(result_info.bits % 7 == 0);
@@ -1253,9 +1235,9 @@ fn takeMultipleOf7Leb128(r: *Reader, comptime Result: type) TakeLeb128Error!Resu
     }
 }
 
-/// Ensures `capacity` more data can be buffered without rebasing.
+/// Ensures `capacity` data can be buffered without rebasing.
 pub fn rebase(r: *Reader, capacity: usize) RebaseError!void {
-    if (r.end + capacity <= r.buffer.len) {
+    if (r.buffer.len - r.seek >= capacity) {
         @branchHint(.likely);
         return;
     }
@@ -1263,11 +1245,12 @@ pub fn rebase(r: *Reader, capacity: usize) RebaseError!void {
 }
 
 pub fn defaultRebase(r: *Reader, capacity: usize) RebaseError!void {
-    if (r.end <= r.buffer.len - capacity) return;
+    assert(r.buffer.len - r.seek < capacity);
     const data = r.buffer[r.seek..r.end];
     @memmove(r.buffer[0..data.len], data);
     r.seek = 0;
     r.end = data.len;
+    assert(r.buffer.len - r.seek >= capacity);
 }
 
 test fixed {
lib/std/http.zig
@@ -329,6 +329,7 @@ pub const Reader = struct {
     /// read from `in`.
     trailers: []const u8 = &.{},
     body_err: ?BodyError = null,
+    max_head_len: usize,
 
     pub const RemainingChunkLen = enum(u64) {
         head = 0,
@@ -387,10 +388,11 @@ pub const Reader = struct {
     pub fn receiveHead(reader: *Reader) HeadError![]const u8 {
         reader.trailers = &.{};
         const in = reader.in;
+        const max_head_len = reader.max_head_len;
         var hp: HeadParser = .{};
         var head_len: usize = 0;
         while (true) {
-            if (in.buffer.len - head_len == 0) return error.HttpHeadersOversize;
+            if (head_len >= max_head_len) return error.HttpHeadersOversize;
             const remaining = in.buffered()[head_len..];
             if (remaining.len == 0) {
                 in.fillMore() catch |err| switch (err) {
lib/std/Io.zig
@@ -717,7 +717,12 @@ pub fn Poller(comptime StreamEnum: type) type {
                 const unused = r.buffer[r.end..];
                 if (unused.len >= min_len) return unused;
             }
-            if (r.seek > 0) r.rebase(r.buffer.len) catch unreachable;
+            if (r.seek > 0) {
+                const data = r.buffer[r.seek..r.end];
+                @memmove(r.buffer[0..data.len], data);
+                r.seek = 0;
+                r.end = data.len;
+            }
             {
                 var list: std.ArrayListUnmanaged(u8) = .{
                     .items = r.buffer[0..r.end],