Commit 8721efece4

Andrew Kelley <andrew@ziglang.org>
2025-08-08 04:47:56
std.crypto.tls.Client: always write to buffer
simplifies the logic & makes it respect limit
1 parent d7bf608
Changed files (2)
lib
std
crypto
Io
lib/std/crypto/tls/Client.zig
@@ -61,9 +61,6 @@ pub const ReadError = error{
     TlsUnexpectedMessage,
     TlsIllegalParameter,
     TlsSequenceOverflow,
-    /// The buffer provided to the read function was not at least
-    /// `min_buffer_len`.
-    OutputBufferUndersize,
 };
 
 pub const SslKeyLog = struct {
@@ -372,7 +369,8 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
                         };
                         P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, pv.server_handshake_key) catch
                             return error.TlsBadRecordMac;
-                        cleartext_fragment_end += std.mem.trimEnd(u8, cleartext, "\x00").len;
+                        // TODO use scalar, non-slice version
+                        cleartext_fragment_end += mem.trimEnd(u8, cleartext, "\x00").len;
                     },
                 }
                 read_seq += 1;
@@ -395,9 +393,9 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
                         const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..];
                         if (message_len > cleartext_fragment_buf.len) return error.TlsRecordOverflow;
                         const cleartext = cleartext_fragment_buf[0..message_len];
-                        const ad = std.mem.toBytes(big(read_seq)) ++
+                        const ad = mem.toBytes(big(read_seq)) ++
                             record_header[0 .. 1 + 2] ++
-                            std.mem.toBytes(big(message_len));
+                            mem.toBytes(big(message_len));
                         const record_iv = record_decoder.array(P.record_iv_length).*;
                         const masked_read_seq = read_seq &
                             comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length);
@@ -738,7 +736,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
                                         &.{ "server finished", &p.transcript_hash.finalResult() },
                                         P.verify_data_length,
                                     ),
-                                    .app_cipher = std.mem.bytesToValue(P.Tls_1_2, &key_block),
+                                    .app_cipher = mem.bytesToValue(P.Tls_1_2, &key_block),
                                 } };
                                 const pv = &p.version.tls_1_2;
                                 const nonce: [P.AEAD.nonce_length]u8 = nonce: {
@@ -756,7 +754,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
                                         client_verify_cleartext.len ..][0..client_verify_cleartext.len],
                                     client_verify_msg[client_verify_msg.len - P.mac_length ..][0..P.mac_length],
                                     &client_verify_cleartext,
-                                    std.mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int(u16, client_verify_cleartext.len),
+                                    mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int(u16, client_verify_cleartext.len),
                                     nonce,
                                     pv.app_cipher.client_write_key,
                                 );
@@ -873,7 +871,10 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
                             .input = input,
                             .reader = .{
                                 .buffer = options.read_buffer,
-                                .vtable = &.{ .stream = stream },
+                                .vtable = &.{
+                                    .stream = stream,
+                                    .readVec = readVec,
+                                },
                                 .seek = 0,
                                 .end = 0,
                             },
@@ -1017,7 +1018,7 @@ fn prepareCiphertextRecord(
                     const nonce = nonce: {
                         const V = @Vector(P.AEAD.nonce_length, u8);
                         const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
-                        const operand: V = pad ++ std.mem.toBytes(big(c.write_seq));
+                        const operand: V = pad ++ mem.toBytes(big(c.write_seq));
                         break :nonce @as(V, pv.client_iv) ^ operand;
                     };
                     P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_key);
@@ -1048,7 +1049,7 @@ fn prepareCiphertextRecord(
                     record_header.* = .{@intFromEnum(inner_content_type)} ++
                         int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
                         int(u16, P.record_iv_length + message_len + P.mac_length);
-                    const ad = std.mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int(u16, message_len);
+                    const ad = mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int(u16, message_len);
                     const record_iv = ciphertext_buf[ciphertext_end..][0..P.record_iv_length];
                     ciphertext_end += P.record_iv_length;
                     const nonce: [P.AEAD.nonce_length]u8 = nonce: {
@@ -1076,7 +1077,22 @@ pub fn eof(c: Client) bool {
 }
 
 fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
+    // This function writes exclusively to the buffer.
+    _ = w;
+    _ = limit;
+    const c: *Client = @alignCast(@fieldParentPtr("reader", r));
+    return readIndirect(c);
+}
+
+fn readVec(r: *Reader, data: [][]u8) Reader.Error!usize {
+    // This function writes exclusively to the buffer.
+    _ = data;
     const c: *Client = @alignCast(@fieldParentPtr("reader", r));
+    return readIndirect(c);
+}
+
+fn readIndirect(c: *Client) Reader.Error!usize {
+    const r = &c.reader;
     if (c.eof()) return error.EndOfStream;
     const input = c.input;
     // If at least one full encrypted record is not buffered, read once.
@@ -1108,8 +1124,13 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
         if (record_end > input.buffered().len) return 0;
     }
 
-    var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined;
-    const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) {
+    if (r.seek == r.end) {
+        r.seek = 0;
+        r.end = 0;
+    }
+    const cleartext_buffer = r.buffer[r.end..];
+
+    const cleartext_len, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) {
         inline else => |*p| switch (c.tls_version) {
             .tls_1_3 => {
                 const pv = &p.tls_1_3;
@@ -1121,23 +1142,24 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
                 const nonce = nonce: {
                     const V = @Vector(P.AEAD.nonce_length, u8);
                     const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
-                    const operand: V = pad ++ std.mem.toBytes(big(c.read_seq));
+                    const operand: V = pad ++ mem.toBytes(big(c.read_seq));
                     break :nonce @as(V, pv.server_iv) ^ operand;
                 };
-                const cleartext = cleartext_stack_buffer[0..ciphertext.len];
+                const cleartext = cleartext_buffer[0..ciphertext.len];
                 P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch
                     return failRead(c, error.TlsBadRecordMac);
+                // TODO use scalar, non-slice version
                 const msg = mem.trimRight(u8, cleartext, "\x00");
-                break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) };
+                break :cleartext .{ msg.len - 1, @enumFromInt(msg[msg.len - 1]) };
             },
             .tls_1_2 => {
                 const pv = &p.tls_1_2;
                 const P = @TypeOf(p.*);
                 const message_len: u16 = record_len - P.record_iv_length - P.mac_length;
                 const ad_header = input.take(tls.record_header_len) catch unreachable; // already peeked
-                const ad = std.mem.toBytes(big(c.read_seq)) ++
+                const ad = mem.toBytes(big(c.read_seq)) ++
                     ad_header[0 .. 1 + 2] ++
-                    std.mem.toBytes(big(message_len));
+                    mem.toBytes(big(message_len));
                 const record_iv = (input.takeArray(P.record_iv_length) catch unreachable).*; // already peeked
                 const masked_read_seq = c.read_seq &
                     comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length);
@@ -1149,14 +1171,15 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
                 };
                 const ciphertext = input.take(message_len) catch unreachable; // already peeked
                 const auth_tag = (input.takeArray(P.mac_length) catch unreachable).*; // already peeked
-                const cleartext = cleartext_stack_buffer[0..ciphertext.len];
+                const cleartext = cleartext_buffer[0..ciphertext.len];
                 P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch
                     return failRead(c, error.TlsBadRecordMac);
-                break :cleartext .{ cleartext, ct };
+                break :cleartext .{ cleartext.len, ct };
             },
             else => unreachable,
         },
     };
+    const cleartext = cleartext_buffer[0..cleartext_len];
     c.read_seq = std.math.add(u64, c.read_seq, 1) catch return failRead(c, error.TlsSequenceOverflow);
     switch (inner_ct) {
         .alert => {
@@ -1245,9 +1268,8 @@ fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize
             return 0;
         },
         .application_data => {
-            if (@intFromEnum(limit) < cleartext.len) return failRead(c, error.OutputBufferUndersize);
-            try w.writeAll(cleartext);
-            return cleartext.len;
+            r.end += cleartext.len;
+            return 0;
         },
         else => return failRead(c, error.TlsUnexpectedMessage),
     }
lib/std/Io/Reader.zig
@@ -25,9 +25,7 @@ pub const VTable = struct {
     ///
     /// Returns the number of bytes written, which will be at minimum `0` and
     /// at most `limit`. The number returned, including zero, does not indicate
-    /// end of stream. `limit` is guaranteed to be at least as large as the
-    /// buffer capacity of `w`, a value whose minimum size is determined by the
-    /// stream implementation.
+    /// end of stream.
     ///
     /// The reader's internal logical seek position moves forward in accordance
     /// with the number of bytes returned from this function.