Commit 22e2aaa283

Andrew Kelley <andrew@ziglang.org>
2022-12-30 01:56:46
crypto.tls: support rsa_pss_rsae_sha256 and fixes
* fix eof logic * fix read logic * fix VecPut logic * add some debug prints to remove later
1 parent e4a9b19
Changed files (2)
lib
lib/std/crypto/tls/Client.zig
@@ -536,7 +536,24 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                                             try sig.verify(verify_bytes, key);
                                         },
                                         .rsa_pss_rsae_sha256 => {
-                                            @panic("TODO signature scheme: rsa_pss_rsae_sha256");
+                                            if (main_cert_pub_key_algo != .rsaEncryption)
+                                                return error.TlsBadSignatureScheme;
+
+                                            const Hash = crypto.hash.sha2.Sha256;
+                                            const rsa = Certificate.rsa;
+                                            const components = try rsa.PublicKey.parseDer(main_cert_pub_key);
+                                            const exponent = components.exponent;
+                                            const modulus = components.modulus;
+                                            switch (modulus.len) {
+                                                inline 128, 256, 512 => |modulus_len| {
+                                                    const key = try rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop);
+                                                    const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig);
+                                                    try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, rsa.poop);
+                                                },
+                                                else => {
+                                                    return error.TlsBadRsaSignatureBitCount;
+                                                },
+                                            }
                                         },
                                         else => {
                                             //std.debug.print("signature scheme: {any}\n", .{
@@ -737,7 +754,7 @@ pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void {
 }
 
 pub fn eof(c: Client) bool {
-    return c.received_close_notify and c.partial_ciphertext_end == 0;
+    return c.received_close_notify and c.partial_ciphertext_idx >= c.partial_ciphertext_end;
 }
 
 /// Returns the number of bytes read, calling the underlying read function the
@@ -822,6 +839,10 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
             c.partial_cleartext_idx = 0;
             c.partial_ciphertext_idx = 0;
             c.partial_ciphertext_end = 0;
+        } else {
+            std.debug.print("finished giving partial cleartext. {d} bytes ciphertext remain\n", .{
+                c.partial_ciphertext_end - c.partial_ciphertext_idx,
+            });
         }
     }
 
@@ -866,8 +887,9 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
 
     // There might be more bytes inside `in_stack_buffer` that need to be processed,
     // but at least frag0 will have one complete ciphertext record.
-    const frag0 = c.partially_read_buffer[0..@min(c.partially_read_buffer.len, actual_read_len)];
-    var frag1 = in_stack_buffer[0 .. actual_read_len - frag0.len];
+    const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len);
+    const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end];
+    var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len];
     // We need to decipher frag0 and frag1 but there may be a ciphertext record
     // straddling the boundary. We can handle this with two memcpy() calls to
     // assemble the straddling record in between handling the two sides.
@@ -900,12 +922,14 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
             const record_len = (record_len_byte_0 << 8) | record_len_byte_1;
             if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
 
-            const second_len = record_len + tls.ciphertext_record_header_len - first.len;
+            const full_record_len = record_len + tls.ciphertext_record_header_len;
+            const second_len = full_record_len - first.len;
             if (frag1.len < second_len)
                 return finishRead2(c, first, frag1, vp.total);
 
             mem.copy(u8, frag[0..in], first);
             mem.copy(u8, frag[first.len..], frag1[0..second_len]);
+            frag = frag[0..full_record_len];
             frag1 = frag1[second_len..];
             in = 0;
             continue;
@@ -914,23 +938,35 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
         in += 1;
         const legacy_version = mem.readIntBig(u16, frag[in..][0..2]);
         in += 2;
-        _ = legacy_version;
+        //_ = legacy_version;
         const record_len = mem.readIntBig(u16, frag[in..][0..2]);
+        std.debug.print("ct={any} legacy_version={x} record_len={d}\n", .{
+            ct, legacy_version, record_len,
+        });
         if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
         in += 2;
         const end = in + record_len;
         if (end > frag.len) {
+            // We need the record header on the next iteration of the loop.
+            in -= tls.ciphertext_record_header_len;
+
             if (frag.ptr == frag1.ptr)
                 return finishRead(c, frag, in, vp.total);
 
             // A record straddles the two fragments. Copy into the now-empty first fragment.
             const first = frag[in..];
-            const second_len = record_len + tls.ciphertext_record_header_len - first.len;
-            if (frag1.len < second_len)
+            const full_record_len = record_len + tls.ciphertext_record_header_len;
+            const second_len = full_record_len - first.len;
+            if (frag1.len < second_len) {
+                std.debug.print("end > frag.len finishRead2 end={d} frag.len={d}\n", .{
+                    end, frag.len,
+                });
                 return finishRead2(c, first, frag1, vp.total);
+            }
 
             mem.copy(u8, frag[0..in], first);
             mem.copy(u8, frag[first.len..], frag1[0..second_len]);
+            frag = frag[0..full_record_len];
             frag1 = frag1[second_len..];
             in = 0;
             continue;
@@ -991,9 +1027,11 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
                             const handshake = cleartext[ct_i..next_handshake_i];
                             switch (handshake_type) {
                                 .new_session_ticket => {
+                                    std.debug.print("new_session_ticket\n", .{});
                                     // This client implementation ignores new session tickets.
                                 },
                                 .key_update => {
+                                    std.debug.print("key_update\n", .{});
                                     switch (c.application_cipher) {
                                         inline else => |*p| {
                                             const P = @TypeOf(p.*);
@@ -1042,10 +1080,13 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
                                 const dest = c.partially_read_buffer[c.partial_ciphertext_idx..];
                                 mem.copy(u8, dest, msg);
                                 c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), c.partial_ciphertext_idx + msg.len);
+                                std.debug.print("application_data {d} bytes to partial buffer\n", .{msg.len});
                             } else {
                                 const amt = vp.put(msg);
+                                std.debug.print("application_data {d} bytes to read buffer\n", .{msg.len});
                                 if (amt < msg.len) {
                                     const rest = msg[amt..];
+                                    std.debug.print("  {d} bytes to partial buffer\n", .{rest.len});
                                     c.partial_cleartext_idx = 0;
                                     c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), rest.len);
                                     mem.copy(u8, &c.partially_read_buffer, rest);
@@ -1055,6 +1096,7 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
                             // Output buffer was used directly which means no
                             // memory copying needs to occur, and we can move
                             // on to the next ciphertext record.
+                            std.debug.print("application_data {d} bytes directly to read buffer\n", .{cleartext.len - 1});
                             vp.next(cleartext.len - 1);
                         }
                     },
@@ -1166,10 +1208,6 @@ const VecPut = struct {
             const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)];
             mem.copy(u8, dest, src);
             bytes_i += src.len;
-            if (bytes_i >= bytes.len) {
-                vp.total += bytes_i;
-                return bytes_i;
-            }
             vp.off += src.len;
             if (vp.off >= v.iov_len) {
                 vp.off = 0;
@@ -1179,6 +1217,10 @@ const VecPut = struct {
                     return bytes_i;
                 }
             }
+            if (bytes_i >= bytes.len) {
+                vp.total += bytes_i;
+                return bytes_i;
+            }
         }
     }
 
@@ -1201,17 +1243,11 @@ const VecPut = struct {
     }
 
     fn freeSize(vp: VecPut) usize {
+        if (vp.idx >= vp.iovecs.len) return 0;
         var total: usize = 0;
-
         total += vp.iovecs[vp.idx].iov_len - vp.off;
-
-        if (vp.idx + 1 >= vp.iovecs.len)
-            return total;
-
-        for (vp.iovecs[vp.idx + 1 ..]) |v| {
-            total += v.iov_len;
-        }
-
+        if (vp.idx + 1 >= vp.iovecs.len) return total;
+        for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.iov_len;
         return total;
     }
 };
lib/std/crypto/Certificate.zig
@@ -474,19 +474,9 @@ fn verifyRsa(
     pub_key: []const u8,
 ) !void {
     if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch;
-    const pub_key_seq = try der.Element.parse(pub_key, 0);
-    if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType;
-    const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start);
-    if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
-    const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end);
-    if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
-    // Skip over meaningless zeroes in the modulus.
-    const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end];
-    const modulus_offset = for (modulus_raw) |byte, i| {
-        if (byte != 0) break i;
-    } else modulus_raw.len;
-    const modulus = modulus_raw[modulus_offset..];
-    const exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end];
+    const pk_components = try rsa.PublicKey.parseDer(pub_key);
+    const exponent = pk_components.exponent;
+    const modulus = pk_components.modulus;
     if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid;
     if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength;
 
@@ -688,10 +678,154 @@ test {
 /// which is licensed under the Apache License Version 2.0, January 2004
 /// http://www.apache.org/licenses/
 /// The code has been modified.
-const rsa = struct {
+pub const rsa = struct {
     const BigInt = std.math.big.int.Managed;
 
-    const PublicKey = struct {
+    pub const PSSSignature = struct {
+        pub fn fromBytes(comptime modulus_len: usize, msg: []const u8) [modulus_len]u8 {
+            var result = [1]u8{0} ** modulus_len;
+            std.mem.copy(u8, &result, msg);
+            return result;
+        }
+
+        pub fn verify(comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const u8, public_key: PublicKey, comptime Hash: type, allocator: std.mem.Allocator) !void {
+            const mod_bits = try countBits(public_key.n.toConst(), allocator);
+            const em_dec = try encrypt(modulus_len, sig, public_key, allocator);
+
+            try EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash, allocator);
+        }
+
+        fn EMSA_PSS_VERIFY(msg: []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type, allocator: std.mem.Allocator) !void {
+            // TODO
+            // 1.   If the length of M is greater than the input limitation for
+            //      the hash function (2^61 - 1 octets for SHA-1), output
+            //      "inconsistent" and stop.
+
+            // emLen = \ceil(emBits/8)
+            const emLen = ((emBit - 1) / 8) + 1;
+            std.debug.assert(emLen == em.len);
+
+            // 2.   Let mHash = Hash(M), an octet string of length hLen.
+            var mHash: [Hash.digest_length]u8 = undefined;
+            Hash.hash(msg, &mHash, .{});
+
+            // 3.   If emLen < hLen + sLen + 2, output "inconsistent" and stop.
+            if (emLen < Hash.digest_length + sLen + 2) {
+                return error.InvalidSignature;
+            }
+
+            // 4.   If the rightmost octet of EM does not have hexadecimal value
+            //      0xbc, output "inconsistent" and stop.
+            if (em[em.len - 1] != 0xbc) {
+                return error.InvalidSignature;
+            }
+
+            // 5.   Let maskedDB be the leftmost emLen - hLen - 1 octets of EM,
+            //      and let H be the next hLen octets.
+            const maskedDB = em[0..(emLen - Hash.digest_length - 1)];
+            const h = em[(emLen - Hash.digest_length - 1)..(emLen - 1)];
+
+            // 6.   If the leftmost 8emLen - emBits bits of the leftmost octet in
+            //      maskedDB are not all equal to zero, output "inconsistent" and
+            //      stop.
+            const zero_bits = emLen * 8 - emBit;
+            var mask: u8 = maskedDB[0];
+            var i: usize = 0;
+            while (i < 8 - zero_bits) : (i += 1) {
+                mask = mask >> 1;
+            }
+            if (mask != 0) {
+                return error.InvalidSignature;
+            }
+
+            // 7.   Let dbMask = MGF(H, emLen - hLen - 1).
+            const mgf_len = emLen - Hash.digest_length - 1;
+            var mgf_out = try allocator.alloc(u8, ((mgf_len - 1) / Hash.digest_length + 1) * Hash.digest_length);
+            defer allocator.free(mgf_out);
+            var dbMask = try MGF1(mgf_out, h, mgf_len, Hash, allocator);
+
+            // 8.   Let DB = maskedDB \xor dbMask.
+            i = 0;
+            while (i < dbMask.len) : (i += 1) {
+                dbMask[i] = maskedDB[i] ^ dbMask[i];
+            }
+
+            // 9.   Set the leftmost 8emLen - emBits bits of the leftmost octet
+            //      in DB to zero.
+            i = 0;
+            mask = 0;
+            while (i < 8 - zero_bits) : (i += 1) {
+                mask = mask << 1;
+                mask += 1;
+            }
+            dbMask[0] = dbMask[0] & mask;
+
+            // 10.  If the emLen - hLen - sLen - 2 leftmost octets of DB are not
+            //      zero or if the octet at position emLen - hLen - sLen - 1 (the
+            //      leftmost position is "position 1") does not have hexadecimal
+            //      value 0x01, output "inconsistent" and stop.
+            if (dbMask[mgf_len - sLen - 2] != 0x00) {
+                return error.InvalidSignature;
+            }
+
+            if (dbMask[mgf_len - sLen - 1] != 0x01) {
+                return error.InvalidSignature;
+            }
+
+            // 11.  Let salt be the last sLen octets of DB.
+            const salt = dbMask[(mgf_len - sLen)..];
+
+            // 12.  Let
+            //         M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
+            //      M' is an octet string of length 8 + hLen + sLen with eight
+            //      initial zero octets.
+            var m_p = try allocator.alloc(u8, 8 + Hash.digest_length + sLen);
+            defer allocator.free(m_p);
+            std.mem.copy(u8, m_p, &([_]u8{0} ** 8));
+            std.mem.copy(u8, m_p[8..], &mHash);
+            std.mem.copy(u8, m_p[(8 + Hash.digest_length)..], salt);
+
+            // 13.  Let H' = Hash(M'), an octet string of length hLen.
+            var h_p: [Hash.digest_length]u8 = undefined;
+            Hash.hash(m_p, &h_p, .{});
+
+            // 14.  If H = H', output "consistent".  Otherwise, output
+            //      "inconsistent".
+            if (!std.mem.eql(u8, h, &h_p)) {
+                return error.InvalidSignature;
+            }
+        }
+
+        fn MGF1(out: []u8, seed: []const u8, len: usize, comptime Hash: type, allocator: std.mem.Allocator) ![]u8 {
+            var counter: usize = 0;
+            var idx: usize = 0;
+            var c: [4]u8 = undefined;
+
+            var hash = try allocator.alloc(u8, seed.len + c.len);
+            defer allocator.free(hash);
+            std.mem.copy(u8, hash, seed);
+            var hashed: [Hash.digest_length]u8 = undefined;
+
+            while (idx < len) {
+                c[0] = @intCast(u8, (counter >> 24) & 0xFF);
+                c[1] = @intCast(u8, (counter >> 16) & 0xFF);
+                c[2] = @intCast(u8, (counter >> 8) & 0xFF);
+                c[3] = @intCast(u8, counter & 0xFF);
+
+                std.mem.copy(u8, hash[seed.len..], &c);
+                Hash.hash(hash, &hashed, .{});
+
+                std.mem.copy(u8, out[idx..], &hashed);
+                idx += hashed.len;
+
+                counter += 1;
+            }
+
+            return out[0..len];
+        }
+    };
+
+    pub const PublicKey = struct {
         n: BigInt,
         e: BigInt,
 
@@ -714,6 +848,24 @@ const rsa = struct {
                 .e = _e,
             };
         }
+
+        pub fn parseDer(pub_key: []const u8) !struct { modulus: []const u8, exponent: []const u8 } {
+            const pub_key_seq = try der.Element.parse(pub_key, 0);
+            if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType;
+            const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start);
+            if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
+            const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end);
+            if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
+            // Skip over meaningless zeroes in the modulus.
+            const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end];
+            const modulus_offset = for (modulus_raw) |byte, i| {
+                if (byte != 0) break i;
+            } else modulus_raw.len;
+            return .{
+                .modulus = modulus_raw[modulus_offset..],
+                .exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end],
+            };
+        }
     };
 
     fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey, allocator: std.mem.Allocator) ![modulus_len]u8 {
@@ -812,6 +964,20 @@ const rsa = struct {
         try BigInt.divFloor(&q, rem, a, n);
     }
 
+    fn countBits(a: std.math.big.int.Const, allocator: std.mem.Allocator) !usize {
+        var i: usize = 0;
+        var a_copy = try BigInt.init(allocator);
+        defer a_copy.deinit();
+        try a_copy.copy(a);
+
+        while (!a_copy.eqZero()) {
+            try a_copy.shiftRight(&a_copy, 1);
+            i += 1;
+        }
+
+        return i;
+    }
+
     // TODO: flush the toilet
-    const poop = std.heap.page_allocator;
+    pub const poop = std.heap.page_allocator;
 };