Commit 22e2aaa283
Changed files (2)
lib
std
crypto
lib/std/crypto/tls/Client.zig
@@ -536,7 +536,24 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
try sig.verify(verify_bytes, key);
},
.rsa_pss_rsae_sha256 => {
- @panic("TODO signature scheme: rsa_pss_rsae_sha256");
+ if (main_cert_pub_key_algo != .rsaEncryption)
+ return error.TlsBadSignatureScheme;
+
+ const Hash = crypto.hash.sha2.Sha256;
+ const rsa = Certificate.rsa;
+ const components = try rsa.PublicKey.parseDer(main_cert_pub_key);
+ const exponent = components.exponent;
+ const modulus = components.modulus;
+ switch (modulus.len) {
+ inline 128, 256, 512 => |modulus_len| {
+ const key = try rsa.PublicKey.fromBytes(exponent, modulus, rsa.poop);
+ const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig);
+ try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash, rsa.poop);
+ },
+ else => {
+ return error.TlsBadRsaSignatureBitCount;
+ },
+ }
},
else => {
//std.debug.print("signature scheme: {any}\n", .{
@@ -737,7 +754,7 @@ pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void {
}
pub fn eof(c: Client) bool {
- return c.received_close_notify and c.partial_ciphertext_end == 0;
+ return c.received_close_notify and c.partial_ciphertext_idx >= c.partial_ciphertext_end;
}
/// Returns the number of bytes read, calling the underlying read function the
@@ -822,6 +839,10 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
c.partial_cleartext_idx = 0;
c.partial_ciphertext_idx = 0;
c.partial_ciphertext_end = 0;
+ } else {
+ std.debug.print("finished giving partial cleartext. {d} bytes ciphertext remain\n", .{
+ c.partial_ciphertext_end - c.partial_ciphertext_idx,
+ });
}
}
@@ -866,8 +887,9 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
// There might be more bytes inside `in_stack_buffer` that need to be processed,
// but at least frag0 will have one complete ciphertext record.
- const frag0 = c.partially_read_buffer[0..@min(c.partially_read_buffer.len, actual_read_len)];
- var frag1 = in_stack_buffer[0 .. actual_read_len - frag0.len];
+ const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len);
+ const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end];
+ var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len];
// We need to decipher frag0 and frag1 but there may be a ciphertext record
// straddling the boundary. We can handle this with two memcpy() calls to
// assemble the straddling record in between handling the two sides.
@@ -900,12 +922,14 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
const record_len = (record_len_byte_0 << 8) | record_len_byte_1;
if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
- const second_len = record_len + tls.ciphertext_record_header_len - first.len;
+ const full_record_len = record_len + tls.ciphertext_record_header_len;
+ const second_len = full_record_len - first.len;
if (frag1.len < second_len)
return finishRead2(c, first, frag1, vp.total);
mem.copy(u8, frag[0..in], first);
mem.copy(u8, frag[first.len..], frag1[0..second_len]);
+ frag = frag[0..full_record_len];
frag1 = frag1[second_len..];
in = 0;
continue;
@@ -914,23 +938,35 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
in += 1;
const legacy_version = mem.readIntBig(u16, frag[in..][0..2]);
in += 2;
- _ = legacy_version;
+ //_ = legacy_version;
const record_len = mem.readIntBig(u16, frag[in..][0..2]);
+ std.debug.print("ct={any} legacy_version={x} record_len={d}\n", .{
+ ct, legacy_version, record_len,
+ });
if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
in += 2;
const end = in + record_len;
if (end > frag.len) {
+ // We need the record header on the next iteration of the loop.
+ in -= tls.ciphertext_record_header_len;
+
if (frag.ptr == frag1.ptr)
return finishRead(c, frag, in, vp.total);
// A record straddles the two fragments. Copy into the now-empty first fragment.
const first = frag[in..];
- const second_len = record_len + tls.ciphertext_record_header_len - first.len;
- if (frag1.len < second_len)
+ const full_record_len = record_len + tls.ciphertext_record_header_len;
+ const second_len = full_record_len - first.len;
+ if (frag1.len < second_len) {
+ std.debug.print("end > frag.len finishRead2 end={d} frag.len={d}\n", .{
+ end, frag.len,
+ });
return finishRead2(c, first, frag1, vp.total);
+ }
mem.copy(u8, frag[0..in], first);
mem.copy(u8, frag[first.len..], frag1[0..second_len]);
+ frag = frag[0..full_record_len];
frag1 = frag1[second_len..];
in = 0;
continue;
@@ -991,9 +1027,11 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
const handshake = cleartext[ct_i..next_handshake_i];
switch (handshake_type) {
.new_session_ticket => {
+ std.debug.print("new_session_ticket\n", .{});
// This client implementation ignores new session tickets.
},
.key_update => {
+ std.debug.print("key_update\n", .{});
switch (c.application_cipher) {
inline else => |*p| {
const P = @TypeOf(p.*);
@@ -1042,10 +1080,13 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
const dest = c.partially_read_buffer[c.partial_ciphertext_idx..];
mem.copy(u8, dest, msg);
c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), c.partial_ciphertext_idx + msg.len);
+ std.debug.print("application_data {d} bytes to partial buffer\n", .{msg.len});
} else {
const amt = vp.put(msg);
+ std.debug.print("application_data {d} bytes to read buffer\n", .{msg.len});
if (amt < msg.len) {
const rest = msg[amt..];
+ std.debug.print(" {d} bytes to partial buffer\n", .{rest.len});
c.partial_cleartext_idx = 0;
c.partial_ciphertext_idx = @intCast(@TypeOf(c.partial_ciphertext_idx), rest.len);
mem.copy(u8, &c.partially_read_buffer, rest);
@@ -1055,6 +1096,7 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
// Output buffer was used directly which means no
// memory copying needs to occur, and we can move
// on to the next ciphertext record.
+ std.debug.print("application_data {d} bytes directly to read buffer\n", .{cleartext.len - 1});
vp.next(cleartext.len - 1);
}
},
@@ -1166,10 +1208,6 @@ const VecPut = struct {
const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)];
mem.copy(u8, dest, src);
bytes_i += src.len;
- if (bytes_i >= bytes.len) {
- vp.total += bytes_i;
- return bytes_i;
- }
vp.off += src.len;
if (vp.off >= v.iov_len) {
vp.off = 0;
@@ -1179,6 +1217,10 @@ const VecPut = struct {
return bytes_i;
}
}
+ if (bytes_i >= bytes.len) {
+ vp.total += bytes_i;
+ return bytes_i;
+ }
}
}
@@ -1201,17 +1243,11 @@ const VecPut = struct {
}
fn freeSize(vp: VecPut) usize {
+ if (vp.idx >= vp.iovecs.len) return 0;
var total: usize = 0;
-
total += vp.iovecs[vp.idx].iov_len - vp.off;
-
- if (vp.idx + 1 >= vp.iovecs.len)
- return total;
-
- for (vp.iovecs[vp.idx + 1 ..]) |v| {
- total += v.iov_len;
- }
-
+ if (vp.idx + 1 >= vp.iovecs.len) return total;
+ for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.iov_len;
return total;
}
};
lib/std/crypto/Certificate.zig
@@ -474,19 +474,9 @@ fn verifyRsa(
pub_key: []const u8,
) !void {
if (pub_key_algo != .rsaEncryption) return error.CertificateSignatureAlgorithmMismatch;
- const pub_key_seq = try der.Element.parse(pub_key, 0);
- if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType;
- const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start);
- if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
- const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end);
- if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
- // Skip over meaningless zeroes in the modulus.
- const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end];
- const modulus_offset = for (modulus_raw) |byte, i| {
- if (byte != 0) break i;
- } else modulus_raw.len;
- const modulus = modulus_raw[modulus_offset..];
- const exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end];
+ const pk_components = try rsa.PublicKey.parseDer(pub_key);
+ const exponent = pk_components.exponent;
+ const modulus = pk_components.modulus;
if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid;
if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength;
@@ -688,10 +678,154 @@ test {
/// which is licensed under the Apache License Version 2.0, January 2004
/// http://www.apache.org/licenses/
/// The code has been modified.
-const rsa = struct {
+pub const rsa = struct {
const BigInt = std.math.big.int.Managed;
- const PublicKey = struct {
+ pub const PSSSignature = struct {
+ pub fn fromBytes(comptime modulus_len: usize, msg: []const u8) [modulus_len]u8 {
+ var result = [1]u8{0} ** modulus_len;
+ std.mem.copy(u8, &result, msg);
+ return result;
+ }
+
+ pub fn verify(comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const u8, public_key: PublicKey, comptime Hash: type, allocator: std.mem.Allocator) !void {
+ const mod_bits = try countBits(public_key.n.toConst(), allocator);
+ const em_dec = try encrypt(modulus_len, sig, public_key, allocator);
+
+ try EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash, allocator);
+ }
+
+ fn EMSA_PSS_VERIFY(msg: []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type, allocator: std.mem.Allocator) !void {
+ // TODO
+ // 1. If the length of M is greater than the input limitation for
+ // the hash function (2^61 - 1 octets for SHA-1), output
+ // "inconsistent" and stop.
+
+ // emLen = \ceil(emBits/8)
+ const emLen = ((emBit - 1) / 8) + 1;
+ std.debug.assert(emLen == em.len);
+
+ // 2. Let mHash = Hash(M), an octet string of length hLen.
+ var mHash: [Hash.digest_length]u8 = undefined;
+ Hash.hash(msg, &mHash, .{});
+
+ // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop.
+ if (emLen < Hash.digest_length + sLen + 2) {
+ return error.InvalidSignature;
+ }
+
+ // 4. If the rightmost octet of EM does not have hexadecimal value
+ // 0xbc, output "inconsistent" and stop.
+ if (em[em.len - 1] != 0xbc) {
+ return error.InvalidSignature;
+ }
+
+ // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM,
+ // and let H be the next hLen octets.
+ const maskedDB = em[0..(emLen - Hash.digest_length - 1)];
+ const h = em[(emLen - Hash.digest_length - 1)..(emLen - 1)];
+
+ // 6. If the leftmost 8emLen - emBits bits of the leftmost octet in
+ // maskedDB are not all equal to zero, output "inconsistent" and
+ // stop.
+ const zero_bits = emLen * 8 - emBit;
+ var mask: u8 = maskedDB[0];
+ var i: usize = 0;
+ while (i < 8 - zero_bits) : (i += 1) {
+ mask = mask >> 1;
+ }
+ if (mask != 0) {
+ return error.InvalidSignature;
+ }
+
+ // 7. Let dbMask = MGF(H, emLen - hLen - 1).
+ const mgf_len = emLen - Hash.digest_length - 1;
+ var mgf_out = try allocator.alloc(u8, ((mgf_len - 1) / Hash.digest_length + 1) * Hash.digest_length);
+ defer allocator.free(mgf_out);
+ var dbMask = try MGF1(mgf_out, h, mgf_len, Hash, allocator);
+
+ // 8. Let DB = maskedDB \xor dbMask.
+ i = 0;
+ while (i < dbMask.len) : (i += 1) {
+ dbMask[i] = maskedDB[i] ^ dbMask[i];
+ }
+
+ // 9. Set the leftmost 8emLen - emBits bits of the leftmost octet
+ // in DB to zero.
+ i = 0;
+ mask = 0;
+ while (i < 8 - zero_bits) : (i += 1) {
+ mask = mask << 1;
+ mask += 1;
+ }
+ dbMask[0] = dbMask[0] & mask;
+
+ // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not
+ // zero or if the octet at position emLen - hLen - sLen - 1 (the
+ // leftmost position is "position 1") does not have hexadecimal
+ // value 0x01, output "inconsistent" and stop.
+ if (dbMask[mgf_len - sLen - 2] != 0x00) {
+ return error.InvalidSignature;
+ }
+
+ if (dbMask[mgf_len - sLen - 1] != 0x01) {
+ return error.InvalidSignature;
+ }
+
+ // 11. Let salt be the last sLen octets of DB.
+ const salt = dbMask[(mgf_len - sLen)..];
+
+ // 12. Let
+ // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
+ // M' is an octet string of length 8 + hLen + sLen with eight
+ // initial zero octets.
+ var m_p = try allocator.alloc(u8, 8 + Hash.digest_length + sLen);
+ defer allocator.free(m_p);
+ std.mem.copy(u8, m_p, &([_]u8{0} ** 8));
+ std.mem.copy(u8, m_p[8..], &mHash);
+ std.mem.copy(u8, m_p[(8 + Hash.digest_length)..], salt);
+
+ // 13. Let H' = Hash(M'), an octet string of length hLen.
+ var h_p: [Hash.digest_length]u8 = undefined;
+ Hash.hash(m_p, &h_p, .{});
+
+ // 14. If H = H', output "consistent". Otherwise, output
+ // "inconsistent".
+ if (!std.mem.eql(u8, h, &h_p)) {
+ return error.InvalidSignature;
+ }
+ }
+
+ fn MGF1(out: []u8, seed: []const u8, len: usize, comptime Hash: type, allocator: std.mem.Allocator) ![]u8 {
+ var counter: usize = 0;
+ var idx: usize = 0;
+ var c: [4]u8 = undefined;
+
+ var hash = try allocator.alloc(u8, seed.len + c.len);
+ defer allocator.free(hash);
+ std.mem.copy(u8, hash, seed);
+ var hashed: [Hash.digest_length]u8 = undefined;
+
+ while (idx < len) {
+ c[0] = @intCast(u8, (counter >> 24) & 0xFF);
+ c[1] = @intCast(u8, (counter >> 16) & 0xFF);
+ c[2] = @intCast(u8, (counter >> 8) & 0xFF);
+ c[3] = @intCast(u8, counter & 0xFF);
+
+ std.mem.copy(u8, hash[seed.len..], &c);
+ Hash.hash(hash, &hashed, .{});
+
+ std.mem.copy(u8, out[idx..], &hashed);
+ idx += hashed.len;
+
+ counter += 1;
+ }
+
+ return out[0..len];
+ }
+ };
+
+ pub const PublicKey = struct {
n: BigInt,
e: BigInt,
@@ -714,6 +848,24 @@ const rsa = struct {
.e = _e,
};
}
+
+ pub fn parseDer(pub_key: []const u8) !struct { modulus: []const u8, exponent: []const u8 } {
+ const pub_key_seq = try der.Element.parse(pub_key, 0);
+ if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType;
+ const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start);
+ if (modulus_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
+ const exponent_elem = try der.Element.parse(pub_key, modulus_elem.slice.end);
+ if (exponent_elem.identifier.tag != .integer) return error.CertificateFieldHasWrongDataType;
+ // Skip over meaningless zeroes in the modulus.
+ const modulus_raw = pub_key[modulus_elem.slice.start..modulus_elem.slice.end];
+ const modulus_offset = for (modulus_raw) |byte, i| {
+ if (byte != 0) break i;
+ } else modulus_raw.len;
+ return .{
+ .modulus = modulus_raw[modulus_offset..],
+ .exponent = pub_key[exponent_elem.slice.start..exponent_elem.slice.end],
+ };
+ }
};
fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey, allocator: std.mem.Allocator) ![modulus_len]u8 {
@@ -812,6 +964,20 @@ const rsa = struct {
try BigInt.divFloor(&q, rem, a, n);
}
+ fn countBits(a: std.math.big.int.Const, allocator: std.mem.Allocator) !usize {
+ var i: usize = 0;
+ var a_copy = try BigInt.init(allocator);
+ defer a_copy.deinit();
+ try a_copy.copy(a);
+
+ while (!a_copy.eqZero()) {
+ try a_copy.shiftRight(&a_copy, 1);
+ i += 1;
+ }
+
+ return i;
+ }
+
// TODO: flush the toilet
- const poop = std.heap.page_allocator;
+ pub const poop = std.heap.page_allocator;
};