Commit a6ede7ba86

Jacob Young <jacobly0@users.noreply.github.com>
2024-11-05 08:24:14
std.crypto.tls: support handshake fragments
1 parent de53e6e
Changed files (1)
lib
std
crypto
lib/std/crypto/tls/Client.zig
@@ -274,13 +274,14 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
     }
 
     var tls_version: tls.ProtocolVersion = undefined;
-    // This is used for two purposes:
+    // These are used for two purposes:
     // * Detect whether a certificate is the first one presented, in which case
     //   we need to verify the host name.
+    var cert_index: usize = 0;
     // * Flip back and forth between the two cleartext buffers in order to keep
     //   the previous certificate in memory so that it can be verified by the
     //   next one.
-    var cert_index: usize = 0;
+    var cert_buf_index: usize = 0;
     var write_seq: u64 = 0;
     var read_seq: u64 = 0;
     var prev_cert: Certificate.Parsed = undefined;
@@ -315,10 +316,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
     var main_cert_pub_key: CertificatePublicKey = undefined;
     const now_sec = std.time.timestamp();
 
+    var cleartext_fragment_start: usize = 0;
+    var cleartext_fragment_end: usize = 0;
     var cleartext_bufs: [2][tls.max_ciphertext_inner_record_len]u8 = undefined;
     var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined;
     var d: tls.Decoder = .{ .buf = &handshake_buffer };
-    while (true) {
+    fragment: while (true) {
         try d.readAtLeastOurAmt(stream, tls.record_header_len);
         const record_header = d.buf[d.idx..][0..tls.record_header_len];
         const record_ct = d.decode(tls.ContentType);
@@ -332,15 +335,16 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
                 std.debug.assert(tls_version == .tls_1_3);
                 if (record_ct != .application_data) return error.TlsUnexpectedMessage;
                 try record_decoder.ensure(record_len);
-                const cleartext_buf = &cleartext_bufs[cert_index % 2];
-                const cleartext = cleartext: switch (handshake_cipher) {
+                const cleartext_buf = &cleartext_bufs[cert_buf_index % 2];
+                switch (handshake_cipher) {
                     inline else => |*p| {
                         const pv = &p.version.tls_1_3;
                         const P = @TypeOf(p.*).A;
                         if (record_len < P.AEAD.tag_length) return error.TlsRecordOverflow;
                         const ciphertext = record_decoder.slice(record_len - P.AEAD.tag_length);
-                        if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow;
-                        const cleartext = cleartext_buf[0..ciphertext.len];
+                        const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..];
+                        if (ciphertext.len > cleartext_fragment_buf.len) return error.TlsRecordOverflow;
+                        const cleartext = cleartext_fragment_buf[0..ciphertext.len];
                         const auth_tag = record_decoder.array(P.AEAD.tag_length).*;
                         const nonce = if (builtin.zig_backend == .stage2_x86_64 and
                             P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1)
@@ -357,27 +361,29 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
                         };
                         P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, pv.server_handshake_key) catch
                             return error.TlsBadRecordMac;
-                        break :cleartext mem.trimRight(u8, cleartext, "\x00");
+                        cleartext_fragment_end += std.mem.trimRight(u8, cleartext, "\x00").len;
                     },
-                };
+                }
                 read_seq += 1;
-                const ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]);
+                cleartext_fragment_end -= 1;
+                const ct: tls.ContentType = @enumFromInt(cleartext_buf[cleartext_fragment_end]);
                 if (ct != .handshake) return error.TlsUnexpectedMessage;
-                break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext[0 .. cleartext.len - 1])), ct };
+                break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end])), ct };
             },
             .application => {
                 std.debug.assert(tls_version == .tls_1_2);
                 if (record_ct != .handshake) return error.TlsUnexpectedMessage;
                 try record_decoder.ensure(record_len);
-                const cleartext_buf = &cleartext_bufs[cert_index % 2];
-                const cleartext = cleartext: switch (handshake_cipher) {
+                const cleartext_buf = &cleartext_bufs[cert_buf_index % 2];
+                switch (handshake_cipher) {
                     inline else => |*p| {
                         const pv = &p.version.tls_1_2;
                         const P = @TypeOf(p.*).A;
                         if (record_len < P.record_iv_length + P.mac_length) return error.TlsRecordOverflow;
                         const message_len: u16 = record_len - P.record_iv_length - P.mac_length;
-                        if (message_len > cleartext_buf.len) return error.TlsRecordOverflow;
-                        const cleartext = cleartext_buf[0..message_len];
+                        const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..];
+                        if (message_len > cleartext_fragment_buf.len) return error.TlsRecordOverflow;
+                        const cleartext = cleartext_fragment_buf[0..message_len];
                         const ad = std.mem.toBytes(big(read_seq)) ++
                             record_header[0 .. 1 + 2] ++
                             std.mem.toBytes(big(message_len));
@@ -400,16 +406,16 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
                         const ciphertext = record_decoder.slice(message_len);
                         const auth_tag = record_decoder.array(P.mac_length);
                         P.AEAD.decrypt(cleartext, ciphertext, auth_tag.*, ad, nonce, pv.app_cipher.server_write_key) catch return error.TlsBadRecordMac;
-                        break :cleartext cleartext;
+                        cleartext_fragment_end += message_len;
                     },
-                };
+                }
                 read_seq += 1;
-                break :content .{ tls.Decoder.fromTheirSlice(cleartext), record_ct };
+                break :content .{ tls.Decoder.fromTheirSlice(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end]), record_ct };
             },
         };
         switch (ct) {
             .alert => {
-                try ctd.ensure(2);
+                ctd.ensure(2) catch continue :fragment;
                 const level = ctd.decode(tls.AlertLevel);
                 const desc = ctd.decode(tls.AlertDescription);
                 _ = level;
@@ -420,15 +426,15 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
                 return error.TlsUnexpectedMessage;
             },
             .change_cipher_spec => {
-                try ctd.ensure(1);
+                ctd.ensure(1) catch continue :fragment;
                 if (ctd.decode(tls.ChangeCipherSpecType) != .change_cipher_spec) return error.TlsIllegalParameter;
                 cipher_state = pending_cipher_state;
             },
             .handshake => while (true) {
-                try ctd.ensure(4);
+                ctd.ensure(4) catch continue :fragment;
                 const handshake_type = ctd.decode(tls.HandshakeType);
                 const handshake_len = ctd.decode(u24);
-                var hsd = try ctd.sub(handshake_len);
+                var hsd = ctd.sub(handshake_len) catch continue :fragment;
                 const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx];
                 switch (handshake_type) {
                     .server_hello => {
@@ -657,6 +663,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
                             prev_cert = subject;
                             cert_index += 1;
                         }
+                        cert_buf_index += 1;
                     },
                     .server_key_exchange => {
                         if (tls_version != .tls_1_2) return error.TlsUnexpectedMessage;
@@ -892,9 +899,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
                     else => return error.TlsUnexpectedMessage,
                 }
                 if (ctd.eof()) break;
+                cleartext_fragment_start = ctd.idx;
             },
             else => return error.TlsUnexpectedMessage,
         }
+        cleartext_fragment_start = 0;
+        cleartext_fragment_end = 0;
     }
 }