Commit a6ede7ba86
Changed files (1)
lib
std
crypto
tls
lib/std/crypto/tls/Client.zig
@@ -274,13 +274,14 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
}
var tls_version: tls.ProtocolVersion = undefined;
- // This is used for two purposes:
+ // These are used for two purposes:
// * Detect whether a certificate is the first one presented, in which case
// we need to verify the host name.
+ var cert_index: usize = 0;
// * Flip back and forth between the two cleartext buffers in order to keep
// the previous certificate in memory so that it can be verified by the
// next one.
- var cert_index: usize = 0;
+ var cert_buf_index: usize = 0;
var write_seq: u64 = 0;
var read_seq: u64 = 0;
var prev_cert: Certificate.Parsed = undefined;
@@ -315,10 +316,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
var main_cert_pub_key: CertificatePublicKey = undefined;
const now_sec = std.time.timestamp();
+ var cleartext_fragment_start: usize = 0;
+ var cleartext_fragment_end: usize = 0;
var cleartext_bufs: [2][tls.max_ciphertext_inner_record_len]u8 = undefined;
var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined;
var d: tls.Decoder = .{ .buf = &handshake_buffer };
- while (true) {
+ fragment: while (true) {
try d.readAtLeastOurAmt(stream, tls.record_header_len);
const record_header = d.buf[d.idx..][0..tls.record_header_len];
const record_ct = d.decode(tls.ContentType);
@@ -332,15 +335,16 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
std.debug.assert(tls_version == .tls_1_3);
if (record_ct != .application_data) return error.TlsUnexpectedMessage;
try record_decoder.ensure(record_len);
- const cleartext_buf = &cleartext_bufs[cert_index % 2];
- const cleartext = cleartext: switch (handshake_cipher) {
+ const cleartext_buf = &cleartext_bufs[cert_buf_index % 2];
+ switch (handshake_cipher) {
inline else => |*p| {
const pv = &p.version.tls_1_3;
const P = @TypeOf(p.*).A;
if (record_len < P.AEAD.tag_length) return error.TlsRecordOverflow;
const ciphertext = record_decoder.slice(record_len - P.AEAD.tag_length);
- if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow;
- const cleartext = cleartext_buf[0..ciphertext.len];
+ const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..];
+ if (ciphertext.len > cleartext_fragment_buf.len) return error.TlsRecordOverflow;
+ const cleartext = cleartext_fragment_buf[0..ciphertext.len];
const auth_tag = record_decoder.array(P.AEAD.tag_length).*;
const nonce = if (builtin.zig_backend == .stage2_x86_64 and
P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1)
@@ -357,27 +361,29 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
};
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, pv.server_handshake_key) catch
return error.TlsBadRecordMac;
- break :cleartext mem.trimRight(u8, cleartext, "\x00");
+ cleartext_fragment_end += std.mem.trimRight(u8, cleartext, "\x00").len;
},
- };
+ }
read_seq += 1;
- const ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]);
+ cleartext_fragment_end -= 1;
+ const ct: tls.ContentType = @enumFromInt(cleartext_buf[cleartext_fragment_end]);
if (ct != .handshake) return error.TlsUnexpectedMessage;
- break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext[0 .. cleartext.len - 1])), ct };
+ break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end])), ct };
},
.application => {
std.debug.assert(tls_version == .tls_1_2);
if (record_ct != .handshake) return error.TlsUnexpectedMessage;
try record_decoder.ensure(record_len);
- const cleartext_buf = &cleartext_bufs[cert_index % 2];
- const cleartext = cleartext: switch (handshake_cipher) {
+ const cleartext_buf = &cleartext_bufs[cert_buf_index % 2];
+ switch (handshake_cipher) {
inline else => |*p| {
const pv = &p.version.tls_1_2;
const P = @TypeOf(p.*).A;
if (record_len < P.record_iv_length + P.mac_length) return error.TlsRecordOverflow;
const message_len: u16 = record_len - P.record_iv_length - P.mac_length;
- if (message_len > cleartext_buf.len) return error.TlsRecordOverflow;
- const cleartext = cleartext_buf[0..message_len];
+ const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..];
+ if (message_len > cleartext_fragment_buf.len) return error.TlsRecordOverflow;
+ const cleartext = cleartext_fragment_buf[0..message_len];
const ad = std.mem.toBytes(big(read_seq)) ++
record_header[0 .. 1 + 2] ++
std.mem.toBytes(big(message_len));
@@ -400,16 +406,16 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
const ciphertext = record_decoder.slice(message_len);
const auth_tag = record_decoder.array(P.mac_length);
P.AEAD.decrypt(cleartext, ciphertext, auth_tag.*, ad, nonce, pv.app_cipher.server_write_key) catch return error.TlsBadRecordMac;
- break :cleartext cleartext;
+ cleartext_fragment_end += message_len;
},
- };
+ }
read_seq += 1;
- break :content .{ tls.Decoder.fromTheirSlice(cleartext), record_ct };
+ break :content .{ tls.Decoder.fromTheirSlice(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end]), record_ct };
},
};
switch (ct) {
.alert => {
- try ctd.ensure(2);
+ ctd.ensure(2) catch continue :fragment;
const level = ctd.decode(tls.AlertLevel);
const desc = ctd.decode(tls.AlertDescription);
_ = level;
@@ -420,15 +426,15 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
return error.TlsUnexpectedMessage;
},
.change_cipher_spec => {
- try ctd.ensure(1);
+ ctd.ensure(1) catch continue :fragment;
if (ctd.decode(tls.ChangeCipherSpecType) != .change_cipher_spec) return error.TlsIllegalParameter;
cipher_state = pending_cipher_state;
},
.handshake => while (true) {
- try ctd.ensure(4);
+ ctd.ensure(4) catch continue :fragment;
const handshake_type = ctd.decode(tls.HandshakeType);
const handshake_len = ctd.decode(u24);
- var hsd = try ctd.sub(handshake_len);
+ var hsd = ctd.sub(handshake_len) catch continue :fragment;
const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx];
switch (handshake_type) {
.server_hello => {
@@ -657,6 +663,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
prev_cert = subject;
cert_index += 1;
}
+ cert_buf_index += 1;
},
.server_key_exchange => {
if (tls_version != .tls_1_2) return error.TlsUnexpectedMessage;
@@ -892,9 +899,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
else => return error.TlsUnexpectedMessage,
}
if (ctd.eof()) break;
+ cleartext_fragment_start = ctd.idx;
},
else => return error.TlsUnexpectedMessage,
}
+ cleartext_fragment_start = 0;
+ cleartext_fragment_end = 0;
}
}