Commit 7ed7bd247e

Andrew Kelley <andrew@ziglang.org>
2022-12-21 05:30:38
std.crypto.tls: verify the common name matches
1 parent 244a97e
Changed files (2)
lib/std/crypto/tls/Client.zig
@@ -18,6 +18,7 @@ const int2 = tls.int2;
 const int3 = tls.int3;
 const array = tls.array;
 const enum_array = tls.enum_array;
+const Certificate = crypto.CertificateBundle.Certificate;
 
 application_cipher: ApplicationCipher,
 read_seq: u64,
@@ -298,6 +299,8 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con
     };
 
     var read_seq: u64 = 0;
+    var validated_cert = false;
+    var is_subsequent_cert = false;
 
     while (true) {
         const end_hdr = i + 5;
@@ -386,10 +389,11 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con
                                         hs_i = next_ext_i;
                                     }
                                 },
-                                @enumToInt(HandshakeType.certificate) => {
+                                @enumToInt(HandshakeType.certificate) => cert: {
                                     switch (cipher_params) {
                                         inline else => |*p| p.transcript_hash.update(wrapped_handshake),
                                     }
+                                    if (validated_cert) break :cert;
                                     var hs_i: u32 = 0;
                                     const cert_req_ctx_len = handshake[hs_i];
                                     hs_i += 1;
@@ -402,41 +406,36 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con
                                         hs_i += 3;
                                         const end_cert = hs_i + cert_size;
 
-                                        const certificate = try Der.parseElement(handshake, hs_i);
-                                        const tbs_certificate = try Der.parseElement(handshake, certificate.start);
-
-                                        const version = try Der.parseElement(handshake, tbs_certificate.start);
-                                        if (@bitCast(u8, version.identifier) != 0xa0 or
-                                            !mem.eql(u8, handshake[version.start..version.end], "\x02\x01\x02"))
-                                        {
-                                            return error.UnsupportedCertificateVersion;
+                                        const subject_cert: Certificate = .{
+                                            .buffer = handshake,
+                                            .index = hs_i,
+                                        };
+                                        const subject = try subject_cert.parse();
+                                        if (!is_subsequent_cert) {
+                                            is_subsequent_cert = true;
+                                            if (mem.eql(u8, subject.common_name, host)) {
+                                                std.debug.print("exact host match\n", .{});
+                                            } else if (mem.startsWith(u8, subject.common_name, "*.") and
+                                                mem.eql(u8, subject.common_name[2..], host))
+                                            {
+                                                std.debug.print("wildcard host match\n", .{});
+                                            } else {
+                                                std.debug.print("host does not match\n", .{});
+                                                return error.TlsCertificateInvalidHost;
+                                            }
                                         }
 
-                                        const serial_number = try Der.parseElement(handshake, version.end);
-                                        // RFC 5280, section 4.1.2.3:
-                                        // "This field MUST contain the same algorithm identifier as
-                                        // the signatureAlgorithm field in the sequence Certificate."
-                                        const signature = try Der.parseElement(handshake, serial_number.end);
-                                        const issuer_elem = try Der.parseElement(handshake, signature.end);
-
-                                        const issuer_bytes = handshake[issuer_elem.start..issuer_elem.end];
-                                        if (ca_bundle.find(issuer_bytes)) |ca_cert_i| {
-                                            const Certificate = crypto.CertificateBundle.Certificate;
-                                            const subject: Certificate = .{
-                                                .buffer = handshake,
-                                                .index = hs_i,
-                                            };
-                                            const issuer: Certificate = .{
-                                                .buffer = ca_bundle.bytes.items,
-                                                .index = ca_cert_i,
-                                            };
-                                            if (subject.verify(issuer)) |_| {
-                                                std.debug.print("found a root CA cert matching issuer. verification success!\n", .{});
-                                            } else |err| {
-                                                std.debug.print("found a root CA cert matching issuer. verification failure: {s}\n", .{
-                                                    @errorName(err),
-                                                });
-                                            }
+                                        if (ca_bundle.verify(subject)) |_| {
+                                            std.debug.print("found a root CA cert matching issuer. verification success!\n", .{});
+                                            validated_cert = true;
+                                            break :cert;
+                                        } else |err| {
+                                            std.debug.print("unable to validate cert against system root CAs: {s}\n", .{
+                                                @errorName(err),
+                                            });
+                                            // TODO handle a certificate
+                                            // signing chain that ends in a
+                                            // root-validated one.
                                         }
 
                                         hs_i = end_cert;
lib/std/crypto/CertificateBundle.zig
@@ -13,6 +13,16 @@ pub const Key = struct {
     subject_end: u32,
 };
 
+pub fn verify(cb: CertificateBundle, subject: Certificate.Parsed) !void {
+    const bytes_index = cb.find(subject.issuer) orelse return error.IssuerNotFound;
+    const issuer_cert: Certificate = .{
+        .buffer = cb.bytes.items,
+        .index = bytes_index,
+    };
+    const issuer = try issuer_cert.parse();
+    try subject.verify(issuer);
+}
+
 /// The returned bytes become invalid after calling any of the rescan functions
 /// or add functions.
 pub fn find(cb: CertificateBundle, subject_name: []const u8) ?u32 {
@@ -120,18 +130,11 @@ pub fn key(cb: CertificateBundle, bytes_index: u32) !Key {
     const tbs_certificate = try Der.parseElement(bytes, certificate.start);
     const version = try Der.parseElement(bytes, tbs_certificate.start);
     try checkVersion(bytes, version);
-
     const serial_number = try Der.parseElement(bytes, version.end);
-
-    // RFC 5280, section 4.1.2.3:
-    // "This field MUST contain the same algorithm identifier as
-    // the signatureAlgorithm field in the sequence Certificate."
     const signature = try Der.parseElement(bytes, serial_number.end);
     const issuer = try Der.parseElement(bytes, signature.end);
     const validity = try Der.parseElement(bytes, issuer.end);
     const subject = try Der.parseElement(bytes, validity.end);
-    //const subject_pub_key = try Der.parseElement(bytes, subject.end);
-    //const extensions = try Der.parseElement(bytes, subject_pub_key.end);
 
     return .{
         .subject_start = subject.start,
@@ -143,70 +146,163 @@ pub const Certificate = struct {
     buffer: []const u8,
     index: u32,
 
-    pub fn verify(subject: Certificate, issuer: Certificate) !void {
-        const subject_certificate = try Der.parseElement(subject.buffer, subject.index);
-        const subject_tbs_certificate = try Der.parseElement(subject.buffer, subject_certificate.start);
-        const subject_version = try Der.parseElement(subject.buffer, subject_tbs_certificate.start);
-        try checkVersion(subject.buffer, subject_version);
-        const subject_serial_number = try Der.parseElement(subject.buffer, subject_version.end);
-        // RFC 5280, section 4.1.2.3:
-        // "This field MUST contain the same algorithm identifier as
-        // the signatureAlgorithm field in the sequence Certificate."
-        const subject_signature = try Der.parseElement(subject.buffer, subject_serial_number.end);
-        const subject_issuer = try Der.parseElement(subject.buffer, subject_signature.end);
-        const subject_validity = try Der.parseElement(subject.buffer, subject_issuer.end);
-        //const subject_name = try Der.parseElement(subject.buffer, subject_validity.end);
-
-        const subject_sig_algo = try Der.parseElement(subject.buffer, subject_tbs_certificate.end);
-        const subject_algo_elem = try Der.parseElement(subject.buffer, subject_sig_algo.start);
-        const subject_algo = try Der.parseObjectId(subject.buffer, subject_algo_elem);
-        const subject_sig_elem = try Der.parseElement(subject.buffer, subject_sig_algo.end);
-        const subject_sig = try parseBitString(subject, subject_sig_elem);
-
-        const issuer_certificate = try Der.parseElement(issuer.buffer, issuer.index);
-        const issuer_tbs_certificate = try Der.parseElement(issuer.buffer, issuer_certificate.start);
-        const issuer_version = try Der.parseElement(issuer.buffer, issuer_tbs_certificate.start);
-        try checkVersion(issuer.buffer, issuer_version);
-        const issuer_serial_number = try Der.parseElement(issuer.buffer, issuer_version.end);
+    pub const Algorithm = enum {
+        sha1WithRSAEncryption,
+        sha224WithRSAEncryption,
+        sha256WithRSAEncryption,
+        sha384WithRSAEncryption,
+        sha512WithRSAEncryption,
+
+        pub const map = std.ComptimeStringMap(Algorithm, .{
+            .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05 }, .sha1WithRSAEncryption },
+            .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B }, .sha256WithRSAEncryption },
+            .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption },
+            .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption },
+            .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption },
+        });
+
+        pub fn Hash(comptime algorithm: Algorithm) type {
+            return switch (algorithm) {
+                .sha1WithRSAEncryption => crypto.hash.Sha1,
+                .sha224WithRSAEncryption => crypto.hash.sha2.Sha224,
+                .sha256WithRSAEncryption => crypto.hash.sha2.Sha256,
+                .sha384WithRSAEncryption => crypto.hash.sha2.Sha384,
+                .sha512WithRSAEncryption => crypto.hash.sha2.Sha512,
+            };
+        }
+    };
+
+    pub const AlgorithmCategory = enum {
+        rsaEncryption,
+        X9_62_id_ecPublicKey,
+
+        pub const map = std.ComptimeStringMap(AlgorithmCategory, .{
+            .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01 }, .rsaEncryption },
+            .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01 }, .X9_62_id_ecPublicKey },
+        });
+    };
+
+    pub const Attribute = enum {
+        commonName,
+        serialNumber,
+        countryName,
+        localityName,
+        stateOrProvinceName,
+        organizationName,
+        organizationalUnitName,
+        organizationIdentifier,
+
+        pub const map = std.ComptimeStringMap(Attribute, .{
+            .{ &[_]u8{ 0x55, 0x04, 0x03 }, .commonName },
+            .{ &[_]u8{ 0x55, 0x04, 0x05 }, .serialNumber },
+            .{ &[_]u8{ 0x55, 0x04, 0x06 }, .countryName },
+            .{ &[_]u8{ 0x55, 0x04, 0x07 }, .localityName },
+            .{ &[_]u8{ 0x55, 0x04, 0x08 }, .stateOrProvinceName },
+            .{ &[_]u8{ 0x55, 0x04, 0x0A }, .organizationName },
+            .{ &[_]u8{ 0x55, 0x04, 0x0B }, .organizationalUnitName },
+            .{ &[_]u8{ 0x55, 0x04, 0x61 }, .organizationIdentifier },
+        });
+    };
+
+    pub const Parsed = struct {
+        certificate: Certificate,
+        issuer: []const u8,
+        subject: []const u8,
+        common_name: []const u8,
+        signature: []const u8,
+        signature_algorithm: Algorithm,
+        message: []const u8,
+        pub_key_algo: AlgorithmCategory,
+        pub_key: []const u8,
+
+        pub fn verify(subject: Parsed, issuer: Parsed) !void {
+            // Check that the subject's issuer name matches the issuer's
+            // subject name.
+            if (!mem.eql(u8, subject.issuer, issuer.subject)) {
+                return error.CertificateIssuerMismatch;
+            }
+
+            // TODO check the time validity for the subject
+            // TODO check the time validity for the issuer
+
+            switch (subject.signature_algorithm) {
+                inline .sha1WithRSAEncryption,
+                .sha224WithRSAEncryption,
+                .sha256WithRSAEncryption,
+                .sha384WithRSAEncryption,
+                .sha512WithRSAEncryption,
+                => |algorithm| return verifyRsa(
+                    algorithm.Hash(),
+                    subject.message,
+                    subject.signature,
+                    issuer.pub_key_algo,
+                    issuer.pub_key,
+                ),
+            }
+        }
+    };
+
+    pub fn parse(cert: Certificate) !Parsed {
+        const cert_bytes = cert.buffer;
+        const certificate = try Der.parseElement(cert_bytes, cert.index);
+        const tbs_certificate = try Der.parseElement(cert_bytes, certificate.start);
+        const version = try Der.parseElement(cert_bytes, tbs_certificate.start);
+        try checkVersion(cert_bytes, version);
+        const serial_number = try Der.parseElement(cert_bytes, version.end);
         // RFC 5280, section 4.1.2.3:
         // "This field MUST contain the same algorithm identifier as
         // the signatureAlgorithm field in the sequence Certificate."
-        const issuer_signature = try Der.parseElement(issuer.buffer, issuer_serial_number.end);
-        const issuer_issuer = try Der.parseElement(issuer.buffer, issuer_signature.end);
-        const issuer_validity = try Der.parseElement(issuer.buffer, issuer_issuer.end);
-        const issuer_name = try Der.parseElement(issuer.buffer, issuer_validity.end);
-        const issuer_pub_key_info = try Der.parseElement(issuer.buffer, issuer_name.end);
-        const issuer_pub_key_signature_algorithm = try Der.parseElement(issuer.buffer, issuer_pub_key_info.start);
-        const issuer_pub_key_algo_elem = try Der.parseElement(issuer.buffer, issuer_pub_key_signature_algorithm.start);
-        const issuer_pub_key_algo = try Der.parseObjectId(issuer.buffer, issuer_pub_key_algo_elem);
-        const issuer_pub_key_elem = try Der.parseElement(issuer.buffer, issuer_pub_key_signature_algorithm.end);
-        const issuer_pub_key = try parseBitString(issuer, issuer_pub_key_elem);
-
-        // Check that the subject's issuer name matches the issuer's subject
-        // name.
-        if (!mem.eql(u8, subject.contents(subject_issuer), issuer.contents(issuer_name))) {
-            return error.CertificateIssuerMismatch;
+        const tbs_signature = try Der.parseElement(cert_bytes, serial_number.end);
+        const issuer = try Der.parseElement(cert_bytes, tbs_signature.end);
+        const validity = try Der.parseElement(cert_bytes, issuer.end);
+        const subject = try Der.parseElement(cert_bytes, validity.end);
+
+        const pub_key_info = try Der.parseElement(cert_bytes, subject.end);
+        const pub_key_signature_algorithm = try Der.parseElement(cert_bytes, pub_key_info.start);
+        const pub_key_algo_elem = try Der.parseElement(cert_bytes, pub_key_signature_algorithm.start);
+        const pub_key_algo = try parseAlgorithmCategory(cert_bytes, pub_key_algo_elem);
+        const pub_key_elem = try Der.parseElement(cert_bytes, pub_key_signature_algorithm.end);
+        const pub_key = try parseBitString(cert, pub_key_elem);
+
+        const rdn = try Der.parseElement(cert_bytes, subject.start);
+        const atav = try Der.parseElement(cert_bytes, rdn.start);
+
+        var common_name: []const u8 = &.{};
+        var atav_i = atav.start;
+        while (atav_i < atav.end) {
+            const ty_elem = try Der.parseElement(cert_bytes, atav_i);
+            const ty = try parseAttribute(cert_bytes, ty_elem);
+            const val = try Der.parseElement(cert_bytes, ty_elem.end);
+            switch (ty) {
+                .commonName => common_name = cert.contents(val),
+                else => {},
+            }
+            atav_i = val.end;
         }
 
-        // TODO check the time validity for the subject
-        _ = subject_validity;
-        // TODO check the time validity for the issuer
-
-        const message = subject.buffer[subject_certificate.start..subject_tbs_certificate.end];
-        //std.debug.print("issuer algo: {any} subject algo: {any}\n", .{ issuer_pub_key_algo, subject_algo });
-        switch (subject_algo) {
-            // zig fmt: off
-              .sha1WithRSAEncryption => return verifyRsa(crypto.hash.Sha1,        message, subject_sig, issuer_pub_key_algo, issuer_pub_key),
-            .sha224WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha224, message, subject_sig, issuer_pub_key_algo, issuer_pub_key),
-            .sha256WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha256, message, subject_sig, issuer_pub_key_algo, issuer_pub_key),
-            .sha384WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha384, message, subject_sig, issuer_pub_key_algo, issuer_pub_key),
-            .sha512WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha512, message, subject_sig, issuer_pub_key_algo, issuer_pub_key),
-            // zig fmt: on
-            else => {
-                std.debug.print("unhandled algorithm: {any}\n", .{subject_algo});
-                return error.UnsupportedCertificateSignatureAlgorithm;
-            },
-        }
+        const sig_algo = try Der.parseElement(cert_bytes, tbs_certificate.end);
+        const algo_elem = try Der.parseElement(cert_bytes, sig_algo.start);
+        const signature_algorithm = try parseAlgorithm(cert_bytes, algo_elem);
+        const sig_elem = try Der.parseElement(cert_bytes, sig_algo.end);
+        const signature = try parseBitString(cert, sig_elem);
+
+        return .{
+            .certificate = cert,
+            .common_name = common_name,
+            .issuer = cert.contents(issuer),
+            .subject = cert.contents(subject),
+            .signature = signature,
+            .signature_algorithm = signature_algorithm,
+            .message = cert_bytes[certificate.start..tbs_certificate.end],
+            .pub_key_algo = pub_key_algo,
+            .pub_key = pub_key,
+        };
+    }
+
+    pub fn verify(subject: Certificate, issuer: Certificate) !void {
+        const parsed_subject = try subject.parse();
+        const parsed_issuer = try issuer.parse();
+        return parsed_subject.verify(parsed_issuer);
     }
 
     pub fn contents(cert: Certificate, elem: Der.Element) []const u8 {
@@ -219,7 +315,30 @@ pub const Certificate = struct {
         return cert.buffer[elem.start + 1 .. elem.end];
     }
 
-    fn verifyRsa(comptime Hash: type, message: []const u8, sig: []const u8, pub_key_algo: Der.Oid, pub_key: []const u8) !void {
+    pub fn parseAlgorithm(bytes: []const u8, element: Der.Element) !Algorithm {
+        if (element.identifier.tag != .object_identifier)
+            return error.CertificateFieldHasWrongDataType;
+        return Algorithm.map.get(bytes[element.start..element.end]) orelse
+            return error.CertificateHasUnrecognizedAlgorithm;
+    }
+
+    pub fn parseAlgorithmCategory(bytes: []const u8, element: Der.Element) !AlgorithmCategory {
+        if (element.identifier.tag != .object_identifier)
+            return error.CertificateFieldHasWrongDataType;
+        return AlgorithmCategory.map.get(bytes[element.start..element.end]) orelse {
+            std.debug.print("unrecognized algorithm category: {}\n", .{std.fmt.fmtSliceHexLower(bytes[element.start..element.end])});
+            return error.CertificateHasUnrecognizedAlgorithmCategory;
+        };
+    }
+
+    pub fn parseAttribute(bytes: []const u8, element: Der.Element) !Attribute {
+        if (element.identifier.tag != .object_identifier)
+            return error.CertificateFieldHasWrongDataType;
+        return Attribute.map.get(bytes[element.start..element.end]) orelse
+            return error.CertificateHasUnrecognizedAlgorithm;
+    }
+
+    fn verifyRsa(comptime Hash: type, message: []const u8, sig: []const u8, pub_key_algo: AlgorithmCategory, pub_key: []const u8) !void {
         if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch;
         const pub_key_seq = try Der.parseElement(pub_key, 0);
         if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType;