Commit 595fff7cb6
Changed files (1)
lib
std
crypto
lib/std/crypto/Tls.zig
@@ -234,7 +234,12 @@ const cipher_suites = blk: {
pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
assert(tls.state == .start);
crypto.random.bytes(&tls.x25519_priv_key);
- tls.x25519_pub_key = try crypto.dh.X25519.recoverPublicKey(tls.x25519_priv_key);
+ tls.x25519_pub_key = crypto.dh.X25519.recoverPublicKey(tls.x25519_priv_key) catch |err| {
+ switch (err) {
+ // Only possible to happen if the private key is all zeroes.
+ error.IdentityElement => return error.InsufficientEntropy,
+ }
+ };
// random (u32)
var rand_buf: [32]u8 = undefined;
@@ -337,6 +342,14 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
};
try stream.writevAll(&iovecs);
+ const client_hello_bytes1 = hello_header[5..];
+
+ var client_handshake_key: [32]u8 = undefined;
+ var server_handshake_key: [32]u8 = undefined;
+ var client_handshake_iv: [12]u8 = undefined;
+ var server_handshake_iv: [12]u8 = undefined;
+ var cipher_suite: CipherSuite = undefined;
+
var handshake_buf: [4000]u8 = undefined;
var len: usize = 0;
var i: usize = i: {
@@ -373,7 +386,7 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
const legacy_session_id_echo_len = hello[34];
if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter;
const cipher_suite_int = mem.readIntBig(u16, hello[35..37]);
- const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
+ cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
return error.TlsIllegalParameter;
std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)});
const legacy_compression_method = hello[37];
@@ -404,12 +417,7 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
const key_size = mem.readIntBig(u16, hello[i..][0..2]);
i += 2;
if (key_size != 32) return error.TlsBadLength;
- const encrypted_key = hello[i..][0..32].*;
- const server_pub_key = try crypto.dh.X25519.scalarmult(
- tls.x25519_priv_key,
- encrypted_key,
- );
- tls.x25519_server_pub_key = server_pub_key;
+ tls.x25519_server_pub_key = hello[i..][0..32].*;
have_server_pub_key = true;
},
else => {
@@ -435,12 +443,77 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
},
else => return error.TlsIllegalParameter,
}
+
+ const shared_key = crypto.dh.X25519.scalarmult(
+ tls.x25519_priv_key,
+ tls.x25519_server_pub_key,
+ ) catch return error.TlsDecryptFailure;
+
+ switch (cipher_suite) {
+ .TLS_AES_128_GCM_SHA256 => {
+ const AEAD = crypto.aead.aes_gcm.Aes128Gcm;
+ const Hash = crypto.hash.sha2.Sha256;
+ const Hmac = crypto.auth.hmac.Hmac(Hash);
+ const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
+
+ const hello_hash = helloHash(client_hello_bytes1, host, frag, Hash);
+ const early_secret = Hkdf.extract(&[1]u8{0}, &([1]u8{0} ** Hash.digest_length));
+ const empty_hash = emptyHash(Hash);
+ const derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length);
+ const handshake_secret = Hkdf.extract(&derived_secret, &shared_key);
+ const client_secret = hkdfExpandLabel(Hkdf, handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length);
+ const server_secret = hkdfExpandLabel(Hkdf, handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length);
+ client_handshake_key[0..AEAD.key_length].* = hkdfExpandLabel(Hkdf, client_secret, "key", "", AEAD.key_length);
+ server_handshake_key[0..AEAD.key_length].* = hkdfExpandLabel(Hkdf, server_secret, "key", "", AEAD.key_length);
+ client_handshake_iv = hkdfExpandLabel(Hkdf, client_secret, "iv", "", AEAD.nonce_length);
+ server_handshake_iv = hkdfExpandLabel(Hkdf, server_secret, "iv", "", AEAD.nonce_length);
+ //std.debug.print("shared_key: {}\nhello_hash: {}\nearly_secret: {}\nempty_hash: {}\nderived_secret: {}\nhandshake_secret: {}\n client_secret: {}\n server_secret: {}\n", .{
+ // std.fmt.fmtSliceHexLower(&shared_key),
+ // std.fmt.fmtSliceHexLower(&hello_hash),
+ // std.fmt.fmtSliceHexLower(&early_secret),
+ // std.fmt.fmtSliceHexLower(&empty_hash),
+ // std.fmt.fmtSliceHexLower(&derived_secret),
+ // std.fmt.fmtSliceHexLower(&handshake_secret),
+ // std.fmt.fmtSliceHexLower(&client_secret),
+ // std.fmt.fmtSliceHexLower(&server_secret),
+ //});
+ },
+ .TLS_AES_256_GCM_SHA384 => {
+ const AEAD = crypto.aead.aes_gcm.Aes256Gcm;
+ const Hash = crypto.hash.sha2.Sha384;
+ const Hmac = crypto.auth.hmac.Hmac(Hash);
+ const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
+
+ const hello_hash = helloHash(client_hello_bytes1, host, frag, Hash);
+ const early_secret = Hkdf.extract(&[1]u8{0}, &([1]u8{0} ** Hash.digest_length));
+ const empty_hash = emptyHash(Hash);
+ const derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length);
+ const handshake_secret = Hkdf.extract(&derived_secret, &shared_key);
+ const client_secret = hkdfExpandLabel(Hkdf, handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length);
+ const server_secret = hkdfExpandLabel(Hkdf, handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length);
+ client_handshake_key = hkdfExpandLabel(Hkdf, client_secret, "key", "", AEAD.key_length);
+ server_handshake_key = hkdfExpandLabel(Hkdf, server_secret, "key", "", AEAD.key_length);
+ client_handshake_iv = hkdfExpandLabel(Hkdf, client_secret, "iv", "", AEAD.nonce_length);
+ server_handshake_iv = hkdfExpandLabel(Hkdf, server_secret, "iv", "", AEAD.nonce_length);
+ },
+ .TLS_CHACHA20_POLY1305_SHA256 => {
+ @panic("TODO");
+ },
+ .TLS_AES_128_CCM_SHA256 => {
+ @panic("TODO");
+ },
+ .TLS_AES_128_CCM_8_SHA256 => {
+ @panic("TODO");
+ },
+ }
},
else => return error.TlsUnexpectedMessage,
}
break :i end;
};
+ var read_seq: u64 = 0;
+
while (true) {
const end_hdr = i + 5;
if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow;
@@ -467,7 +540,88 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage;
},
.application_data => {
- std.debug.print("TODO: decrypt these {d} bytes\n", .{record_size});
+ var cleartext_buf: [1000]u8 = undefined;
+ const cleartext = switch (cipher_suite) {
+ .TLS_AES_128_GCM_SHA256 => c: {
+ const AEAD = crypto.aead.aes_gcm.Aes128Gcm;
+ const ciphertext_len = record_size - AEAD.tag_length;
+ const ciphertext = handshake_buf[i..][0..ciphertext_len];
+ i += ciphertext.len;
+ if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow;
+ const cleartext = cleartext_buf[0..ciphertext.len];
+ const auth_tag = handshake_buf[i..][0..AEAD.tag_length].*;
+ const V = @Vector(AEAD.nonce_length, u8);
+ const pad = [1]u8{0} ** (AEAD.nonce_length - 8);
+ const operand: V = pad ++ @bitCast([8]u8, big(read_seq));
+ read_seq += 1;
+ const nonce: [AEAD.nonce_length]u8 = @as(V, server_handshake_iv) ^ operand;
+ //std.debug.print("seq: {d} nonce: {} operand: {}\n", .{
+ // read_seq - 1,
+ // std.fmt.fmtSliceHexLower(&nonce),
+ // std.fmt.fmtSliceHexLower(&@as([12]u8, operand)),
+ //});
+ const ad = handshake_buf[end_hdr - 5 ..][0..5];
+ const key = server_handshake_key[0..AEAD.key_length].*;
+ AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, key) catch
+ return error.TlsBadRecordMac;
+
+ break :c cleartext;
+ },
+ .TLS_AES_256_GCM_SHA384 => c: {
+ const AEAD = crypto.aead.aes_gcm.Aes256Gcm;
+ const ciphertext_len = record_size - AEAD.tag_length;
+ const ciphertext = handshake_buf[i..][0..ciphertext_len];
+ i += ciphertext.len;
+ if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow;
+ const cleartext = cleartext_buf[0..ciphertext.len];
+ const auth_tag = handshake_buf[i..][0..AEAD.tag_length].*;
+ const V = @Vector(AEAD.nonce_length, u8);
+ const pad = [1]u8{0} ** (AEAD.nonce_length - 8);
+ const operand: V = pad ++ @bitCast([8]u8, big(read_seq));
+ read_seq += 1;
+ const nonce: [AEAD.nonce_length]u8 = @as(V, server_handshake_iv) ^ operand;
+ const ad = handshake_buf[end_hdr - 5 ..][0..5];
+ const key = server_handshake_key[0..AEAD.key_length].*;
+ AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, key) catch
+ return error.TlsBadRecordMac;
+
+ break :c cleartext;
+ },
+ .TLS_CHACHA20_POLY1305_SHA256 => {
+ @panic("TODO");
+ },
+ .TLS_AES_128_CCM_SHA256 => {
+ @panic("TODO");
+ },
+ .TLS_AES_128_CCM_8_SHA256 => {
+ @panic("TODO");
+ },
+ };
+
+ const inner_ct = cleartext[cleartext.len - 1];
+ switch (inner_ct) {
+ @enumToInt(ContentType.handshake) => {
+ const handshake_len = mem.readIntBig(u24, cleartext[1..4]);
+ if (4 + handshake_len != cleartext.len - 1) return error.TlsBadLength;
+ switch (cleartext[0]) {
+ @enumToInt(HandshakeType.encrypted_extensions) => {
+ const ext_size = mem.readIntBig(u16, cleartext[4..6]);
+ if (ext_size != 0) {
+ @panic("TODO handle encrypted extensions");
+ }
+ std.debug.print("empty encrypted extensions\n", .{});
+ },
+ else => {
+ std.debug.print("handshake type: {d}\n", .{cleartext[0]});
+ return error.TlsUnexpectedMessage;
+ },
+ }
+ },
+ else => {
+ std.debug.print("inner content type: {d}\n", .{inner_ct});
+ return error.TlsUnexpectedMessage;
+ },
+ }
},
else => {
std.debug.print("content type: {s}\n", .{@tagName(ct)});
@@ -486,3 +640,56 @@ pub fn writeAll(tls: *Tls, stream: net.Stream, buffer: []const u8) !void {
_ = buffer;
@panic("hold on a minute, we didn't finish implementing the handshake yet");
}
+
+fn hkdfExpandLabel(
+ comptime Hkdf: type,
+ key: [Hkdf.prk_length]u8,
+ label: []const u8,
+ context: []const u8,
+ comptime len: usize,
+) [len]u8 {
+ const max_label_len = 255;
+ const max_context_len = 255;
+ const tls13 = "tls13 ";
+ var buf: [2 + 1 + tls13.len + max_label_len + 1 + max_context_len]u8 = undefined;
+ mem.writeIntBig(u16, buf[0..2], len);
+ buf[2] = @intCast(u8, tls13.len + label.len);
+ buf[3..][0..tls13.len].* = tls13.*;
+ var i: usize = 3 + tls13.len;
+ mem.copy(u8, buf[i..], label);
+ i += label.len;
+ buf[i] = @intCast(u8, context.len);
+ i += 1;
+ mem.copy(u8, buf[i..], context);
+ i += context.len;
+
+ var result: [len]u8 = undefined;
+ Hkdf.expand(&result, buf[0..i], key);
+ return result;
+}
+
+fn emptyHash(comptime Hash: type) [Hash.digest_length]u8 {
+ var result: [Hash.digest_length]u8 = undefined;
+ Hash.hash(&.{}, &result, .{});
+ return result;
+}
+
+fn helloHash(s0: []const u8, s1: []const u8, s2: []const u8, comptime Hash: type) [Hash.digest_length]u8 {
+ var h = Hash.init(.{});
+ h.update(s0);
+ h.update(s1);
+ h.update(s2);
+ var result: [Hash.digest_length]u8 = undefined;
+ h.final(&result);
+ return result;
+}
+
+const builtin = @import("builtin");
+const native_endian = builtin.cpu.arch.endian();
+
+inline fn big(x: anytype) @TypeOf(x) {
+ return switch (native_endian) {
+ .Big => x,
+ .Little => @byteSwap(x),
+ };
+}