Commit 1ab008d89d

Frank Denis <124872+jedisct1@users.noreply.github.com>
2023-05-30 12:06:44
RSA: remove usage of allocators (#15901)
Individual max buffer sizes are well known, now that arithmetic doesn't require allocations any more. Also bump `main_cert_pub_key_buf`, so that e.g. `nodejs.org` public keys can fit.
1 parent 9244e4f
Changed files (2)
lib
lib/std/crypto/tls/Client.zig
@@ -424,7 +424,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
     var handshake_state: HandshakeState = .encrypted_extensions;
     var cleartext_bufs: [2][8000]u8 = undefined;
     var main_cert_pub_key_algo: Certificate.AlgorithmCategory = undefined;
-    var main_cert_pub_key_buf: [300]u8 = undefined;
+    var main_cert_pub_key_buf: [600]u8 = undefined;
     var main_cert_pub_key_len: u16 = undefined;
     const now_sec = std.time.timestamp();
 
@@ -602,14 +602,11 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
                                     const components = try rsa.PublicKey.parseDer(main_cert_pub_key);
                                     const exponent = components.exponent;
                                     const modulus = components.modulus;
-                                    var rsa_mem_buf: [512 * 32]u8 = undefined;
-                                    var fba = std.heap.FixedBufferAllocator.init(&rsa_mem_buf);
-                                    const ally = fba.allocator();
                                     switch (modulus.len) {
                                         inline 128, 256, 512 => |modulus_len| {
                                             const key = try rsa.PublicKey.fromBytes(exponent, modulus);
                                             const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig);
-                                            try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, ally);
+                                            try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash);
                                         },
                                         else => {
                                             return error.TlsBadRsaSignatureBitCount;
lib/std/crypto/Certificate.zig
@@ -917,18 +917,20 @@ pub const rsa = struct {
             return result;
         }
 
-        pub fn verify(comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const u8, public_key: PublicKey, comptime Hash: type, allocator: std.mem.Allocator) !void {
+        pub fn verify(comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const u8, public_key: PublicKey, comptime Hash: type) !void {
             const mod_bits = public_key.n.bits();
             const em_dec = try encrypt(modulus_len, sig, public_key);
 
-            EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash, allocator) catch unreachable;
+            EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash) catch unreachable;
         }
 
-        fn EMSA_PSS_VERIFY(msg: []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type, allocator: std.mem.Allocator) !void {
-            // TODO
+        fn EMSA_PSS_VERIFY(msg: []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type) !void {
             // 1.   If the length of M is greater than the input limitation for
             //      the hash function (2^61 - 1 octets for SHA-1), output
             //      "inconsistent" and stop.
+            // All the cryptographic hash functions in the standard library have a limit of >= 2^61 - 1.
+            // Even then, this check is only there for paranoia. In the context of TLS certifcates, emBit cannot exceed 4096.
+            if (emBit >= 1 << 61) return error.InvalidSignature;
 
             // emLen = \ceil(emBits/8)
             const emLen = ((emBit - 1) / 8) + 1;
@@ -952,7 +954,7 @@ pub const rsa = struct {
             // 5.   Let maskedDB be the leftmost emLen - hLen - 1 octets of EM,
             //      and let H be the next hLen octets.
             const maskedDB = em[0..(emLen - Hash.digest_length - 1)];
-            const h = em[(emLen - Hash.digest_length - 1)..(emLen - 1)];
+            const h = em[(emLen - Hash.digest_length - 1)..(emLen - 1)][0..Hash.digest_length];
 
             // 6.   If the leftmost 8emLen - emBits bits of the leftmost octet in
             //      maskedDB are not all equal to zero, output "inconsistent" and
@@ -969,9 +971,12 @@ pub const rsa = struct {
 
             // 7.   Let dbMask = MGF(H, emLen - hLen - 1).
             const mgf_len = emLen - Hash.digest_length - 1;
-            var mgf_out = try allocator.alloc(u8, ((mgf_len - 1) / Hash.digest_length + 1) * Hash.digest_length);
-            defer allocator.free(mgf_out);
-            var dbMask = try MGF1(mgf_out, h, mgf_len, Hash, allocator);
+            var mgf_out_buf: [512]u8 = undefined;
+            if (mgf_len > mgf_out_buf.len) { // Modulus > 4096 bits
+                return error.InvalidSignature;
+            }
+            var mgf_out = mgf_out_buf[0 .. ((mgf_len - 1) / Hash.digest_length + 1) * Hash.digest_length];
+            var dbMask = try MGF1(Hash, mgf_out, h, mgf_len);
 
             // 8.   Let DB = maskedDB \xor dbMask.
             i = 0;
@@ -1008,8 +1013,11 @@ pub const rsa = struct {
             //         M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
             //      M' is an octet string of length 8 + hLen + sLen with eight
             //      initial zero octets.
-            var m_p = try allocator.alloc(u8, 8 + Hash.digest_length + sLen);
-            defer allocator.free(m_p);
+            if (sLen > Hash.digest_length) { // A seed larger than the hash length would be useless
+                return error.InvalidSignature;
+            }
+            var m_p_buf: [8 + Hash.digest_length + Hash.digest_length]u8 = undefined;
+            var m_p = m_p_buf[0 .. 8 + Hash.digest_length + sLen];
             std.mem.copyForwards(u8, m_p, &([_]u8{0} ** 8));
             std.mem.copyForwards(u8, m_p[8..], &mHash);
             std.mem.copyForwards(u8, m_p[(8 + Hash.digest_length)..], salt);
@@ -1025,14 +1033,12 @@ pub const rsa = struct {
             }
         }
 
-        fn MGF1(out: []u8, seed: []const u8, len: usize, comptime Hash: type, allocator: std.mem.Allocator) ![]u8 {
+        fn MGF1(comptime Hash: type, out: []u8, seed: *const [Hash.digest_length]u8, len: usize) ![]u8 {
             var counter: usize = 0;
             var idx: usize = 0;
             var c: [4]u8 = undefined;
-
-            var hash = try allocator.alloc(u8, seed.len + c.len);
-            defer allocator.free(hash);
-            std.mem.copyForwards(u8, hash, seed);
+            var hash: [Hash.digest_length + c.len]u8 = undefined;
+            @memcpy(hash[0..Hash.digest_length], seed);
             var hashed: [Hash.digest_length]u8 = undefined;
 
             while (idx < len) {
@@ -1042,7 +1048,7 @@ pub const rsa = struct {
                 c[3] = @intCast(u8, counter & 0xFF);
 
                 std.mem.copyForwards(u8, hash[seed.len..], &c);
-                Hash.hash(hash, &hashed, .{});
+                Hash.hash(&hash, &hashed, .{});
 
                 std.mem.copyForwards(u8, out[idx..], &hashed);
                 idx += hashed.len;