Commit 5bbedb63cf

Andrew Kelley <andrew@ziglang.org>
2022-12-28 03:10:31
std.crypto.Certificate: support verifying secp384r1 pub keys
1 parent b1cbfa0
Changed files (1)
lib
std
lib/std/crypto/Certificate.zig
@@ -71,6 +71,16 @@ pub const Attribute = enum {
     });
 };
 
+pub const NamedCurve = enum {
+    secp384r1,
+    X9_62_prime256v1,
+
+    pub const map = std.ComptimeStringMap(NamedCurve, .{
+        .{ &[_]u8{ 0x2B, 0x81, 0x04, 0x00, 0x22 }, .secp384r1 },
+        .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x07 }, .X9_62_prime256v1 },
+    });
+};
+
 pub const Parsed = struct {
     certificate: Certificate,
     issuer_slice: Slice,
@@ -78,11 +88,16 @@ pub const Parsed = struct {
     common_name_slice: Slice,
     signature_slice: Slice,
     signature_algorithm: Algorithm,
-    pub_key_algo: AlgorithmCategory,
+    pub_key_algo: PubKeyAlgo,
     pub_key_slice: Slice,
     message_slice: Slice,
     validity: Validity,
 
+    pub const PubKeyAlgo = union(AlgorithmCategory) {
+        rsaEncryption: void,
+        X9_62_id_ecPublicKey: NamedCurve,
+    };
+
     pub const Validity = struct {
         not_before: u64,
         not_after: u64,
@@ -114,6 +129,10 @@ pub const Parsed = struct {
         return p.slice(p.pub_key_slice);
     }
 
+    pub fn pubKeySigAlgo(p: Parsed) []const u8 {
+        return p.slice(p.pub_key_signature_algorithm_slice);
+    }
+
     pub fn message(p: Parsed) []const u8 {
         return p.slice(p.message_slice);
     }
@@ -130,6 +149,7 @@ pub const Parsed = struct {
         CertificateSignatureInvalidLength,
         CertificateSignatureInvalid,
         CertificateSignatureUnsupportedBitCount,
+        CertificateSignatureNamedCurveUnsupported,
     };
 
     /// This function checks the time validity for the subject only. Checking
@@ -160,56 +180,78 @@ pub const Parsed = struct {
                 parsed_issuer.pub_key_algo,
                 parsed_issuer.pubKey(),
             ),
-            .ecdsa_with_SHA224,
+
+            inline .ecdsa_with_SHA224,
             .ecdsa_with_SHA256,
             .ecdsa_with_SHA384,
             .ecdsa_with_SHA512,
-            => {
-                return error.CertificateSignatureAlgorithmUnsupported;
-            },
+            => |algorithm| return verify_ecdsa(
+                algorithm.Hash(),
+                parsed_subject.message(),
+                parsed_subject.signature(),
+                parsed_issuer.pub_key_algo,
+                parsed_issuer.pubKey(),
+            ),
         }
     }
 };
 
 pub fn parse(cert: Certificate) !Parsed {
     const cert_bytes = cert.buffer;
-    const certificate = try der.parseElement(cert_bytes, cert.index);
-    const tbs_certificate = try der.parseElement(cert_bytes, certificate.slice.start);
-    const version = try der.parseElement(cert_bytes, tbs_certificate.slice.start);
+    const certificate = try der.Element.parse(cert_bytes, cert.index);
+    const tbs_certificate = try der.Element.parse(cert_bytes, certificate.slice.start);
+    const version = try der.Element.parse(cert_bytes, tbs_certificate.slice.start);
     try checkVersion(cert_bytes, version);
-    const serial_number = try der.parseElement(cert_bytes, version.slice.end);
+    const serial_number = try der.Element.parse(cert_bytes, version.slice.end);
     // RFC 5280, section 4.1.2.3:
     // "This field MUST contain the same algorithm identifier as
     // the signatureAlgorithm field in the sequence Certificate."
-    const tbs_signature = try der.parseElement(cert_bytes, serial_number.slice.end);
-    const issuer = try der.parseElement(cert_bytes, tbs_signature.slice.end);
-    const validity = try der.parseElement(cert_bytes, issuer.slice.end);
-    const not_before = try der.parseElement(cert_bytes, validity.slice.start);
+    const tbs_signature = try der.Element.parse(cert_bytes, serial_number.slice.end);
+    const issuer = try der.Element.parse(cert_bytes, tbs_signature.slice.end);
+    const validity = try der.Element.parse(cert_bytes, issuer.slice.end);
+    const not_before = try der.Element.parse(cert_bytes, validity.slice.start);
     const not_before_utc = try parseTime(cert, not_before);
-    const not_after = try der.parseElement(cert_bytes, not_before.slice.end);
+    const not_after = try der.Element.parse(cert_bytes, not_before.slice.end);
     const not_after_utc = try parseTime(cert, not_after);
-    const subject = try der.parseElement(cert_bytes, validity.slice.end);
-
-    const pub_key_info = try der.parseElement(cert_bytes, subject.slice.end);
-    const pub_key_signature_algorithm = try der.parseElement(cert_bytes, pub_key_info.slice.start);
-    const pub_key_algo_elem = try der.parseElement(cert_bytes, pub_key_signature_algorithm.slice.start);
-    const pub_key_algo = try parseAlgorithmCategory(cert_bytes, pub_key_algo_elem);
-    const pub_key_elem = try der.parseElement(cert_bytes, pub_key_signature_algorithm.slice.end);
+    const subject = try der.Element.parse(cert_bytes, validity.slice.end);
+
+    const pub_key_info = try der.Element.parse(cert_bytes, subject.slice.end);
+    const pub_key_signature_algorithm = try der.Element.parse(cert_bytes, pub_key_info.slice.start);
+    const pub_key_algo_elem = try der.Element.parse(cert_bytes, pub_key_signature_algorithm.slice.start);
+    const pub_key_algo_tag = try parseAlgorithmCategory(cert_bytes, pub_key_algo_elem);
+    var pub_key_algo: Parsed.PubKeyAlgo = undefined;
+    switch (pub_key_algo_tag) {
+        .rsaEncryption => {
+            pub_key_algo = .{ .rsaEncryption = {} };
+        },
+        .X9_62_id_ecPublicKey => {
+            // RFC 5480 Section 2.1.1.1 Named Curve
+            // ECParameters ::= CHOICE {
+            //   namedCurve         OBJECT IDENTIFIER
+            //   -- implicitCurve   NULL
+            //   -- specifiedCurve  SpecifiedECDomain
+            // }
+            const params_elem = try der.Element.parse(cert_bytes, pub_key_algo_elem.slice.end);
+            const named_curve = try parseNamedCurve(cert_bytes, params_elem);
+            pub_key_algo = .{ .X9_62_id_ecPublicKey = named_curve };
+        },
+    }
+    const pub_key_elem = try der.Element.parse(cert_bytes, pub_key_signature_algorithm.slice.end);
     const pub_key = try parseBitString(cert, pub_key_elem);
 
     var common_name = der.Element.Slice.empty;
     var name_i = subject.slice.start;
     //std.debug.print("subject name:\n", .{});
     while (name_i < subject.slice.end) {
-        const rdn = try der.parseElement(cert_bytes, name_i);
+        const rdn = try der.Element.parse(cert_bytes, name_i);
         var rdn_i = rdn.slice.start;
         while (rdn_i < rdn.slice.end) {
-            const atav = try der.parseElement(cert_bytes, rdn_i);
+            const atav = try der.Element.parse(cert_bytes, rdn_i);
             var atav_i = atav.slice.start;
             while (atav_i < atav.slice.end) {
-                const ty_elem = try der.parseElement(cert_bytes, atav_i);
+                const ty_elem = try der.Element.parse(cert_bytes, atav_i);
                 const ty = try parseAttribute(cert_bytes, ty_elem);
-                const val = try der.parseElement(cert_bytes, ty_elem.slice.end);
+                const val = try der.Element.parse(cert_bytes, ty_elem.slice.end);
                 //std.debug.print(" {s}: '{s}'\n", .{
                 //    @tagName(ty), cert_bytes[val.slice.start..val.slice.end],
                 //});
@@ -224,10 +266,10 @@ pub fn parse(cert: Certificate) !Parsed {
         name_i = rdn.slice.end;
     }
 
-    const sig_algo = try der.parseElement(cert_bytes, tbs_certificate.slice.end);
-    const algo_elem = try der.parseElement(cert_bytes, sig_algo.slice.start);
+    const sig_algo = try der.Element.parse(cert_bytes, tbs_certificate.slice.end);
+    const algo_elem = try der.Element.parse(cert_bytes, sig_algo.slice.start);
     const signature_algorithm = try parseAlgorithm(cert_bytes, algo_elem);
-    const sig_elem = try der.parseElement(cert_bytes, sig_algo.slice.end);
+    const sig_elem = try der.Element.parse(cert_bytes, sig_algo.slice.end);
     const signature = try parseBitString(cert, sig_elem);
 
     return .{
@@ -391,45 +433,52 @@ test parseYear4 {
 }
 
 pub fn parseAlgorithm(bytes: []const u8, element: der.Element) !Algorithm {
-    if (element.identifier.tag != .object_identifier)
-        return error.CertificateFieldHasWrongDataType;
-    const oid_bytes = bytes[element.slice.start..element.slice.end];
-    return Algorithm.map.get(oid_bytes) orelse {
-        //std.debug.print("oid bytes: {}\n", .{std.fmt.fmtSliceHexLower(oid_bytes)});
-        return error.CertificateHasUnrecognizedAlgorithm;
-    };
+    return parseEnum(Algorithm, bytes, element);
 }
 
 pub fn parseAlgorithmCategory(bytes: []const u8, element: der.Element) !AlgorithmCategory {
-    if (element.identifier.tag != .object_identifier)
-        return error.CertificateFieldHasWrongDataType;
-    return AlgorithmCategory.map.get(bytes[element.slice.start..element.slice.end]) orelse
-        return error.CertificateHasUnrecognizedAlgorithmCategory;
+    return parseEnum(AlgorithmCategory, bytes, element);
 }
 
 pub fn parseAttribute(bytes: []const u8, element: der.Element) !Attribute {
+    return parseEnum(Attribute, bytes, element);
+}
+
+pub fn parseNamedCurve(bytes: []const u8, element: der.Element) !NamedCurve {
+    return parseEnum(NamedCurve, bytes, element);
+}
+
+fn parseEnum(comptime E: type, bytes: []const u8, element: der.Element) !E {
     if (element.identifier.tag != .object_identifier)
         return error.CertificateFieldHasWrongDataType;
     const oid_bytes = bytes[element.slice.start..element.slice.end];
-    return Attribute.map.get(oid_bytes) orelse {
-        //std.debug.print("attr: {}\n", .{std.fmt.fmtSliceHexLower(oid_bytes)});
-        return error.CertificateHasUnrecognizedAttribute;
+    return E.map.get(oid_bytes) orelse {
+        //std.debug.print("tag: {}\n", .{std.fmt.fmtSliceHexLower(oid_bytes)});
+        return error.CertificateHasUnrecognizedObjectId;
     };
 }
 
+pub fn checkVersion(bytes: []const u8, version: der.Element) !void {
+    if (@bitCast(u8, version.identifier) != 0xa0 or
+        !mem.eql(u8, bytes[version.slice.start..version.slice.end], "\x02\x01\x02"))
+    {
+        return error.UnsupportedCertificateVersion;
+    }
+}
+
 fn verifyRsa(
     comptime Hash: type,
     message: []const u8,
     sig: []const u8,
-    pub_key_algo: AlgorithmCategory,
+    pub_key_algo: Parsed.PubKeyAlgo,
     pub_key: []const u8,
 ) !void {
     if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch;
-    const pub_key_seq = try der.parseElement(pub_key, 0);
+    const pub_key_seq = try der.Element.parse(pub_key, 0);
     if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType;
-    const modulus_elem = try der.parseElement(pub_key, pub_key_seq.slice.start);
+    const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start);
     if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
-    const exponent_elem = try der.parseElement(pub_key, modulus_elem.slice.end);
+    const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end);
     if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
     // Skip over meaningless zeroes in the modulus.
     const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end];
@@ -504,11 +553,39 @@ fn verifyRsa(
     }
 }
 
-pub fn checkVersion(bytes: []const u8, version: der.Element) !void {
-    if (@bitCast(u8, version.identifier) != 0xa0 or
-        !mem.eql(u8, bytes[version.slice.start..version.slice.end], "\x02\x01\x02"))
-    {
-        return error.UnsupportedCertificateVersion;
+fn verify_ecdsa(
+    comptime Hash: type,
+    message: []const u8,
+    encoded_sig: []const u8,
+    pub_key_algo: Parsed.PubKeyAlgo,
+    sec1_pub_key: []const u8,
+) !void {
+    const sig_named_curve = switch (pub_key_algo) {
+        .X9_62_id_ecPublicKey => |named_curve| named_curve,
+        else => return error.CertificateSignatureAlgorithmMismatch,
+    };
+
+    switch (sig_named_curve) {
+        .secp384r1 => {
+            const P = crypto.ecc.P384;
+            const Ecdsa = crypto.sign.ecdsa.Ecdsa(P, Hash);
+            const sig = Ecdsa.Signature.fromDer(encoded_sig) catch |err| switch (err) {
+                error.InvalidEncoding => return error.CertificateSignatureInvalid,
+            };
+            const pub_key = Ecdsa.PublicKey.fromSec1(sec1_pub_key) catch |err| switch (err) {
+                error.InvalidEncoding => return error.CertificateSignatureInvalid,
+                error.NonCanonical => return error.CertificateSignatureInvalid,
+                error.NotSquare => return error.CertificateSignatureInvalid,
+            };
+            sig.verify(message, pub_key) catch |err| switch (err) {
+                error.IdentityElement => return error.CertificateSignatureInvalid,
+                error.NonCanonical => return error.CertificateSignatureInvalid,
+                error.SignatureVerificationFailed => return error.CertificateSignatureInvalid,
+            };
+        },
+        .X9_62_prime256v1 => {
+            return error.CertificateSignatureNamedCurveUnsupported;
+        },
     }
 }
 
@@ -559,45 +636,45 @@ pub const der = struct {
 
             pub const empty: Slice = .{ .start = 0, .end = 0 };
         };
-    };
 
-    pub const ParseElementError = error{CertificateFieldHasInvalidLength};
+        pub const ParseError = error{CertificateFieldHasInvalidLength};
+
+        pub fn parse(bytes: []const u8, index: u32) ParseError!Element {
+            var i = index;
+            const identifier = @bitCast(Identifier, bytes[i]);
+            i += 1;
+            const size_byte = bytes[i];
+            i += 1;
+            if ((size_byte >> 7) == 0) {
+                return .{
+                    .identifier = identifier,
+                    .slice = .{
+                        .start = i,
+                        .end = i + size_byte,
+                    },
+                };
+            }
+
+            const len_size = @truncate(u7, size_byte);
+            if (len_size > @sizeOf(u32)) {
+                return error.CertificateFieldHasInvalidLength;
+            }
+
+            const end_i = i + len_size;
+            var long_form_size: u32 = 0;
+            while (i < end_i) : (i += 1) {
+                long_form_size = (long_form_size << 8) | bytes[i];
+            }
 
-    pub fn parseElement(bytes: []const u8, index: u32) ParseElementError!Element {
-        var i = index;
-        const identifier = @bitCast(Identifier, bytes[i]);
-        i += 1;
-        const size_byte = bytes[i];
-        i += 1;
-        if ((size_byte >> 7) == 0) {
             return .{
                 .identifier = identifier,
                 .slice = .{
                     .start = i,
-                    .end = i + size_byte,
+                    .end = i + long_form_size,
                 },
             };
         }
-
-        const len_size = @truncate(u7, size_byte);
-        if (len_size > @sizeOf(u32)) {
-            return error.CertificateFieldHasInvalidLength;
-        }
-
-        const end_i = i + len_size;
-        var long_form_size: u32 = 0;
-        while (i < end_i) : (i += 1) {
-            long_form_size = (long_form_size << 8) | bytes[i];
-        }
-
-        return .{
-            .identifier = identifier,
-            .slice = .{
-                .start = i,
-                .end = i + long_form_size,
-            },
-        };
-    }
+    };
 };
 
 test {