Commit 0fb78b15aa
Changed files (2)
lib
std
crypto
lib/std/crypto/tls/Client.zig
@@ -126,88 +126,73 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
const client_hello_bytes1 = plaintext_header[5..];
var handshake_cipher: tls.HandshakeCipher = undefined;
-
- var handshake_buf: [8000]u8 = undefined;
- var len: usize = 0;
- var i: usize = i: {
- const plaintext = handshake_buf[0..5];
- len = try stream.readAtLeast(&handshake_buf, plaintext.len);
- if (len < plaintext.len) return error.EndOfStream;
- const ct = @intToEnum(tls.ContentType, plaintext[0]);
- const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]);
- const end = plaintext.len + frag_len;
- if (end > handshake_buf.len) return error.TlsRecordOverflow;
- if (end > len) {
- len += try stream.readAtLeast(handshake_buf[len..], end - len);
- if (end > len) return error.EndOfStream;
- }
- const frag = handshake_buf[plaintext.len..end];
-
+ var handshake_buffer: [8000]u8 = undefined;
+ var d: tls.Decoder = .{ .buf = &handshake_buffer };
+ {
+ try d.readAtLeastOurAmt(stream, tls.record_header_len);
+ const ct = d.decode(tls.ContentType);
+ d.skip(2); // legacy_record_version
+ const record_len = d.decode(u16);
+ try d.readAtLeast(stream, record_len);
+ const server_hello_fragment = d.buf[d.idx..][0..record_len];
+ var ptd = try d.sub(record_len);
switch (ct) {
.alert => {
- const level = @intToEnum(tls.AlertLevel, frag[0]);
- const desc = @intToEnum(tls.AlertDescription, frag[1]);
+ try ptd.ensure(2);
+ const level = ptd.decode(tls.AlertLevel);
+ const desc = ptd.decode(tls.AlertDescription);
_ = level;
_ = desc;
return error.TlsAlert;
},
.handshake => {
- if (frag[0] != @enumToInt(tls.HandshakeType.server_hello)) {
+ try ptd.ensure(4);
+ const handshake_type = ptd.decode(tls.HandshakeType);
+ if (handshake_type != .server_hello) return error.TlsUnexpectedMessage;
+ const length = ptd.decode(u24);
+ var hsd = try ptd.sub(length);
+ try hsd.ensure(2 + 32 + 1 + 32 + 2 + 1 + 2);
+ const legacy_version = hsd.decode(u16);
+ const random = hsd.array(32);
+ if (mem.eql(u8, random, &tls.hello_retry_request_sequence)) {
+ // This is a HelloRetryRequest message. This client implementation
+ // does not expect to get one.
return error.TlsUnexpectedMessage;
}
- const length = mem.readIntBig(u24, frag[1..4]);
- if (4 + length != frag.len) return error.TlsBadLength;
- var i: usize = 4;
- const legacy_version = mem.readIntBig(u16, frag[i..][0..2]);
- i += 2;
- const random = frag[i..][0..32].*;
- i += 32;
- if (mem.eql(u8, &random, &tls.hello_retry_request_sequence)) {
- @panic("TODO handle HelloRetryRequest");
- }
- const legacy_session_id_echo_len = frag[i];
- i += 1;
+ const legacy_session_id_echo_len = hsd.decode(u8);
if (legacy_session_id_echo_len != 32) return error.TlsIllegalParameter;
- const legacy_session_id_echo = frag[i..][0..32];
+ const legacy_session_id_echo = hsd.array(32);
if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id))
return error.TlsIllegalParameter;
- i += 32;
- const cipher_suite_int = mem.readIntBig(u16, frag[i..][0..2]);
- i += 2;
- const cipher_suite_tag = @intToEnum(tls.CipherSuite, cipher_suite_int);
- const legacy_compression_method = frag[i];
- i += 1;
- _ = legacy_compression_method;
- const extensions_size = mem.readIntBig(u16, frag[i..][0..2]);
- i += 2;
- if (i + extensions_size != frag.len) return error.TlsBadLength;
+ const cipher_suite_tag = hsd.decode(tls.CipherSuite);
+ hsd.skip(1); // legacy_compression_method
+ const extensions_size = hsd.decode(u16);
+ var all_extd = try hsd.sub(extensions_size);
var supported_version: u16 = 0;
var shared_key: [32]u8 = undefined;
var have_shared_key = false;
- while (i < frag.len) {
- const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, frag[i..][0..2]));
- i += 2;
- const ext_size = mem.readIntBig(u16, frag[i..][0..2]);
- i += 2;
- const next_i = i + ext_size;
- if (next_i > frag.len) return error.TlsBadLength;
+ while (!all_extd.eof()) {
+ try all_extd.ensure(2 + 2);
+ const et = all_extd.decode(tls.ExtensionType);
+ const ext_size = all_extd.decode(u16);
+ var extd = try all_extd.sub(ext_size);
switch (et) {
.supported_versions => {
if (supported_version != 0) return error.TlsIllegalParameter;
- supported_version = mem.readIntBig(u16, frag[i..][0..2]);
+ try extd.ensure(2);
+ supported_version = extd.decode(u16);
},
.key_share => {
if (have_shared_key) return error.TlsIllegalParameter;
have_shared_key = true;
- const named_group = @intToEnum(tls.NamedGroup, mem.readIntBig(u16, frag[i..][0..2]));
- i += 2;
- const key_size = mem.readIntBig(u16, frag[i..][0..2]);
- i += 2;
-
+ try extd.ensure(4);
+ const named_group = extd.decode(tls.NamedGroup);
+ const key_size = extd.decode(u16);
+ try extd.ensure(key_size);
switch (named_group) {
.x25519 => {
- if (key_size != 32) return error.TlsBadLength;
- const server_pub_key = frag[i..][0..32];
+ if (key_size != 32) return error.TlsIllegalParameter;
+ const server_pub_key = extd.array(32);
shared_key = crypto.dh.X25519.scalarmult(
x25519_kp.secret_key,
@@ -215,7 +200,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
) catch return error.TlsDecryptFailure;
},
.secp256r1 => {
- const server_pub_key = frag[i..][0..key_size];
+ const server_pub_key = extd.slice(key_size);
const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey;
const pk = PublicKey.fromSec1(server_pub_key) catch {
@@ -233,14 +218,12 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
},
else => {},
}
- i = next_i;
}
if (!have_shared_key) return error.TlsIllegalParameter;
+
const tls_version = if (supported_version == 0) legacy_version else supported_version;
- switch (tls_version) {
- @enumToInt(tls.ProtocolVersion.tls_1_3) => {},
- else => return error.TlsIllegalParameter,
- }
+ if (tls_version != @enumToInt(tls.ProtocolVersion.tls_1_3))
+ return error.TlsIllegalParameter;
switch (cipher_suite_tag) {
inline .AES_128_GCM_SHA256,
@@ -264,7 +247,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
const p = &@field(handshake_cipher, @tagName(tag));
p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1
p.transcript_hash.update(host); // Client Hello part 2
- p.transcript_hash.update(frag); // Server Hello
+ p.transcript_hash.update(server_hello_fragment);
const hello_hash = p.transcript_hash.peek();
const zeroes = [1]u8{0} ** P.Hash.digest_length;
const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes);
@@ -289,8 +272,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
},
else => return error.TlsUnexpectedMessage,
}
- break :i end;
- };
+ }
// This is used for two purposes:
// * Detect whether a certificate is the first one presented, in which case
@@ -322,29 +304,17 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
var main_cert_pub_key_len: u16 = undefined;
while (true) {
- const end_hdr = i + 5;
- if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow;
- if (end_hdr > len) {
- len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len);
- if (end_hdr > len) return error.EndOfStream;
- }
- const ct = @intToEnum(tls.ContentType, handshake_buf[i]);
- i += 1;
- const legacy_version = mem.readIntBig(u16, handshake_buf[i..][0..2]);
- i += 2;
- _ = legacy_version;
- const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]);
- i += 2;
- const end = i + record_size;
- if (end > handshake_buf.len) return error.TlsRecordOverflow;
- if (end > len) {
- len += try stream.readAtLeast(handshake_buf[len..], end - len);
- if (end > len) return error.EndOfStream;
- }
+ try d.readAtLeastOurAmt(stream, tls.record_header_len);
+ const record_header = d.buf[d.idx..][0..5];
+ const ct = d.decode(tls.ContentType);
+ d.skip(2); // legacy_version
+ const record_len = d.decode(u16);
+ try d.readAtLeast(stream, record_len);
+ var record_decoder = try d.sub(record_len);
switch (ct) {
.change_cipher_spec => {
- if (record_size != 1) return error.TlsUnexpectedMessage;
- if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage;
+ try record_decoder.ensure(1);
+ if (record_decoder.decode(u8) != 0x01) return error.TlsIllegalParameter;
},
.application_data => {
const cleartext_buf = &cleartext_bufs[cert_index % 2];
@@ -352,276 +322,261 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
const cleartext = switch (handshake_cipher) {
inline else => |*p| c: {
const P = @TypeOf(p.*);
- const ciphertext_len = record_size - P.AEAD.tag_length;
- const ciphertext = handshake_buf[i..][0..ciphertext_len];
- i += ciphertext.len;
+ const ciphertext_len = record_len - P.AEAD.tag_length;
+ try record_decoder.ensure(ciphertext_len + P.AEAD.tag_length);
+ const ciphertext = record_decoder.slice(ciphertext_len);
if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow;
const cleartext = cleartext_buf[0..ciphertext.len];
- const auth_tag = handshake_buf[i..][0..P.AEAD.tag_length].*;
+ const auth_tag = record_decoder.array(P.AEAD.tag_length).*;
const V = @Vector(P.AEAD.nonce_length, u8);
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
const operand: V = pad ++ @bitCast([8]u8, big(read_seq));
read_seq += 1;
const nonce = @as(V, p.server_handshake_iv) ^ operand;
- const ad = handshake_buf[end_hdr - 5 ..][0..5];
- P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_handshake_key) catch
+ P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, p.server_handshake_key) catch
return error.TlsBadRecordMac;
break :c cleartext;
},
};
const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]);
- switch (inner_ct) {
- .handshake => {
- var ct_i: usize = 0;
- while (true) {
- const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]);
- ct_i += 1;
- const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]);
- ct_i += 3;
- const next_handshake_i = ct_i + handshake_len;
- if (next_handshake_i > cleartext.len - 1)
- return error.TlsBadLength;
- const wrapped_handshake = cleartext[ct_i - 4 .. next_handshake_i];
- const handshake = cleartext[ct_i..next_handshake_i];
- switch (handshake_type) {
- .encrypted_extensions => {
- if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage;
- handshake_state = .certificate;
- switch (handshake_cipher) {
- inline else => |*p| p.transcript_hash.update(wrapped_handshake),
- }
- const total_ext_size = mem.readIntBig(u16, handshake[0..2]);
- var hs_i: usize = 2;
- const end_ext_i = 2 + total_ext_size;
- while (hs_i < end_ext_i) {
- const et = @intToEnum(tls.ExtensionType, mem.readIntBig(u16, handshake[hs_i..][0..2]));
- hs_i += 2;
- const ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]);
- hs_i += 2;
- const next_ext_i = hs_i + ext_size;
- switch (et) {
- .server_name => {},
- else => {},
- }
- hs_i = next_ext_i;
- }
- },
- .certificate => cert: {
- switch (handshake_cipher) {
- inline else => |*p| p.transcript_hash.update(wrapped_handshake),
- }
- switch (handshake_state) {
- .certificate => {},
- .trust_chain_established => break :cert,
- else => return error.TlsUnexpectedMessage,
+ if (inner_ct != .handshake) return error.TlsUnexpectedMessage;
+
+ var ctd = tls.Decoder.fromTheirSlice(cleartext[0 .. cleartext.len - 1]);
+ while (true) {
+ try ctd.ensure(4);
+ const handshake_type = ctd.decode(tls.HandshakeType);
+ const handshake_len = ctd.decode(u24);
+ var hsd = try ctd.sub(handshake_len);
+ const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx];
+ const handshake = ctd.buf[ctd.idx - handshake_len .. ctd.idx];
+ switch (handshake_type) {
+ .encrypted_extensions => {
+ if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage;
+ handshake_state = .certificate;
+ switch (handshake_cipher) {
+ inline else => |*p| p.transcript_hash.update(wrapped_handshake),
+ }
+ try hsd.ensure(2);
+ const total_ext_size = hsd.decode(u16);
+ var all_extd = try hsd.sub(total_ext_size);
+ while (!all_extd.eof()) {
+ try all_extd.ensure(4);
+ const et = all_extd.decode(tls.ExtensionType);
+ const ext_size = all_extd.decode(u16);
+ var extd = try all_extd.sub(ext_size);
+ _ = extd;
+ switch (et) {
+ .server_name => {},
+ else => {},
+ }
+ }
+ },
+ .certificate => cert: {
+ switch (handshake_cipher) {
+ inline else => |*p| p.transcript_hash.update(wrapped_handshake),
+ }
+ switch (handshake_state) {
+ .certificate => {},
+ .trust_chain_established => break :cert,
+ else => return error.TlsUnexpectedMessage,
+ }
+ try hsd.ensure(1 + 4);
+ const cert_req_ctx_len = hsd.decode(u8);
+ if (cert_req_ctx_len != 0) return error.TlsIllegalParameter;
+ const certs_size = hsd.decode(u24);
+ var certs_decoder = try hsd.sub(certs_size);
+ while (!certs_decoder.eof()) {
+ try certs_decoder.ensure(3);
+ const cert_size = certs_decoder.decode(u24);
+ var certd = try certs_decoder.sub(cert_size);
+
+ const subject_cert: Certificate = .{
+ .buffer = certd.buf,
+ .index = @intCast(u32, certd.idx),
+ };
+ const subject = try subject_cert.parse();
+ if (cert_index == 0) {
+ // Verify the host on the first certificate.
+ if (!hostMatchesCommonName(host, subject.commonName())) {
+ return error.TlsCertificateHostMismatch;
}
- var hs_i: u32 = 0;
- const cert_req_ctx_len = handshake[hs_i];
- hs_i += 1;
- if (cert_req_ctx_len != 0) return error.TlsIllegalParameter;
- const certs_size = mem.readIntBig(u24, handshake[hs_i..][0..3]);
- hs_i += 3;
- const end_certs = hs_i + certs_size;
- while (hs_i < end_certs) {
- const cert_size = mem.readIntBig(u24, handshake[hs_i..][0..3]);
- hs_i += 3;
- const end_cert = hs_i + cert_size;
-
- const subject_cert: Certificate = .{
- .buffer = handshake,
- .index = hs_i,
- };
- const subject = try subject_cert.parse();
- if (cert_index == 0) {
- // Verify the host on the first certificate.
- if (!hostMatchesCommonName(host, subject.commonName())) {
- return error.TlsCertificateHostMismatch;
- }
- // Keep track of the public key for
- // the certificate_verify message
- // later.
- main_cert_pub_key_algo = subject.pub_key_algo;
- const pub_key = subject.pubKey();
- if (pub_key.len > main_cert_pub_key_buf.len)
- return error.CertificatePublicKeyInvalid;
- @memcpy(&main_cert_pub_key_buf, pub_key.ptr, pub_key.len);
- main_cert_pub_key_len = @intCast(@TypeOf(main_cert_pub_key_len), pub_key.len);
- } else {
- try prev_cert.verify(subject);
- }
-
- if (ca_bundle.verify(subject)) |_| {
- handshake_state = .trust_chain_established;
- break :cert;
- } else |err| switch (err) {
- error.CertificateIssuerNotFound => {},
- else => |e| return e,
- }
-
- prev_cert = subject;
- cert_index += 1;
-
- hs_i = end_cert;
- const total_ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]);
- hs_i += 2;
- hs_i += total_ext_size;
- }
- },
- .certificate_verify => {
- switch (handshake_state) {
- .trust_chain_established => handshake_state = .finished,
- .certificate => return error.TlsCertificateNotVerified,
- else => return error.TlsUnexpectedMessage,
- }
+ // Keep track of the public key for the
+ // certificate_verify message later.
+ main_cert_pub_key_algo = subject.pub_key_algo;
+ const pub_key = subject.pubKey();
+ if (pub_key.len > main_cert_pub_key_buf.len)
+ return error.CertificatePublicKeyInvalid;
+ @memcpy(&main_cert_pub_key_buf, pub_key.ptr, pub_key.len);
+ main_cert_pub_key_len = @intCast(@TypeOf(main_cert_pub_key_len), pub_key.len);
+ } else {
+ try prev_cert.verify(subject);
+ }
- const scheme = @intToEnum(tls.SignatureScheme, mem.readIntBig(u16, handshake[0..2]));
- const sig_len = mem.readIntBig(u16, handshake[2..4]);
- if (4 + sig_len > handshake.len) return error.TlsBadLength;
- const encoded_sig = handshake[4..][0..sig_len];
- const max_digest_len = 64;
- var verify_buffer =
- ([1]u8{0x20} ** 64) ++
- "TLS 1.3, server CertificateVerify\x00".* ++
- @as([max_digest_len]u8, undefined);
-
- const verify_bytes = switch (handshake_cipher) {
- inline else => |*p| v: {
- const transcript_digest = p.transcript_hash.peek();
- verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest;
- p.transcript_hash.update(wrapped_handshake);
- break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len];
- },
- };
- const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len];
-
- switch (scheme) {
- inline .ecdsa_secp256r1_sha256,
- .ecdsa_secp384r1_sha384,
- => |comptime_scheme| {
- if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey)
- return error.TlsBadSignatureScheme;
- const Ecdsa = SchemeEcdsa(comptime_scheme);
- const sig = try Ecdsa.Signature.fromDer(encoded_sig);
- const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key);
- try sig.verify(verify_bytes, key);
- },
- .rsa_pss_rsae_sha256 => {
- if (main_cert_pub_key_algo != .rsaEncryption)
- return error.TlsBadSignatureScheme;
-
- const Hash = crypto.hash.sha2.Sha256;
- const rsa = Certificate.rsa;
- const components = try rsa.PublicKey.parseDer(main_cert_pub_key);
- const exponent = components.exponent;
- const modulus = components.modulus;
- var rsa_mem_buf: [512 * 32]u8 = undefined;
- var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf);
- const ally = fba.allocator();
- switch (modulus.len) {
- inline 128, 256, 512 => |modulus_len| {
- const key = try rsa.PublicKey.fromBytes(exponent, modulus, ally);
- const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig);
- try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, ally);
- },
- else => {
- return error.TlsBadRsaSignatureBitCount;
- },
- }
+ if (ca_bundle.verify(subject)) |_| {
+ handshake_state = .trust_chain_established;
+ break :cert;
+ } else |err| switch (err) {
+ error.CertificateIssuerNotFound => {},
+ else => |e| return e,
+ }
+
+ prev_cert = subject;
+ cert_index += 1;
+
+ try certs_decoder.ensure(2);
+ const total_ext_size = certs_decoder.decode(u16);
+ var all_extd = try certs_decoder.sub(total_ext_size);
+ _ = all_extd;
+ }
+ },
+ .certificate_verify => {
+ switch (handshake_state) {
+ .trust_chain_established => handshake_state = .finished,
+ .certificate => return error.TlsCertificateNotVerified,
+ else => return error.TlsUnexpectedMessage,
+ }
+
+ try hsd.ensure(4);
+ const scheme = hsd.decode(tls.SignatureScheme);
+ const sig_len = hsd.decode(u16);
+ try hsd.ensure(sig_len);
+ const encoded_sig = hsd.slice(sig_len);
+ const max_digest_len = 64;
+ var verify_buffer =
+ ([1]u8{0x20} ** 64) ++
+ "TLS 1.3, server CertificateVerify\x00".* ++
+ @as([max_digest_len]u8, undefined);
+
+ const verify_bytes = switch (handshake_cipher) {
+ inline else => |*p| v: {
+ const transcript_digest = p.transcript_hash.peek();
+ verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest;
+ p.transcript_hash.update(wrapped_handshake);
+ break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len];
+ },
+ };
+ const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len];
+
+ switch (scheme) {
+ inline .ecdsa_secp256r1_sha256,
+ .ecdsa_secp384r1_sha384,
+ => |comptime_scheme| {
+ if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey)
+ return error.TlsBadSignatureScheme;
+ const Ecdsa = SchemeEcdsa(comptime_scheme);
+ const sig = try Ecdsa.Signature.fromDer(encoded_sig);
+ const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key);
+ try sig.verify(verify_bytes, key);
+ },
+ .rsa_pss_rsae_sha256 => {
+ if (main_cert_pub_key_algo != .rsaEncryption)
+ return error.TlsBadSignatureScheme;
+
+ const Hash = crypto.hash.sha2.Sha256;
+ const rsa = Certificate.rsa;
+ const components = try rsa.PublicKey.parseDer(main_cert_pub_key);
+ const exponent = components.exponent;
+ const modulus = components.modulus;
+ var rsa_mem_buf: [512 * 32]u8 = undefined;
+ var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf);
+ const ally = fba.allocator();
+ switch (modulus.len) {
+ inline 128, 256, 512 => |modulus_len| {
+ const key = try rsa.PublicKey.fromBytes(exponent, modulus, ally);
+ const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig);
+ try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, ally);
},
else => {
- return error.TlsBadSignatureScheme;
+ return error.TlsBadRsaSignatureBitCount;
},
}
},
- .finished => {
- if (handshake_state != .finished) return error.TlsUnexpectedMessage;
- // This message is to trick buggy proxies into behaving correctly.
- const client_change_cipher_spec_msg = [_]u8{
- @enumToInt(tls.ContentType.change_cipher_spec),
- 0x03, 0x03, // legacy protocol version
- 0x00, 0x01, // length
- 0x01,
- };
- const app_cipher = switch (handshake_cipher) {
- inline else => |*p, tag| c: {
- const P = @TypeOf(p.*);
- const finished_digest = p.transcript_hash.peek();
- p.transcript_hash.update(wrapped_handshake);
- const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key);
- if (!mem.eql(u8, &expected_server_verify_data, handshake))
- return error.TlsDecryptError;
- const handshake_hash = p.transcript_hash.finalResult();
- const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key);
- const out_cleartext = [_]u8{
- @enumToInt(tls.HandshakeType.finished),
- 0, 0, verify_data.len, // length
- } ++ verify_data ++ [1]u8{@enumToInt(tls.ContentType.handshake)};
-
- const wrapped_len = out_cleartext.len + P.AEAD.tag_length;
-
- var finished_msg = [_]u8{
- @enumToInt(tls.ContentType.application_data),
- 0x03, 0x03, // legacy protocol version
- 0, wrapped_len, // byte length of encrypted record
- } ++ @as([wrapped_len]u8, undefined);
-
- const ad = finished_msg[0..5];
- const ciphertext = finished_msg[5..][0..out_cleartext.len];
- const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..];
- const nonce = p.client_handshake_iv;
- P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key);
-
- const both_msgs = client_change_cipher_spec_msg ++ finished_msg;
- try stream.writeAll(&both_msgs);
-
- const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
- const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
- break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{
- .client_secret = client_secret,
- .server_secret = server_secret,
- .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length),
- .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length),
- .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length),
- .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length),
- });
- },
- };
- var client: Client = .{
- .read_seq = 0,
- .write_seq = 0,
- .partial_cleartext_idx = 0,
- .partial_ciphertext_idx = 0,
- .partial_ciphertext_end = @intCast(u15, len - end),
- .received_close_notify = false,
- .application_cipher = app_cipher,
- .partially_read_buffer = undefined,
- };
- mem.copy(u8, &client.partially_read_buffer, handshake_buf[len..end]);
- return client;
- },
else => {
- return error.TlsUnexpectedMessage;
+ return error.TlsBadSignatureScheme;
},
}
- ct_i = next_handshake_i;
- if (ct_i >= cleartext.len - 1) break;
- }
- },
- else => {
- return error.TlsUnexpectedMessage;
- },
+ },
+ .finished => {
+ if (handshake_state != .finished) return error.TlsUnexpectedMessage;
+ // This message is to trick buggy proxies into behaving correctly.
+ const client_change_cipher_spec_msg = [_]u8{
+ @enumToInt(tls.ContentType.change_cipher_spec),
+ 0x03, 0x03, // legacy protocol version
+ 0x00, 0x01, // length
+ 0x01,
+ };
+ const app_cipher = switch (handshake_cipher) {
+ inline else => |*p, tag| c: {
+ const P = @TypeOf(p.*);
+ const finished_digest = p.transcript_hash.peek();
+ p.transcript_hash.update(wrapped_handshake);
+ const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key);
+ if (!mem.eql(u8, &expected_server_verify_data, handshake))
+ return error.TlsDecryptError;
+ const handshake_hash = p.transcript_hash.finalResult();
+ const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key);
+ const out_cleartext = [_]u8{
+ @enumToInt(tls.HandshakeType.finished),
+ 0, 0, verify_data.len, // length
+ } ++ verify_data ++ [1]u8{@enumToInt(tls.ContentType.handshake)};
+
+ const wrapped_len = out_cleartext.len + P.AEAD.tag_length;
+
+ var finished_msg = [_]u8{
+ @enumToInt(tls.ContentType.application_data),
+ 0x03, 0x03, // legacy protocol version
+ 0, wrapped_len, // byte length of encrypted record
+ } ++ @as([wrapped_len]u8, undefined);
+
+ const ad = finished_msg[0..5];
+ const ciphertext = finished_msg[5..][0..out_cleartext.len];
+ const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..];
+ const nonce = p.client_handshake_iv;
+ P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key);
+
+ const both_msgs = client_change_cipher_spec_msg ++ finished_msg;
+ try stream.writeAll(&both_msgs);
+
+ const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
+ const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
+ break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{
+ .client_secret = client_secret,
+ .server_secret = server_secret,
+ .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length),
+ .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length),
+ .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length),
+ .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length),
+ });
+ },
+ };
+ const leftover = d.rest();
+ var client: Client = .{
+ .read_seq = 0,
+ .write_seq = 0,
+ .partial_cleartext_idx = 0,
+ .partial_ciphertext_idx = 0,
+ .partial_ciphertext_end = @intCast(u15, leftover.len),
+ .received_close_notify = false,
+ .application_cipher = app_cipher,
+ .partially_read_buffer = undefined,
+ };
+ mem.copy(u8, &client.partially_read_buffer, leftover);
+ return client;
+ },
+ else => {
+ return error.TlsUnexpectedMessage;
+ },
+ }
+ if (ctd.eof()) break;
}
},
else => {
return error.TlsUnexpectedMessage;
},
}
- i = end;
}
-
- return error.TlsHandshakeFailure;
}
pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
@@ -638,12 +593,12 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
inline else => |*p| l: {
const P = @TypeOf(p.*);
const V = @Vector(P.AEAD.nonce_length, u8);
- const overhead_len = tls.ciphertext_record_header_len + P.AEAD.tag_length + 1;
+ const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1;
while (true) {
const encrypted_content_len = @intCast(u16, @min(
@min(bytes.len - bytes_i, max_ciphertext_len - 1),
ciphertext_buf.len -
- tls.ciphertext_record_header_len - P.AEAD.tag_length - ciphertext_end - 1,
+ tls.record_header_len - P.AEAD.tag_length - ciphertext_end - 1,
));
if (encrypted_content_len == 0) break :l overhead_len;
@@ -829,7 +784,7 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
// Cleartext capacity of output buffer, in records, rounded up.
const buf_cap = (cleartext_buf_len +| (max_ciphertext_len - 1)) / max_ciphertext_len;
- const wanted_read_len = buf_cap * (max_ciphertext_len + tls.ciphertext_record_header_len);
+ const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len);
const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len);
const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len);
const actual_read_len = try stream.readv(ask_iovecs);
@@ -860,13 +815,13 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
continue;
}
- if (in + tls.ciphertext_record_header_len > frag.len) {
+ if (in + tls.record_header_len > frag.len) {
if (frag.ptr == frag1.ptr)
return finishRead(c, frag, in, vp.total);
const first = frag[in..];
- if (frag1.len < tls.ciphertext_record_header_len)
+ if (frag1.len < tls.record_header_len)
return finishRead2(c, first, frag1, vp.total);
// A record straddles the two fragments. Copy into the now-empty first fragment.
@@ -875,7 +830,7 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
const record_len = (record_len_byte_0 << 8) | record_len_byte_1;
if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
- const full_record_len = record_len + tls.ciphertext_record_header_len;
+ const full_record_len = record_len + tls.record_header_len;
const second_len = full_record_len - first.len;
if (frag1.len < second_len)
return finishRead2(c, first, frag1, vp.total);
@@ -898,14 +853,14 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
const end = in + record_len;
if (end > frag.len) {
// We need the record header on the next iteration of the loop.
- in -= tls.ciphertext_record_header_len;
+ in -= tls.record_header_len;
if (frag.ptr == frag1.ptr)
return finishRead(c, frag, in, vp.total);
// A record straddles the two fragments. Copy into the now-empty first fragment.
const first = frag[in..];
- const full_record_len = record_len + tls.ciphertext_record_header_len;
+ const full_record_len = record_len + tls.record_header_len;
const second_len = full_record_len - first.len;
if (frag1.len < second_len)
return finishRead2(c, first, frag1, vp.total);
@@ -919,7 +874,12 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
}
switch (ct) {
.alert => {
- @panic("TODO handle an alert here");
+ if (in + 2 > frag.len) return error.TlsDecodeError;
+ const level = @intToEnum(tls.AlertLevel, frag[in]);
+ const desc = @intToEnum(tls.AlertDescription, frag[in + 1]);
+ _ = level;
+ _ = desc;
+ return error.TlsAlert;
},
.application_data => {
const cleartext = switch (c.application_cipher) {
lib/std/crypto/tls.zig
@@ -39,9 +39,9 @@ const assert = std.debug.assert;
pub const Client = @import("tls/Client.zig");
-pub const ciphertext_record_header_len = 5;
+pub const record_header_len = 5;
pub const max_ciphertext_len = (1 << 14) + 256;
-pub const max_ciphertext_record_len = max_ciphertext_len + ciphertext_record_header_len;
+pub const max_ciphertext_record_len = max_ciphertext_len + record_header_len;
pub const hello_retry_request_sequence = [32]u8{
0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
@@ -360,3 +360,130 @@ pub inline fn int3(x: u24) [3]u8 {
@truncate(u8, x),
};
}
+
+/// An abstraction to ensure that protocol-parsing code does not perform an
+/// out-of-bounds read.
+pub const Decoder = struct {
+ buf: []u8,
+ /// Points to the next byte in buffer that will be decoded.
+ idx: usize = 0,
+ /// Up to this point in `buf` we have already checked that `cap` is greater than it.
+ our_end: usize = 0,
+ /// Beyond this point in `buf` is extra tag-along bytes beyond the amount we
+ /// requested with `readAtLeast`.
+ their_end: usize = 0,
+ /// Points to the end within buffer that has been filled. Beyond this point
+ /// in buf is undefined bytes.
+ cap: usize = 0,
+ /// Debug helper to prevent illegal calls to read functions.
+ disable_reads: bool = false,
+
+ pub fn fromTheirSlice(buf: []u8) Decoder {
+ return .{
+ .buf = buf,
+ .their_end = buf.len,
+ .cap = buf.len,
+ .disable_reads = true,
+ };
+ }
+
+ /// Use this function to increase `their_end`.
+ pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void {
+ assert(!d.disable_reads);
+ const existing_amt = d.cap - d.idx;
+ d.their_end = d.idx + their_amt;
+ if (their_amt <= existing_amt) return;
+ const request_amt = their_amt - existing_amt;
+ const dest = d.buf[d.cap..];
+ if (request_amt > dest.len) return error.TlsRecordOverflow;
+ const actual_amt = try stream.readAtLeast(dest, request_amt);
+ if (actual_amt < request_amt) return error.TlsConnectionTruncated;
+ d.cap += actual_amt;
+ }
+
+ /// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`.
+ /// Use when `our_amt` is calculated by us, not by them.
+ pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void {
+ assert(!d.disable_reads);
+ try readAtLeast(d, stream, our_amt);
+ d.our_end = d.idx + our_amt;
+ }
+
+ /// Use this function to increase `our_end`.
+ /// This should always be called with an amount provided by us, not them.
+ pub fn ensure(d: *Decoder, amt: usize) !void {
+ d.our_end = @max(d.idx + amt, d.our_end);
+ if (d.our_end > d.their_end) return error.TlsDecodeError;
+ }
+
+ /// Use this function to increase `idx`.
+ pub fn decode(d: *Decoder, comptime T: type) T {
+ switch (@typeInfo(T)) {
+ .Int => |info| switch (info.bits) {
+ 8 => {
+ skip(d, 1);
+ return d.buf[d.idx - 1];
+ },
+ 16 => {
+ skip(d, 2);
+ const b0: u16 = d.buf[d.idx - 2];
+ const b1: u16 = d.buf[d.idx - 1];
+ return (b0 << 8) | b1;
+ },
+ 24 => {
+ skip(d, 3);
+ const b0: u24 = d.buf[d.idx - 3];
+ const b1: u24 = d.buf[d.idx - 2];
+ const b2: u24 = d.buf[d.idx - 1];
+ return (b0 << 16) | (b1 << 8) | b2;
+ },
+ else => @compileError("unsupported int type: " ++ @typeName(T)),
+ },
+ .Enum => |info| {
+ const int = d.decode(info.tag_type);
+ if (info.is_exhaustive) @compileError("exhaustive enum cannot be used");
+ return @intToEnum(T, int);
+ },
+ else => @compileError("unsupported type: " ++ @typeName(T)),
+ }
+ }
+
+ /// Use this function to increase `idx`.
+ pub fn array(d: *Decoder, comptime len: usize) *[len]u8 {
+ skip(d, len);
+ return d.buf[d.idx - len ..][0..len];
+ }
+
+ /// Use this function to increase `idx`.
+ pub fn slice(d: *Decoder, len: usize) []u8 {
+ skip(d, len);
+ return d.buf[d.idx - len ..][0..len];
+ }
+
+ /// Use this function to increase `idx`.
+ pub fn skip(d: *Decoder, amt: usize) void {
+ d.idx += amt;
+ assert(d.idx <= d.our_end); // insufficient ensured bytes
+ }
+
+ pub fn eof(d: Decoder) bool {
+ assert(d.our_end <= d.their_end);
+ assert(d.idx <= d.our_end);
+ return d.idx == d.their_end;
+ }
+
+ /// Provide the length they claim, and receive a sub-decoder specific to that slice.
+ /// The parent decoder is advanced to the end.
+ pub fn sub(d: *Decoder, their_len: usize) !Decoder {
+ const end = d.idx + their_len;
+ if (end > d.their_end) return error.TlsDecodeError;
+ const sub_buf = d.buf[d.idx..end];
+ d.idx = end;
+ d.our_end = end;
+ return fromTheirSlice(sub_buf);
+ }
+
+ pub fn rest(d: Decoder) []u8 {
+ return d.buf[d.idx..d.cap];
+ }
+};