Commit 244a97e8ad

Andrew Kelley <andrew@ziglang.org>
2022-12-21 03:26:23
std.crypto.tls: certificate signature validation
1 parent 504070e
Changed files (2)
lib/std/crypto/tls/Client.zig
@@ -53,15 +53,12 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con
         0x02, // byte length of supported versions
         0x03, 0x04, // TLS 1.3
     }) ++ tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{
-        .rsa_pkcs1_sha256,
-        .rsa_pkcs1_sha384,
-        .rsa_pkcs1_sha512,
         .ecdsa_secp256r1_sha256,
         .ecdsa_secp384r1_sha384,
         .ecdsa_secp521r1_sha512,
-        .rsa_pss_rsae_sha256,
-        .rsa_pss_rsae_sha384,
-        .rsa_pss_rsae_sha512,
+        .rsa_pkcs1_sha256,
+        .rsa_pkcs1_sha384,
+        .rsa_pkcs1_sha512,
         .ed25519,
     })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{
         .secp256r1,
@@ -420,33 +417,32 @@ pub fn init(stream: net.Stream, ca_bundle: crypto.CertificateBundle, host: []con
                                         // "This field MUST contain the same algorithm identifier as
                                         // the signatureAlgorithm field in the sequence Certificate."
                                         const signature = try Der.parseElement(handshake, serial_number.end);
-                                        const issuer = try Der.parseElement(handshake, signature.end);
-                                        const validity = try Der.parseElement(handshake, issuer.end);
-                                        const subject = try Der.parseElement(handshake, validity.end);
-                                        const subject_pub_key = try Der.parseElement(handshake, subject.end);
-                                        const extensions = try Der.parseElement(handshake, subject_pub_key.end);
-                                        _ = extensions;
-
-                                        const signature_algorithm = try Der.parseElement(handshake, tbs_certificate.end);
-                                        const signature_value = try Der.parseElement(handshake, signature_algorithm.end);
-                                        _ = signature_value;
-
-                                        const algorithm_elem = try Der.parseElement(handshake, signature_algorithm.start);
-                                        const algorithm = try Der.parseObjectId(handshake, algorithm_elem);
-                                        std.debug.print("cert has this signature algorithm: {any}\n", .{algorithm});
-                                        //const parameters = try Der.parseElement(signature_algorithm.contents, &sa_i);
+                                        const issuer_elem = try Der.parseElement(handshake, signature.end);
+
+                                        const issuer_bytes = handshake[issuer_elem.start..issuer_elem.end];
+                                        if (ca_bundle.find(issuer_bytes)) |ca_cert_i| {
+                                            const Certificate = crypto.CertificateBundle.Certificate;
+                                            const subject: Certificate = .{
+                                                .buffer = handshake,
+                                                .index = hs_i,
+                                            };
+                                            const issuer: Certificate = .{
+                                                .buffer = ca_bundle.bytes.items,
+                                                .index = ca_cert_i,
+                                            };
+                                            if (subject.verify(issuer)) |_| {
+                                                std.debug.print("found a root CA cert matching issuer. verification success!\n", .{});
+                                            } else |err| {
+                                                std.debug.print("found a root CA cert matching issuer. verification failure: {s}\n", .{
+                                                    @errorName(err),
+                                                });
+                                            }
+                                        }
 
                                         hs_i = end_cert;
                                         const total_ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]);
                                         hs_i += 2;
                                         hs_i += total_ext_size;
-
-                                        const issuer_bytes = handshake[issuer.start..issuer.end];
-                                        const ca_cert = ca_bundle.find(issuer_bytes);
-
-                                        std.debug.print("received certificate of size {d} bytes with {d} bytes of extensions. ca_found={any}\n", .{
-                                            cert_size, total_ext_size, ca_cert != null,
-                                        });
                                     }
                                 },
                                 @enumToInt(HandshakeType.certificate_verify) => {
lib/std/crypto/CertificateBundle.zig
@@ -15,7 +15,7 @@ pub const Key = struct {
 
 /// The returned bytes become invalid after calling any of the rescan functions
 /// or add functions.
-pub fn find(cb: CertificateBundle, subject_name: []const u8) ?[]const u8 {
+pub fn find(cb: CertificateBundle, subject_name: []const u8) ?u32 {
     const Adapter = struct {
         cb: CertificateBundle,
 
@@ -29,8 +29,7 @@ pub fn find(cb: CertificateBundle, subject_name: []const u8) ?[]const u8 {
             return mem.eql(u8, a, b);
         }
     };
-    const index = cb.map.getAdapted(subject_name, Adapter{ .cb = cb }) orelse return null;
-    return cb.bytes.items[index..];
+    return cb.map.getAdapted(subject_name, Adapter{ .cb = cb });
 }
 
 pub fn deinit(cb: *CertificateBundle, gpa: Allocator) void {
@@ -105,7 +104,7 @@ pub fn addCertsFromFile(
         const decoded_start = @intCast(u32, cb.bytes.items.len);
         const dest_buf = cb.bytes.allocatedSlice()[decoded_start..];
         cb.bytes.items.len += try base64.decode(dest_buf, encoded_cert);
-        const k = try key(cb, decoded_start);
+        const k = try cb.key(decoded_start);
         const gop = try cb.map.getOrPutContext(gpa, k, .{ .cb = cb });
         if (gop.found_existing) {
             cb.bytes.items.len = decoded_start;
@@ -115,16 +114,12 @@ pub fn addCertsFromFile(
     }
 }
 
-pub fn key(cb: *CertificateBundle, bytes_index: u32) !Key {
+pub fn key(cb: CertificateBundle, bytes_index: u32) !Key {
     const bytes = cb.bytes.items;
     const certificate = try Der.parseElement(bytes, bytes_index);
     const tbs_certificate = try Der.parseElement(bytes, certificate.start);
     const version = try Der.parseElement(bytes, tbs_certificate.start);
-    if (@bitCast(u8, version.identifier) != 0xa0 or
-        !mem.eql(u8, bytes[version.start..version.end], "\x02\x01\x02"))
-    {
-        return error.UnsupportedCertificateVersion;
-    }
+    try checkVersion(bytes, version);
 
     const serial_number = try Der.parseElement(bytes, version.end);
 
@@ -144,10 +139,173 @@ pub fn key(cb: *CertificateBundle, bytes_index: u32) !Key {
     };
 }
 
+pub const Certificate = struct {
+    buffer: []const u8,
+    index: u32,
+
+    pub fn verify(subject: Certificate, issuer: Certificate) !void {
+        const subject_certificate = try Der.parseElement(subject.buffer, subject.index);
+        const subject_tbs_certificate = try Der.parseElement(subject.buffer, subject_certificate.start);
+        const subject_version = try Der.parseElement(subject.buffer, subject_tbs_certificate.start);
+        try checkVersion(subject.buffer, subject_version);
+        const subject_serial_number = try Der.parseElement(subject.buffer, subject_version.end);
+        // RFC 5280, section 4.1.2.3:
+        // "This field MUST contain the same algorithm identifier as
+        // the signatureAlgorithm field in the sequence Certificate."
+        const subject_signature = try Der.parseElement(subject.buffer, subject_serial_number.end);
+        const subject_issuer = try Der.parseElement(subject.buffer, subject_signature.end);
+        const subject_validity = try Der.parseElement(subject.buffer, subject_issuer.end);
+        //const subject_name = try Der.parseElement(subject.buffer, subject_validity.end);
+
+        const subject_sig_algo = try Der.parseElement(subject.buffer, subject_tbs_certificate.end);
+        const subject_algo_elem = try Der.parseElement(subject.buffer, subject_sig_algo.start);
+        const subject_algo = try Der.parseObjectId(subject.buffer, subject_algo_elem);
+        const subject_sig_elem = try Der.parseElement(subject.buffer, subject_sig_algo.end);
+        const subject_sig = try parseBitString(subject, subject_sig_elem);
+
+        const issuer_certificate = try Der.parseElement(issuer.buffer, issuer.index);
+        const issuer_tbs_certificate = try Der.parseElement(issuer.buffer, issuer_certificate.start);
+        const issuer_version = try Der.parseElement(issuer.buffer, issuer_tbs_certificate.start);
+        try checkVersion(issuer.buffer, issuer_version);
+        const issuer_serial_number = try Der.parseElement(issuer.buffer, issuer_version.end);
+        // RFC 5280, section 4.1.2.3:
+        // "This field MUST contain the same algorithm identifier as
+        // the signatureAlgorithm field in the sequence Certificate."
+        const issuer_signature = try Der.parseElement(issuer.buffer, issuer_serial_number.end);
+        const issuer_issuer = try Der.parseElement(issuer.buffer, issuer_signature.end);
+        const issuer_validity = try Der.parseElement(issuer.buffer, issuer_issuer.end);
+        const issuer_name = try Der.parseElement(issuer.buffer, issuer_validity.end);
+        const issuer_pub_key_info = try Der.parseElement(issuer.buffer, issuer_name.end);
+        const issuer_pub_key_signature_algorithm = try Der.parseElement(issuer.buffer, issuer_pub_key_info.start);
+        const issuer_pub_key_algo_elem = try Der.parseElement(issuer.buffer, issuer_pub_key_signature_algorithm.start);
+        const issuer_pub_key_algo = try Der.parseObjectId(issuer.buffer, issuer_pub_key_algo_elem);
+        const issuer_pub_key_elem = try Der.parseElement(issuer.buffer, issuer_pub_key_signature_algorithm.end);
+        const issuer_pub_key = try parseBitString(issuer, issuer_pub_key_elem);
+
+        // Check that the subject's issuer name matches the issuer's subject
+        // name.
+        if (!mem.eql(u8, subject.contents(subject_issuer), issuer.contents(issuer_name))) {
+            return error.CertificateIssuerMismatch;
+        }
+
+        // TODO check the time validity for the subject
+        _ = subject_validity;
+        // TODO check the time validity for the issuer
+
+        const message = subject.buffer[subject_certificate.start..subject_tbs_certificate.end];
+        //std.debug.print("issuer algo: {any} subject algo: {any}\n", .{ issuer_pub_key_algo, subject_algo });
+        switch (subject_algo) {
+            // zig fmt: off
+              .sha1WithRSAEncryption => return verifyRsa(crypto.hash.Sha1,        message, subject_sig, issuer_pub_key_algo, issuer_pub_key),
+            .sha224WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha224, message, subject_sig, issuer_pub_key_algo, issuer_pub_key),
+            .sha256WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha256, message, subject_sig, issuer_pub_key_algo, issuer_pub_key),
+            .sha384WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha384, message, subject_sig, issuer_pub_key_algo, issuer_pub_key),
+            .sha512WithRSAEncryption => return verifyRsa(crypto.hash.sha2.Sha512, message, subject_sig, issuer_pub_key_algo, issuer_pub_key),
+            // zig fmt: on
+            else => {
+                std.debug.print("unhandled algorithm: {any}\n", .{subject_algo});
+                return error.UnsupportedCertificateSignatureAlgorithm;
+            },
+        }
+    }
+
+    pub fn contents(cert: Certificate, elem: Der.Element) []const u8 {
+        return cert.buffer[elem.start..elem.end];
+    }
+
+    pub fn parseBitString(cert: Certificate, elem: Der.Element) ![]const u8 {
+        if (elem.identifier.tag != .bitstring) return error.CertificateFieldHasWrongDataType;
+        if (cert.buffer[elem.start] != 0) return error.CertificateHasInvalidBitString;
+        return cert.buffer[elem.start + 1 .. elem.end];
+    }
+
+    fn verifyRsa(comptime Hash: type, message: []const u8, sig: []const u8, pub_key_algo: Der.Oid, pub_key: []const u8) !void {
+        if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch;
+        const pub_key_seq = try Der.parseElement(pub_key, 0);
+        if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType;
+        const modulus_elem = try Der.parseElement(pub_key, pub_key_seq.start);
+        if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
+        const exponent_elem = try Der.parseElement(pub_key, modulus_elem.end);
+        if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
+        // Skip over meaningless zeroes in the modulus.
+        const modulus_raw = pub_key[modulus_elem.start..modulus_elem.end];
+        const modulus_offset = for (modulus_raw) |byte, i| {
+            if (byte != 0) break i;
+        } else modulus_raw.len;
+        const modulus = modulus_raw[modulus_offset..];
+        const exponent = pub_key[exponent_elem.start..exponent_elem.end];
+        if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid;
+        if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength;
+
+        const hash_der = switch (Hash) {
+            crypto.hash.Sha1 => [_]u8{
+                0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e,
+                0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14,
+            },
+            crypto.hash.sha2.Sha224 => [_]u8{
+                0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86,
+                0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05,
+                0x00, 0x04, 0x1c,
+            },
+            crypto.hash.sha2.Sha256 => [_]u8{
+                0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86,
+                0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05,
+                0x00, 0x04, 0x20,
+            },
+            crypto.hash.sha2.Sha384 => [_]u8{
+                0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86,
+                0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05,
+                0x00, 0x04, 0x30,
+            },
+            crypto.hash.sha2.Sha512 => [_]u8{
+                0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86,
+                0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05,
+                0x00, 0x04, 0x40,
+            },
+            else => @compileError("unreachable"),
+        };
+
+        var msg_hashed: [Hash.digest_length]u8 = undefined;
+        Hash.hash(message, &msg_hashed, .{});
+
+        switch (modulus.len) {
+            inline 128, 256, 512 => |modulus_len| {
+                const ps_len = modulus_len - (hash_der.len + msg_hashed.len) - 3;
+                const em: [modulus_len]u8 =
+                    [2]u8{ 0, 1 } ++
+                    ([1]u8{0xff} ** ps_len) ++
+                    [1]u8{0} ++
+                    hash_der ++
+                    msg_hashed;
+
+                const public_key = try rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop);
+                const em_dec = try rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key, rsa.poop);
+
+                if (!mem.eql(u8, &em, &em_dec)) {
+                    try std.testing.expectEqualSlices(u8, &em, &em_dec);
+                    return error.CertificateSignatureInvalid;
+                }
+            },
+            else => {
+                return error.CertificateSignatureUnsupportedBitCount;
+            },
+        }
+    }
+};
+
+fn checkVersion(bytes: []const u8, version: Der.Element) !void {
+    if (@bitCast(u8, version.identifier) != 0xa0 or
+        !mem.eql(u8, bytes[version.start..version.end], "\x02\x01\x02"))
+    {
+        return error.UnsupportedCertificateVersion;
+    }
+}
+
 const builtin = @import("builtin");
 const std = @import("../std.zig");
 const fs = std.fs;
 const mem = std.mem;
+const crypto = std.crypto;
 const Allocator = std.mem.Allocator;
 const Der = std.crypto.Der;
 const CertificateBundle = @This();
@@ -177,3 +335,138 @@ test {
 
     try bundle.rescan(std.testing.allocator);
 }
+
+/// TODO: replace this with Frank's upcoming RSA implementation. the verify
+/// function won't have the possibility of failure - it will either identify a
+/// valid signature or an invalid signature.
+/// This code is borrowed from https://github.com/shiguredo/tls13-zig
+/// which is licensed under the Apache License Version 2.0, January 2004
+/// http://www.apache.org/licenses/
+/// The code has been modified.
+const rsa = struct {
+    const BigInt = std.math.big.int.Managed;
+
+    const PublicKey = struct {
+        n: BigInt,
+        e: BigInt,
+
+        pub fn deinit(self: *PublicKey) void {
+            self.n.deinit();
+            self.e.deinit();
+        }
+
+        pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8, allocator: std.mem.Allocator) !PublicKey {
+            var _n = try BigInt.init(allocator);
+            errdefer _n.deinit();
+            try setBytes(&_n, modulus_bytes, allocator);
+
+            var _e = try BigInt.init(allocator);
+            errdefer _e.deinit();
+            try setBytes(&_e, pub_bytes, allocator);
+
+            return .{
+                .n = _n,
+                .e = _e,
+            };
+        }
+    };
+
+    fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey, allocator: std.mem.Allocator) ![modulus_len]u8 {
+        var m = try BigInt.init(allocator);
+        defer m.deinit();
+
+        try setBytes(&m, &msg, allocator);
+
+        if (m.order(public_key.n) != .lt) {
+            return error.MessageTooLong;
+        }
+
+        var e = try BigInt.init(allocator);
+        defer e.deinit();
+
+        try pow_montgomery(&e, &m, &public_key.e, &public_key.n, allocator);
+
+        var res: [modulus_len]u8 = undefined;
+
+        try toBytes(&res, &e, allocator);
+
+        return res;
+    }
+
+    fn setBytes(r: *BigInt, bytes: []const u8, allcator: std.mem.Allocator) !void {
+        try r.set(0);
+        var tmp = try BigInt.init(allcator);
+        defer tmp.deinit();
+        for (bytes) |b| {
+            try r.shiftLeft(r, 8);
+            try tmp.set(b);
+            try r.add(r, &tmp);
+        }
+    }
+
+    fn pow_montgomery(r: *BigInt, a: *const BigInt, x: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void {
+        var bin_raw: [512]u8 = undefined;
+        try toBytes(&bin_raw, x, allocator);
+
+        var i: usize = 0;
+        while (bin_raw[i] == 0x00) : (i += 1) {}
+        const bin = bin_raw[i..];
+
+        try r.set(1);
+        var r1 = try BigInt.init(allocator);
+        defer r1.deinit();
+        try BigInt.copy(&r1, a.toConst());
+        i = 0;
+        while (i < bin.len * 8) : (i += 1) {
+            if (((bin[i / 8] >> @intCast(u3, (7 - (i % 8)))) & 0x1) == 0) {
+                try BigInt.mul(&r1, r, &r1);
+                try mod(&r1, &r1, n, allocator);
+                try BigInt.sqr(r, r);
+                try mod(r, r, n, allocator);
+            } else {
+                try BigInt.mul(r, r, &r1);
+                try mod(r, r, n, allocator);
+                try BigInt.sqr(&r1, &r1);
+                try mod(&r1, &r1, n, allocator);
+            }
+        }
+    }
+
+    fn toBytes(out: []u8, a: *const BigInt, allocator: std.mem.Allocator) !void {
+        const Error = error{
+            BufferTooSmall,
+        };
+
+        var mask = try BigInt.initSet(allocator, 0xFF);
+        defer mask.deinit();
+        var tmp = try BigInt.init(allocator);
+        defer tmp.deinit();
+
+        var a_copy = try BigInt.init(allocator);
+        defer a_copy.deinit();
+        try a_copy.copy(a.toConst());
+
+        // Encoding into big-endian bytes
+        var i: usize = 0;
+        while (i < out.len) : (i += 1) {
+            try tmp.bitAnd(&a_copy, &mask);
+            const b = try tmp.to(u8);
+            out[out.len - i - 1] = b;
+            try a_copy.shiftRight(&a_copy, 8);
+        }
+
+        if (!a_copy.eqZero()) {
+            return Error.BufferTooSmall;
+        }
+    }
+
+    fn mod(rem: *BigInt, a: *const BigInt, n: *const BigInt, allocator: std.mem.Allocator) !void {
+        var q = try BigInt.init(allocator);
+        defer q.deinit();
+
+        try BigInt.divFloor(&q, rem, a, n);
+    }
+
+    // TODO: flush the toilet
+    const poop = std.heap.page_allocator;
+};