Commit 16f936b420

Andrew Kelley <andrew@ziglang.org>
2022-12-22 02:54:17
std.crypto.tls: handle the certificate_verify message
1 parent 29475b4
Changed files (2)
lib
lib/std/crypto/tls/Client.zig
@@ -308,8 +308,23 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
     var prev_cert: Certificate.Parsed = undefined;
     // Set to true once a trust chain has been established from the first
     // certificate to a root CA.
-    var cert_verification_done = false;
+    const HandshakeState = enum {
+        /// In this state we expect only an encrypted_extensions message.
+        encrypted_extensions,
+        /// In this state we expect certificate messages.
+        certificate,
+        /// In this state we expect certificate or certificate_verify messages.
+        /// certificate messages are ignored since the trust chain is already
+        /// established.
+        trust_chain_established,
+        /// In this state, we expect only the finished message.
+        finished,
+    };
+    var handshake_state: HandshakeState = .encrypted_extensions;
     var cleartext_bufs: [2][8000]u8 = undefined;
+    var main_cert_pub_key_algo: Certificate.AlgorithmCategory = undefined;
+    var main_cert_pub_key_buf: [128]u8 = undefined;
+    var main_cert_pub_key_len: u8 = undefined;
 
     while (true) {
         const end_hdr = i + 5;
@@ -376,6 +391,8 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                             const handshake = cleartext[ct_i..next_handshake_i];
                             switch (handshake_type) {
                                 @enumToInt(HandshakeType.encrypted_extensions) => {
+                                    if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage;
+                                    handshake_state = .certificate;
                                     switch (handshake_cipher) {
                                         inline else => |*p| p.transcript_hash.update(wrapped_handshake),
                                     }
@@ -403,7 +420,11 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                                     switch (handshake_cipher) {
                                         inline else => |*p| p.transcript_hash.update(wrapped_handshake),
                                     }
-                                    if (cert_verification_done) break :cert;
+                                    switch (handshake_state) {
+                                        .certificate => {},
+                                        .trust_chain_established => break :cert,
+                                        else => return error.TlsUnexpectedMessage,
+                                    }
                                     var hs_i: u32 = 0;
                                     const cert_req_ctx_len = handshake[hs_i];
                                     hs_i += 1;
@@ -421,38 +442,41 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                                             .index = hs_i,
                                         };
                                         const subject = try subject_cert.parse();
-                                        if (cert_index > 0) {
-                                            if (prev_cert.verify(subject)) |_| {
-                                                std.debug.print("previous certificate verified\n", .{});
-                                            } else |err| {
+                                        if (cert_index == 0) {
+                                            // Verify the host on the first certificate.
+                                            if (!hostMatchesCommonName(host, subject.commonName())) {
+                                                return error.TlsCertificateHostMismatch;
+                                            }
+
+                                            // Keep track of the public key for
+                                            // the certificate_verify message
+                                            // later.
+                                            main_cert_pub_key_algo = subject.pub_key_algo;
+                                            const pub_key = subject.pubKey();
+                                            if (pub_key.len > main_cert_pub_key_buf.len)
+                                                return error.CertificatePublicKeyInvalid;
+                                            @memcpy(&main_cert_pub_key_buf, pub_key.ptr, pub_key.len);
+                                            main_cert_pub_key_len = @intCast(@TypeOf(main_cert_pub_key_len), pub_key.len);
+                                        } else {
+                                            prev_cert.verify(subject) catch |err| {
                                                 std.debug.print("unable to validate previous cert: {s}\n", .{
                                                     @errorName(err),
                                                 });
-                                            }
-                                        } else {
-                                            // Verify the host on the first certificate.
-                                            const common_name = subject.commonName();
-                                            if (mem.eql(u8, common_name, host)) {
-                                                std.debug.print("exact host match\n", .{});
-                                            } else if (mem.startsWith(u8, common_name, "*.") and
-                                                (mem.endsWith(u8, host, common_name[1..]) or
-                                                mem.eql(u8, common_name[2..], host)))
-                                            {
-                                                std.debug.print("wildcard host match\n", .{});
-                                            } else {
-                                                std.debug.print("host does not match\n", .{});
-                                                return error.TlsCertificateInvalidHost;
-                                            }
+                                                return err;
+                                            };
                                         }
 
                                         if (ca_bundle.verify(subject)) |_| {
-                                            std.debug.print("found a root CA cert matching issuer. verification success!\n", .{});
-                                            cert_verification_done = true;
+                                            handshake_state = .trust_chain_established;
                                             break :cert;
-                                        } else |err| {
-                                            std.debug.print("unable to validate cert against system root CAs: {s}\n", .{
-                                                @errorName(err),
-                                            });
+                                        } else |err| switch (err) {
+                                            error.IssuerNotFound => {},
+                                            else => |e| {
+                                                std.debug.print("unable to validate cert against system root CAs: {s}\n", .{
+                                                    @errorName(e),
+                                                });
+                                                return e;
+                                            },
                                         }
 
                                         prev_cert = subject;
@@ -465,12 +489,46 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                                     }
                                 },
                                 @enumToInt(HandshakeType.certificate_verify) => {
-                                    switch (handshake_cipher) {
-                                        inline else => |*p| p.transcript_hash.update(wrapped_handshake),
+                                    switch (handshake_state) {
+                                        .trust_chain_established => handshake_state = .finished,
+                                        .certificate => return error.TlsCertificateNotVerified,
+                                        else => return error.TlsUnexpectedMessage,
+                                    }
+
+                                    const algorithm = @intToEnum(tls.SignatureScheme, mem.readIntBig(u16, handshake[0..2]));
+                                    const sig_len = mem.readIntBig(u16, handshake[2..4]);
+                                    if (4 + sig_len > handshake.len) return error.TlsBadLength;
+                                    const encoded_sig = handshake[4..][0..sig_len];
+                                    const max_digest_len = 64;
+                                    var verify_buffer =
+                                        ([1]u8{0x20} ** 64) ++
+                                        "TLS 1.3, server CertificateVerify\x00".* ++
+                                        ([1]u8{undefined} ** max_digest_len);
+
+                                    const verify_bytes = switch (handshake_cipher) {
+                                        inline else => |*p| v: {
+                                            const transcript_digest = p.transcript_hash.peek();
+                                            verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest;
+                                            p.transcript_hash.update(wrapped_handshake);
+                                            break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len];
+                                        },
+                                    };
+                                    const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len];
+
+                                    switch (algorithm) {
+                                        .ecdsa_secp256r1_sha256 => {
+                                            if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey)
+                                                return error.TlsBadSignatureAlgorithm;
+                                            const P256 = std.crypto.sign.ecdsa.EcdsaP256Sha256;
+                                            const sig = try P256.Signature.fromDer(encoded_sig);
+                                            const key = try P256.PublicKey.fromSec1(main_cert_pub_key);
+                                            try sig.verify(verify_bytes, key);
+                                        },
+                                        else => return error.TlsBadSignatureAlgorithm,
                                     }
-                                    std.debug.print("ignoring certificate_verify\n", .{});
                                 },
                                 @enumToInt(HandshakeType.finished) => {
+                                    if (handshake_state != .finished) return error.TlsUnexpectedMessage;
                                     // This message is to trick buggy proxies into behaving correctly.
                                     const client_change_cipher_spec_msg = [_]u8{
                                         @enumToInt(ContentType.change_cipher_spec),
@@ -762,6 +820,26 @@ fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize {
     return out;
 }
 
+fn hostMatchesCommonName(host: []const u8, common_name: []const u8) bool {
+    if (mem.eql(u8, common_name, host)) {
+        return true; // exact match
+    }
+
+    if (mem.startsWith(u8, common_name, "*.")) {
+        // wildcard certificate, matches any subdomain
+        if (mem.endsWith(u8, host, common_name[1..])) {
+            // The host has a subdomain, but the important part matches.
+            return true;
+        }
+        if (mem.eql(u8, common_name[2..], host)) {
+            // The host has no subdomain and matches exactly.
+            return true;
+        }
+    }
+
+    return false;
+}
+
 const builtin = @import("builtin");
 const native_endian = builtin.cpu.arch.endian();
 
lib/std/crypto/Certificate.zig
@@ -9,6 +9,10 @@ pub const Algorithm = enum {
     sha256WithRSAEncryption,
     sha384WithRSAEncryption,
     sha512WithRSAEncryption,
+    ecdsa_with_SHA224,
+    ecdsa_with_SHA256,
+    ecdsa_with_SHA384,
+    ecdsa_with_SHA512,
 
     pub const map = std.ComptimeStringMap(Algorithm, .{
         .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05 }, .sha1WithRSAEncryption },
@@ -16,15 +20,19 @@ pub const Algorithm = enum {
         .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption },
         .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption },
         .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption },
+        .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x01 }, .ecdsa_with_SHA224 },
+        .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x02 }, .ecdsa_with_SHA256 },
+        .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x03 }, .ecdsa_with_SHA384 },
+        .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x04 }, .ecdsa_with_SHA512 },
     });
 
     pub fn Hash(comptime algorithm: Algorithm) type {
         return switch (algorithm) {
             .sha1WithRSAEncryption => crypto.hash.Sha1,
-            .sha224WithRSAEncryption => crypto.hash.sha2.Sha224,
-            .sha256WithRSAEncryption => crypto.hash.sha2.Sha256,
-            .sha384WithRSAEncryption => crypto.hash.sha2.Sha384,
-            .sha512WithRSAEncryption => crypto.hash.sha2.Sha512,
+            .ecdsa_with_SHA224, .sha224WithRSAEncryption => crypto.hash.sha2.Sha224,
+            .ecdsa_with_SHA256, .sha256WithRSAEncryption => crypto.hash.sha2.Sha256,
+            .ecdsa_with_SHA384, .sha384WithRSAEncryption => crypto.hash.sha2.Sha384,
+            .ecdsa_with_SHA512, .sha512WithRSAEncryption => crypto.hash.sha2.Sha512,
         };
     }
 };
@@ -125,6 +133,13 @@ pub const Parsed = struct {
                 parsed_issuer.pub_key_algo,
                 parsed_issuer.pubKey(),
             ),
+            .ecdsa_with_SHA224,
+            .ecdsa_with_SHA256,
+            .ecdsa_with_SHA384,
+            .ecdsa_with_SHA512,
+            => {
+                return error.CertificateSignatureAlgorithmUnsupported;
+            },
         }
     }
 };
@@ -205,8 +220,11 @@ pub fn parseBitString(cert: Certificate, elem: der.Element) !der.Element.Slice {
 pub fn parseAlgorithm(bytes: []const u8, element: der.Element) !Algorithm {
     if (element.identifier.tag != .object_identifier)
         return error.CertificateFieldHasWrongDataType;
-    return Algorithm.map.get(bytes[element.slice.start..element.slice.end]) orelse
+    const oid_bytes = bytes[element.slice.start..element.slice.end];
+    return Algorithm.map.get(oid_bytes) orelse {
+        //std.debug.print("oid bytes: {}\n", .{std.fmt.fmtSliceHexLower(oid_bytes)});
         return error.CertificateHasUnrecognizedAlgorithm;
+    };
 }
 
 pub fn parseAlgorithmCategory(bytes: []const u8, element: der.Element) !AlgorithmCategory {