Commit 5d7eca6669

Andrew Kelley <andrew@ziglang.org>
2022-12-19 05:32:15
std.crypto.tls.Client: fix verify_data for batched handshakes
1 parent e2c16d0
Changed files (2)
lib
std
lib/std/crypto/tls/Client.zig
@@ -62,12 +62,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
         .rsa_pss_rsae_sha384,
         .rsa_pss_rsae_sha512,
         .ed25519,
-        .ed448,
-        .rsa_pss_pss_sha256,
-        .rsa_pss_pss_sha384,
-        .rsa_pss_pss_sha512,
-        .rsa_pkcs1_sha1,
-        .ecdsa_sha1,
     })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{
         .secp256r1,
         .x25519,
@@ -98,24 +92,21 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
         int2(legacy_compression_methods) ++
         extensions_header;
 
-    const handshake =
+    const out_handshake =
         [_]u8{@enumToInt(HandshakeType.client_hello)} ++
         int3(@intCast(u24, client_hello.len + host_len)) ++
         client_hello;
 
-    const hello_header = [_]u8{
-        // Plaintext header
+    const plaintext_header = [_]u8{
         @enumToInt(ContentType.handshake),
         0x03, 0x01, // legacy_record_version
-    } ++
-        int2(@intCast(u16, handshake.len + host_len)) ++
-        handshake;
+    } ++ int2(@intCast(u16, out_handshake.len + host_len)) ++ out_handshake;
 
     {
         var iovecs = [_]std.os.iovec_const{
             .{
-                .iov_base = &hello_header,
-                .iov_len = hello_header.len,
+                .iov_base = &plaintext_header,
+                .iov_len = plaintext_header.len,
             },
             .{
                 .iov_base = host.ptr,
@@ -125,7 +116,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
         try stream.writevAll(&iovecs);
     }
 
-    const client_hello_bytes1 = hello_header[5..];
+    const client_hello_bytes1 = plaintext_header[5..];
 
     var cipher_params: CipherParams = undefined;
 
@@ -176,7 +167,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
                 const cipher_suite_int = mem.readIntBig(u16, frag[i..][0..2]);
                 i += 2;
                 const cipher_suite_tag = @intToEnum(CipherSuite, cipher_suite_int);
-                std.debug.print("server wants cipher suite {any}\n", .{cipher_suite_tag});
                 const legacy_compression_method = frag[i];
                 i += 1;
                 _ = legacy_compression_method;
@@ -243,12 +233,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
                 if (!have_shared_key) return error.TlsIllegalParameter;
                 const tls_version = if (supported_version == 0) legacy_version else supported_version;
                 switch (tls_version) {
-                    @enumToInt(tls.ProtocolVersion.tls_1_2) => {
-                        std.debug.print("server wants TLS v1.2\n", .{});
-                    },
-                    @enumToInt(tls.ProtocolVersion.tls_1_3) => {
-                        std.debug.print("server wants TLS v1.3\n", .{});
-                    },
+                    @enumToInt(tls.ProtocolVersion.tls_1_3) => {},
                     else => return error.TlsIllegalParameter,
                 }
 
@@ -270,7 +255,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
                             .client_handshake_iv = undefined,
                             .server_handshake_iv = undefined,
                             .transcript_hash = P.Hash.init(.{}),
-                            .finished_digest = undefined,
                         });
                         const p = &@field(cipher_params, @tagName(tag));
                         p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1
@@ -361,7 +345,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
                         const ad = handshake_buf[end_hdr - 5 ..][0..5];
                         P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_handshake_key) catch
                             return error.TlsBadRecordMac;
-                        p.transcript_hash.update(cleartext[0 .. cleartext.len - 1]);
                         break :c cleartext;
                     },
                 };
@@ -378,17 +361,22 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
                             const next_handshake_i = ct_i + handshake_len;
                             if (next_handshake_i > cleartext.len - 1)
                                 return error.TlsBadLength;
+                            const wrapped_handshake = cleartext[ct_i - 4 .. next_handshake_i];
+                            const handshake = cleartext[ct_i..next_handshake_i];
                             switch (handshake_type) {
                                 @enumToInt(HandshakeType.encrypted_extensions) => {
-                                    const total_ext_size = mem.readIntBig(u16, cleartext[ct_i..][0..2]);
-                                    ct_i += 2;
-                                    const end_ext_i = ct_i + total_ext_size;
-                                    while (ct_i < end_ext_i) {
-                                        const et = mem.readIntBig(u16, cleartext[ct_i..][0..2]);
-                                        ct_i += 2;
-                                        const ext_size = mem.readIntBig(u16, cleartext[ct_i..][0..2]);
-                                        ct_i += 2;
-                                        const next_ext_i = ct_i + ext_size;
+                                    switch (cipher_params) {
+                                        inline else => |*p| p.transcript_hash.update(wrapped_handshake),
+                                    }
+                                    const total_ext_size = mem.readIntBig(u16, handshake[0..2]);
+                                    var hs_i: usize = 2;
+                                    const end_ext_i = 2 + total_ext_size;
+                                    while (hs_i < end_ext_i) {
+                                        const et = mem.readIntBig(u16, handshake[hs_i..][0..2]);
+                                        hs_i += 2;
+                                        const ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]);
+                                        hs_i += 2;
+                                        const next_ext_i = hs_i + ext_size;
                                         switch (et) {
                                             @enumToInt(tls.ExtensionType.server_name) => {},
                                             else => {
@@ -397,19 +385,38 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
                                                 });
                                             },
                                         }
-                                        ct_i = next_ext_i;
+                                        hs_i = next_ext_i;
                                     }
                                 },
                                 @enumToInt(HandshakeType.certificate) => {
-                                    std.debug.print("cool certificate bro\n", .{});
+                                    switch (cipher_params) {
+                                        inline else => |*p| p.transcript_hash.update(wrapped_handshake),
+                                    }
+                                    var hs_i: usize = 0;
+                                    const cert_req_ctx_len = handshake[hs_i];
+                                    hs_i += 1;
+                                    if (cert_req_ctx_len != 0) return error.TlsIllegalParameter;
+                                    const certs_size = mem.readIntBig(u24, handshake[hs_i..][0..3]);
+                                    hs_i += 3;
+                                    const end_certs = hs_i + certs_size;
+                                    while (hs_i < end_certs) {
+                                        const cert_size = mem.readIntBig(u24, handshake[hs_i..][0..3]);
+                                        hs_i += 3;
+                                        hs_i += cert_size;
+                                        const total_ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]);
+                                        hs_i += 2;
+                                        hs_i += total_ext_size;
+
+                                        std.debug.print("received certificate of size {d} bytes with {d} bytes of extensions\n", .{
+                                            cert_size, total_ext_size,
+                                        });
+                                    }
                                 },
                                 @enumToInt(HandshakeType.certificate_verify) => {
-                                    std.debug.print("the certificate came with a fancy signature\n", .{});
                                     switch (cipher_params) {
-                                        inline else => |*p| {
-                                            p.finished_digest = p.transcript_hash.peek();
-                                        },
+                                        inline else => |*p| p.transcript_hash.update(wrapped_handshake),
                                     }
+                                    std.debug.print("ignoring certificate_verify\n", .{});
                                 },
                                 @enumToInt(HandshakeType.finished) => {
                                     // This message is to trick buggy proxies into behaving correctly.
@@ -422,9 +429,10 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
                                     const app_cipher = switch (cipher_params) {
                                         inline else => |*p, tag| c: {
                                             const P = @TypeOf(p.*);
-                                            const expected_server_verify_data = tls.hmac(P.Hmac, &p.finished_digest, p.server_finished_key);
-                                            const actual_server_verify_data = cleartext[ct_i..][0..handshake_len];
-                                            if (!mem.eql(u8, &expected_server_verify_data, actual_server_verify_data))
+                                            const finished_digest = p.transcript_hash.peek();
+                                            p.transcript_hash.update(wrapped_handshake);
+                                            const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key);
+                                            if (!mem.eql(u8, &expected_server_verify_data, handshake))
                                                 return error.TlsDecryptError;
                                             const handshake_hash = p.transcript_hash.finalResult();
                                             const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key);
lib/std/crypto/tls.zig
@@ -221,6 +221,12 @@ pub const CipherSuite = enum(u16) {
     _,
 };
 
+pub const CertificateType = enum(u8) {
+    X509 = 0,
+    RawPublicKey = 2,
+    _,
+};
+
 pub fn CipherParamsT(comptime AeadType: type, comptime HashType: type) type {
     return struct {
         pub const AEAD = AeadType;
@@ -237,7 +243,6 @@ pub fn CipherParamsT(comptime AeadType: type, comptime HashType: type) type {
         client_handshake_iv: [AEAD.nonce_length]u8,
         server_handshake_iv: [AEAD.nonce_length]u8,
         transcript_hash: Hash,
-        finished_digest: [Hash.digest_length]u8,
     };
 }