Commit 595fff7cb6

Andrew Kelley <andrew@ziglang.org>
2022-12-15 08:55:33
std.crypto.Tls: decrypting handshake messages
1 parent 920e5bc
Changed files (1)
lib
std
crypto
lib/std/crypto/Tls.zig
@@ -234,7 +234,12 @@ const cipher_suites = blk: {
 pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
     assert(tls.state == .start);
     crypto.random.bytes(&tls.x25519_priv_key);
-    tls.x25519_pub_key = try crypto.dh.X25519.recoverPublicKey(tls.x25519_priv_key);
+    tls.x25519_pub_key = crypto.dh.X25519.recoverPublicKey(tls.x25519_priv_key) catch |err| {
+        switch (err) {
+            // Only possible to happen if the private key is all zeroes.
+            error.IdentityElement => return error.InsufficientEntropy,
+        }
+    };
 
     // random (u32)
     var rand_buf: [32]u8 = undefined;
@@ -337,6 +342,14 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
     };
     try stream.writevAll(&iovecs);
 
+    const client_hello_bytes1 = hello_header[5..];
+
+    var client_handshake_key: [32]u8 = undefined;
+    var server_handshake_key: [32]u8 = undefined;
+    var client_handshake_iv: [12]u8 = undefined;
+    var server_handshake_iv: [12]u8 = undefined;
+    var cipher_suite: CipherSuite = undefined;
+
     var handshake_buf: [4000]u8 = undefined;
     var len: usize = 0;
     var i: usize = i: {
@@ -373,7 +386,7 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
                 const legacy_session_id_echo_len = hello[34];
                 if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter;
                 const cipher_suite_int = mem.readIntBig(u16, hello[35..37]);
-                const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
+                cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
                     return error.TlsIllegalParameter;
                 std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)});
                 const legacy_compression_method = hello[37];
@@ -404,12 +417,7 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
                                     const key_size = mem.readIntBig(u16, hello[i..][0..2]);
                                     i += 2;
                                     if (key_size != 32) return error.TlsBadLength;
-                                    const encrypted_key = hello[i..][0..32].*;
-                                    const server_pub_key = try crypto.dh.X25519.scalarmult(
-                                        tls.x25519_priv_key,
-                                        encrypted_key,
-                                    );
-                                    tls.x25519_server_pub_key = server_pub_key;
+                                    tls.x25519_server_pub_key = hello[i..][0..32].*;
                                     have_server_pub_key = true;
                                 },
                                 else => {
@@ -435,12 +443,77 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
                     },
                     else => return error.TlsIllegalParameter,
                 }
+
+                const shared_key = crypto.dh.X25519.scalarmult(
+                    tls.x25519_priv_key,
+                    tls.x25519_server_pub_key,
+                ) catch return error.TlsDecryptFailure;
+
+                switch (cipher_suite) {
+                    .TLS_AES_128_GCM_SHA256 => {
+                        const AEAD = crypto.aead.aes_gcm.Aes128Gcm;
+                        const Hash = crypto.hash.sha2.Sha256;
+                        const Hmac = crypto.auth.hmac.Hmac(Hash);
+                        const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
+
+                        const hello_hash = helloHash(client_hello_bytes1, host, frag, Hash);
+                        const early_secret = Hkdf.extract(&[1]u8{0}, &([1]u8{0} ** Hash.digest_length));
+                        const empty_hash = emptyHash(Hash);
+                        const derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length);
+                        const handshake_secret = Hkdf.extract(&derived_secret, &shared_key);
+                        const client_secret = hkdfExpandLabel(Hkdf, handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length);
+                        const server_secret = hkdfExpandLabel(Hkdf, handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length);
+                        client_handshake_key[0..AEAD.key_length].* = hkdfExpandLabel(Hkdf, client_secret, "key", "", AEAD.key_length);
+                        server_handshake_key[0..AEAD.key_length].* = hkdfExpandLabel(Hkdf, server_secret, "key", "", AEAD.key_length);
+                        client_handshake_iv = hkdfExpandLabel(Hkdf, client_secret, "iv", "", AEAD.nonce_length);
+                        server_handshake_iv = hkdfExpandLabel(Hkdf, server_secret, "iv", "", AEAD.nonce_length);
+                        //std.debug.print("shared_key: {}\nhello_hash: {}\nearly_secret: {}\nempty_hash: {}\nderived_secret: {}\nhandshake_secret: {}\n client_secret: {}\n server_secret: {}\n", .{
+                        //    std.fmt.fmtSliceHexLower(&shared_key),
+                        //    std.fmt.fmtSliceHexLower(&hello_hash),
+                        //    std.fmt.fmtSliceHexLower(&early_secret),
+                        //    std.fmt.fmtSliceHexLower(&empty_hash),
+                        //    std.fmt.fmtSliceHexLower(&derived_secret),
+                        //    std.fmt.fmtSliceHexLower(&handshake_secret),
+                        //    std.fmt.fmtSliceHexLower(&client_secret),
+                        //    std.fmt.fmtSliceHexLower(&server_secret),
+                        //});
+                    },
+                    .TLS_AES_256_GCM_SHA384 => {
+                        const AEAD = crypto.aead.aes_gcm.Aes256Gcm;
+                        const Hash = crypto.hash.sha2.Sha384;
+                        const Hmac = crypto.auth.hmac.Hmac(Hash);
+                        const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
+
+                        const hello_hash = helloHash(client_hello_bytes1, host, frag, Hash);
+                        const early_secret = Hkdf.extract(&[1]u8{0}, &([1]u8{0} ** Hash.digest_length));
+                        const empty_hash = emptyHash(Hash);
+                        const derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length);
+                        const handshake_secret = Hkdf.extract(&derived_secret, &shared_key);
+                        const client_secret = hkdfExpandLabel(Hkdf, handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length);
+                        const server_secret = hkdfExpandLabel(Hkdf, handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length);
+                        client_handshake_key = hkdfExpandLabel(Hkdf, client_secret, "key", "", AEAD.key_length);
+                        server_handshake_key = hkdfExpandLabel(Hkdf, server_secret, "key", "", AEAD.key_length);
+                        client_handshake_iv = hkdfExpandLabel(Hkdf, client_secret, "iv", "", AEAD.nonce_length);
+                        server_handshake_iv = hkdfExpandLabel(Hkdf, server_secret, "iv", "", AEAD.nonce_length);
+                    },
+                    .TLS_CHACHA20_POLY1305_SHA256 => {
+                        @panic("TODO");
+                    },
+                    .TLS_AES_128_CCM_SHA256 => {
+                        @panic("TODO");
+                    },
+                    .TLS_AES_128_CCM_8_SHA256 => {
+                        @panic("TODO");
+                    },
+                }
             },
             else => return error.TlsUnexpectedMessage,
         }
         break :i end;
     };
 
+    var read_seq: u64 = 0;
+
     while (true) {
         const end_hdr = i + 5;
         if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow;
@@ -467,7 +540,88 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
                 if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage;
             },
             .application_data => {
-                std.debug.print("TODO: decrypt these {d} bytes\n", .{record_size});
+                var cleartext_buf: [1000]u8 = undefined;
+                const cleartext = switch (cipher_suite) {
+                    .TLS_AES_128_GCM_SHA256 => c: {
+                        const AEAD = crypto.aead.aes_gcm.Aes128Gcm;
+                        const ciphertext_len = record_size - AEAD.tag_length;
+                        const ciphertext = handshake_buf[i..][0..ciphertext_len];
+                        i += ciphertext.len;
+                        if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow;
+                        const cleartext = cleartext_buf[0..ciphertext.len];
+                        const auth_tag = handshake_buf[i..][0..AEAD.tag_length].*;
+                        const V = @Vector(AEAD.nonce_length, u8);
+                        const pad = [1]u8{0} ** (AEAD.nonce_length - 8);
+                        const operand: V = pad ++ @bitCast([8]u8, big(read_seq));
+                        read_seq += 1;
+                        const nonce: [AEAD.nonce_length]u8 = @as(V, server_handshake_iv) ^ operand;
+                        //std.debug.print("seq: {d} nonce: {} operand: {}\n", .{
+                        //    read_seq - 1,
+                        //    std.fmt.fmtSliceHexLower(&nonce),
+                        //    std.fmt.fmtSliceHexLower(&@as([12]u8, operand)),
+                        //});
+                        const ad = handshake_buf[end_hdr - 5 ..][0..5];
+                        const key = server_handshake_key[0..AEAD.key_length].*;
+                        AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, key) catch
+                            return error.TlsBadRecordMac;
+
+                        break :c cleartext;
+                    },
+                    .TLS_AES_256_GCM_SHA384 => c: {
+                        const AEAD = crypto.aead.aes_gcm.Aes256Gcm;
+                        const ciphertext_len = record_size - AEAD.tag_length;
+                        const ciphertext = handshake_buf[i..][0..ciphertext_len];
+                        i += ciphertext.len;
+                        if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow;
+                        const cleartext = cleartext_buf[0..ciphertext.len];
+                        const auth_tag = handshake_buf[i..][0..AEAD.tag_length].*;
+                        const V = @Vector(AEAD.nonce_length, u8);
+                        const pad = [1]u8{0} ** (AEAD.nonce_length - 8);
+                        const operand: V = pad ++ @bitCast([8]u8, big(read_seq));
+                        read_seq += 1;
+                        const nonce: [AEAD.nonce_length]u8 = @as(V, server_handshake_iv) ^ operand;
+                        const ad = handshake_buf[end_hdr - 5 ..][0..5];
+                        const key = server_handshake_key[0..AEAD.key_length].*;
+                        AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, key) catch
+                            return error.TlsBadRecordMac;
+
+                        break :c cleartext;
+                    },
+                    .TLS_CHACHA20_POLY1305_SHA256 => {
+                        @panic("TODO");
+                    },
+                    .TLS_AES_128_CCM_SHA256 => {
+                        @panic("TODO");
+                    },
+                    .TLS_AES_128_CCM_8_SHA256 => {
+                        @panic("TODO");
+                    },
+                };
+
+                const inner_ct = cleartext[cleartext.len - 1];
+                switch (inner_ct) {
+                    @enumToInt(ContentType.handshake) => {
+                        const handshake_len = mem.readIntBig(u24, cleartext[1..4]);
+                        if (4 + handshake_len != cleartext.len - 1) return error.TlsBadLength;
+                        switch (cleartext[0]) {
+                            @enumToInt(HandshakeType.encrypted_extensions) => {
+                                const ext_size = mem.readIntBig(u16, cleartext[4..6]);
+                                if (ext_size != 0) {
+                                    @panic("TODO handle encrypted extensions");
+                                }
+                                std.debug.print("empty encrypted extensions\n", .{});
+                            },
+                            else => {
+                                std.debug.print("handshake type: {d}\n", .{cleartext[0]});
+                                return error.TlsUnexpectedMessage;
+                            },
+                        }
+                    },
+                    else => {
+                        std.debug.print("inner content type: {d}\n", .{inner_ct});
+                        return error.TlsUnexpectedMessage;
+                    },
+                }
             },
             else => {
                 std.debug.print("content type: {s}\n", .{@tagName(ct)});
@@ -486,3 +640,56 @@ pub fn writeAll(tls: *Tls, stream: net.Stream, buffer: []const u8) !void {
     _ = buffer;
     @panic("hold on a minute, we didn't finish implementing the handshake yet");
 }
+
+fn hkdfExpandLabel(
+    comptime Hkdf: type,
+    key: [Hkdf.prk_length]u8,
+    label: []const u8,
+    context: []const u8,
+    comptime len: usize,
+) [len]u8 {
+    const max_label_len = 255;
+    const max_context_len = 255;
+    const tls13 = "tls13 ";
+    var buf: [2 + 1 + tls13.len + max_label_len + 1 + max_context_len]u8 = undefined;
+    mem.writeIntBig(u16, buf[0..2], len);
+    buf[2] = @intCast(u8, tls13.len + label.len);
+    buf[3..][0..tls13.len].* = tls13.*;
+    var i: usize = 3 + tls13.len;
+    mem.copy(u8, buf[i..], label);
+    i += label.len;
+    buf[i] = @intCast(u8, context.len);
+    i += 1;
+    mem.copy(u8, buf[i..], context);
+    i += context.len;
+
+    var result: [len]u8 = undefined;
+    Hkdf.expand(&result, buf[0..i], key);
+    return result;
+}
+
+fn emptyHash(comptime Hash: type) [Hash.digest_length]u8 {
+    var result: [Hash.digest_length]u8 = undefined;
+    Hash.hash(&.{}, &result, .{});
+    return result;
+}
+
+fn helloHash(s0: []const u8, s1: []const u8, s2: []const u8, comptime Hash: type) [Hash.digest_length]u8 {
+    var h = Hash.init(.{});
+    h.update(s0);
+    h.update(s1);
+    h.update(s2);
+    var result: [Hash.digest_length]u8 = undefined;
+    h.final(&result);
+    return result;
+}
+
+const builtin = @import("builtin");
+const native_endian = builtin.cpu.arch.endian();
+
+inline fn big(x: anytype) @TypeOf(x) {
+    return switch (native_endian) {
+        .Big => x,
+        .Little => @byteSwap(x),
+    };
+}