Commit ceb211e65f

Andrew Kelley <andrew@ziglang.org>
2022-12-28 06:59:19
std.crypto.tls.Client: handle key_update message
1 parent 5bbedb6
Changed files (2)
lib
std
lib/std/crypto/tls/Client.zig
@@ -24,7 +24,7 @@ read_seq: u64,
 write_seq: u64,
 /// The size is enough to contain exactly one TLSCiphertext record.
 partially_read_buffer: [tls.max_ciphertext_record_len]u8,
-/// The number of partially read bytes inside `partiall_read_buffer`.
+/// The number of partially read bytes inside `partially_read_buffer`.
 partially_read_len: u15,
 eof: bool,
 
@@ -584,6 +584,8 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
                                             //    std.fmt.fmtSliceHexLower(&server_secret),
                                             //});
                                             break :c @unionInit(ApplicationCipher, @tagName(tag), .{
+                                                .client_secret = client_secret,
+                                                .server_secret = server_secret,
                                                 .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length),
                                                 .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length),
                                                 .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length),
@@ -669,7 +671,7 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
                 ciphertext_end += auth_tag.len;
                 const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
                 const operand: V = pad ++ @bitCast([8]u8, big(c.write_seq));
-                c.write_seq += 1;
+                c.write_seq += 1; // TODO send key_update on overflow
                 const nonce = @as(V, p.client_iv) ^ operand;
                 P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key);
                 //std.debug.print("seq: {d} nonce: {} client_key: {} client_iv: {} ad: {} auth_tag: {}\nserver_key: {} server_iv: {}\n", .{
@@ -789,7 +791,8 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize {
                     },
                 };
 
-                const inner_ct = @intToEnum(ContentType, buffer[out + cleartext_len - 1]);
+                const cleartext = buffer[out..][0..cleartext_len];
+                const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]);
                 switch (inner_ct) {
                     .alert => {
                         const level = @intToEnum(tls.AlertLevel, buffer[out]);
@@ -802,7 +805,56 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize {
                         return error.TlsAlert;
                     },
                     .handshake => {
-                        std.debug.print("the server wants to keep shaking hands\n", .{});
+                        var ct_i: usize = 0;
+                        while (true) {
+                            const handshake_type = cleartext[ct_i];
+                            ct_i += 1;
+                            const handshake_len = mem.readIntBig(u24, cleartext[ct_i..][0..3]);
+                            ct_i += 3;
+                            const next_handshake_i = ct_i + handshake_len;
+                            if (next_handshake_i > cleartext.len - 1)
+                                return error.TlsBadLength;
+                            const handshake = cleartext[ct_i..next_handshake_i];
+                            switch (handshake_type) {
+                                @enumToInt(HandshakeType.new_session_ticket) => {
+                                    std.debug.print("server sent a new session ticket\n", .{});
+                                },
+                                @enumToInt(HandshakeType.key_update) => {
+                                    switch (c.application_cipher) {
+                                        inline else => |*p| {
+                                            const P = @TypeOf(p.*);
+                                            const server_secret = hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length);
+                                            p.server_secret = server_secret;
+                                            p.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length);
+                                            p.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length);
+                                        },
+                                    }
+                                    c.read_seq = 0;
+
+                                    switch (@intToEnum(tls.KeyUpdateRequest, handshake[0])) {
+                                        .update_requested => {
+                                            switch (c.application_cipher) {
+                                                inline else => |*p| {
+                                                    const P = @TypeOf(p.*);
+                                                    const client_secret = hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length);
+                                                    p.client_secret = client_secret;
+                                                    p.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length);
+                                                    p.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length);
+                                                },
+                                            }
+                                            c.write_seq = 0;
+                                        },
+                                        .update_not_requested => {},
+                                        _ => return error.TlsIllegalParameter,
+                                    }
+                                },
+                                else => {
+                                    return error.TlsUnexpectedMessage;
+                                },
+                            }
+                            ct_i = next_handshake_i;
+                            if (ct_i >= cleartext.len - 1) break;
+                        }
                     },
                     .application_data => {
                         out += cleartext_len - 1;
lib/std/crypto/tls.zig
@@ -227,6 +227,12 @@ pub const CertificateType = enum(u8) {
     _,
 };
 
+pub const KeyUpdateRequest = enum(u8) {
+    update_not_requested = 0,
+    update_requested = 1,
+    _,
+};
+
 pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type) type {
     return struct {
         pub const AEAD = AeadType;
@@ -261,6 +267,8 @@ pub fn ApplicationCipherT(comptime AeadType: type, comptime HashType: type) type
         pub const Hmac = crypto.auth.hmac.Hmac(Hash);
         pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
 
+        client_secret: [Hash.digest_length]u8,
+        server_secret: [Hash.digest_length]u8,
         client_key: [AEAD.key_length]u8,
         server_key: [AEAD.key_length]u8,
         client_iv: [AEAD.nonce_length]u8,