Commit de53e6e4f2

Jacob Young <jacobly0@users.noreply.github.com>
2024-11-05 07:37:12
std.crypto.tls: improve debuggability of encrypted connections
By default, programs built in debug mode that open a https connection will append secrets to the file specified in the SSLKEYLOGFILE environment variable to allow protocol debugging by external programs.
1 parent d86a8ae
Changed files (3)
lib
std
lib/std/crypto/tls/Client.zig
@@ -33,7 +33,7 @@ received_close_notify: bool,
 /// This makes the application vulnerable to truncation attacks unless the
 /// application layer itself verifies that the amount of data received equals
 /// the amount of data expected, such as HTTP with the Content-Length header.
-allow_truncation_attacks: bool = false,
+allow_truncation_attacks: bool,
 application_cipher: tls.ApplicationCipher,
 /// The size is enough to contain exactly one TLSCiphertext record.
 /// This buffer is segmented into four parts:
@@ -44,6 +44,24 @@ application_cipher: tls.ApplicationCipher,
 /// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and
 /// `partial_ciphertext_end` describe the span of the segments.
 partially_read_buffer: [tls.max_ciphertext_record_len]u8,
+/// If non-null, ssl secrets are logged to a file.  Creating such a log file allows other
+/// programs with access to that file to decrypt all traffic over this connection.
+ssl_key_log: ?struct {
+    client_key_seq: u64,
+    server_key_seq: u64,
+    client_random: [32]u8,
+    file: std.fs.File,
+
+    fn clientCounter(key_log: *@This()) u64 {
+        defer key_log.client_key_seq += 1;
+        return key_log.client_key_seq;
+    }
+
+    fn serverCounter(key_log: *@This()) u64 {
+        defer key_log.server_key_seq += 1;
+        return key_log.server_key_seq;
+    }
+},
 
 /// This is an example of the type that is needed by the read and write
 /// functions. It can have any fields but it must at least have these
@@ -88,6 +106,32 @@ pub const StreamInterface = struct {
     }
 };
 
+pub const Options = struct {
+    /// How to perform host verification of server certificates.
+    host: union(enum) {
+        /// No host verification is performed, which prevents a trusted connection from
+        /// being established.
+        no_verification,
+        /// Verify that the server certificate was issues for a given host.
+        explicit: []const u8,
+    },
+    /// How to verify the authenticity of server certificates.
+    ca: union(enum) {
+        /// No ca verification is performed, which prevents a trusted connection from
+        /// being established.
+        no_verification,
+        /// Verify that the server certificate is a valid self-signed certificate.
+        /// This provides no authorization guarantees, as anyone can create a
+        /// self-signed certificate.
+        self_signed,
+        /// Verify that the server certificate is authorized by a given ca bundle.
+        bundle: Certificate.Bundle,
+    },
+    /// If non-null, ssl secrets are logged to this file.  Creating such a log file allows
+    /// other programs with access to that file to decrypt all traffic over this connection.
+    ssl_key_log_file: ?std.fs.File = null,
+};
+
 pub fn InitError(comptime Stream: type) type {
     return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{
         InsufficientEntropy,
@@ -140,12 +184,17 @@ pub fn InitError(comptime Stream: type) type {
 /// must conform to `StreamInterface`.
 ///
 /// `host` is only borrowed during this function call.
-pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) InitError(@TypeOf(stream))!Client {
+pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client {
+    const host = switch (options.host) {
+        .no_verification => "",
+        .explicit => |host| host,
+    };
     const host_len: u16 = @intCast(host.len);
 
     var random_buffer: [128]u8 = undefined;
     crypto.random.bytes(&random_buffer);
     const client_hello_rand = random_buffer[0..32].*;
+    var key_seq: u64 = 0;
     var server_hello_rand: [32]u8 = undefined;
     const legacy_session_id = random_buffer[32..64].*;
 
@@ -179,15 +228,21 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
             array(u16, u8, key_share.secp256r1_kp.public_key.toUncompressedSec1()) ++
             int(u16, @intFromEnum(tls.NamedGroup.x25519)) ++
             array(u16, u8, key_share.x25519_kp.public_key),
-    )) ++ int(u16, @intFromEnum(tls.ExtensionType.server_name)) ++
+    ));
+    const server_name_extension = int(u16, @intFromEnum(tls.ExtensionType.server_name)) ++
         int(u16, 2 + 1 + 2 + host_len) ++ // byte length of this extension payload
         int(u16, 1 + 2 + host_len) ++ // server_name_list byte count
         .{0x00} ++ // name_type
         int(u16, host_len);
+    const server_name_extension_len = switch (options.host) {
+        .no_verification => 0,
+        .explicit => server_name_extension.len + host_len,
+    };
 
     const extensions_header =
-        int(u16, @intCast(extensions_payload.len + host_len)) ++
-        extensions_payload;
+        int(u16, @intCast(extensions_payload.len + server_name_extension_len)) ++
+        extensions_payload ++
+        server_name_extension;
 
     const client_hello =
         int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
@@ -198,20 +253,24 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
         extensions_header;
 
     const out_handshake = .{@intFromEnum(tls.HandshakeType.client_hello)} ++
-        int(u24, @intCast(client_hello.len + host_len)) ++
+        int(u24, @intCast(client_hello.len - server_name_extension.len + server_name_extension_len)) ++
         client_hello;
 
-    const cleartext_header = .{@intFromEnum(tls.ContentType.handshake)} ++
+    const cleartext_header_buf = .{@intFromEnum(tls.ContentType.handshake)} ++
         int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_0)) ++
-        int(u16, @intCast(out_handshake.len + host_len)) ++
+        int(u16, @intCast(out_handshake.len - server_name_extension.len + server_name_extension_len)) ++
         out_handshake;
+    const cleartext_header = switch (options.host) {
+        .no_verification => cleartext_header_buf[0 .. cleartext_header_buf.len - server_name_extension.len],
+        .explicit => &cleartext_header_buf,
+    };
 
     {
         var iovecs = [_]std.posix.iovec_const{
-            .{ .base = &cleartext_header, .len = cleartext_header.len },
+            .{ .base = cleartext_header.ptr, .len = cleartext_header.len },
             .{ .base = host.ptr, .len = host.len },
         };
-        try stream.writevAll(&iovecs);
+        try stream.writevAll(iovecs[0..if (host.len == 0) 1 else 2]);
     }
 
     var tls_version: tls.ProtocolVersion = undefined;
@@ -472,6 +531,12 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
                                         pv.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes);
                                         const client_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length);
                                         const server_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length);
+                                        if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{
+                                            .client_random = &client_hello_rand,
+                                        }, .{
+                                            .SERVER_HANDSHAKE_TRAFFIC_SECRET = &server_secret,
+                                            .CLIENT_HANDSHAKE_TRAFFIC_SECRET = &client_secret,
+                                        });
                                         pv.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length);
                                         pv.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length);
                                         pv.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length);
@@ -544,6 +609,13 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
                             const cert_size = certs_decoder.decode(u24);
                             const certd = try certs_decoder.sub(cert_size);
 
+                            if (tls_version == .tls_1_3) {
+                                try certs_decoder.ensure(2);
+                                const total_ext_size = certs_decoder.decode(u16);
+                                const all_extd = try certs_decoder.sub(total_ext_size);
+                                _ = all_extd;
+                            }
+
                             const subject_cert: Certificate = .{
                                 .buffer = certd.buf,
                                 .index = @intCast(certd.idx),
@@ -551,7 +623,10 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
                             const subject = try subject_cert.parse();
                             if (cert_index == 0) {
                                 // Verify the host on the first certificate.
-                                try subject.verifyHostName(host);
+                                switch (options.host) {
+                                    .no_verification => {},
+                                    .explicit => try subject.verifyHostName(host),
+                                }
 
                                 // Keep track of the public key for the
                                 // certificate_verify message later.
@@ -560,23 +635,27 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
                                 try prev_cert.verify(subject, now_sec);
                             }
 
-                            if (ca_bundle.verify(subject, now_sec)) |_| {
-                                handshake_state = .trust_chain_established;
-                                break :cert;
-                            } else |err| switch (err) {
-                                error.CertificateIssuerNotFound => {},
-                                else => |e| return e,
+                            switch (options.ca) {
+                                .no_verification => {
+                                    handshake_state = .trust_chain_established;
+                                    break :cert;
+                                },
+                                .self_signed => {
+                                    try subject.verify(subject, now_sec);
+                                    handshake_state = .trust_chain_established;
+                                    break :cert;
+                                },
+                                .bundle => |ca_bundle| if (ca_bundle.verify(subject, now_sec)) |_| {
+                                    handshake_state = .trust_chain_established;
+                                    break :cert;
+                                } else |err| switch (err) {
+                                    error.CertificateIssuerNotFound => {},
+                                    else => |e| return e,
+                                },
                             }
 
                             prev_cert = subject;
                             cert_index += 1;
-
-                            if (tls_version == .tls_1_3) {
-                                try certs_decoder.ensure(2);
-                                const total_ext_size = certs_decoder.decode(u16);
-                                const all_extd = try certs_decoder.sub(total_ext_size);
-                                _ = all_extd;
-                            }
                         }
                     },
                     .server_key_exchange => {
@@ -625,6 +704,11 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
                                     &client_hello_rand,
                                     &server_hello_rand,
                                 }, 48);
+                                if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{
+                                    .client_random = &client_hello_rand,
+                                }, .{
+                                    .CLIENT_RANDOM = &master_secret,
+                                });
                                 const key_block = hmacExpandLabel(
                                     P.Hmac,
                                     &master_secret,
@@ -748,6 +832,14 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
 
                                     const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
                                     const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
+                                    if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{
+                                        .counter = key_seq,
+                                        .client_random = &client_hello_rand,
+                                    }, .{
+                                        .SERVER_TRAFFIC_SECRET = &server_secret,
+                                        .CLIENT_TRAFFIC_SECRET = &client_secret,
+                                    });
+                                    key_seq += 1;
                                     break :app_cipher @unionInit(tls.ApplicationCipher, @tagName(tag), .{ .tls_1_3 = .{
                                         .client_secret = client_secret,
                                         .server_secret = server_secret,
@@ -784,8 +876,15 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
                             .partial_ciphertext_idx = 0,
                             .partial_ciphertext_end = @intCast(leftover.len),
                             .received_close_notify = false,
+                            .allow_truncation_attacks = false,
                             .application_cipher = app_cipher,
                             .partially_read_buffer = undefined,
+                            .ssl_key_log = if (options.ssl_key_log_file) |key_log_file| .{
+                                .client_key_seq = key_seq,
+                                .server_key_seq = key_seq,
+                                .client_random = client_hello_rand,
+                                .file = key_log_file,
+                            } else null,
                         };
                         @memcpy(client.partially_read_buffer[0..leftover.len], leftover);
                         return client;
@@ -1358,6 +1457,12 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove
                                     const pv = &p.tls_1_3;
                                     const P = @TypeOf(p.*);
                                     const server_secret = hkdfExpandLabel(P.Hkdf, pv.server_secret, "traffic upd", "", P.Hash.digest_length);
+                                    if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{
+                                        .counter = key_log.serverCounter(),
+                                        .client_random = &key_log.client_random,
+                                    }, .{
+                                        .SERVER_TRAFFIC_SECRET = &server_secret,
+                                    });
                                     pv.server_secret = server_secret;
                                     pv.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length);
                                     pv.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length);
@@ -1372,6 +1477,12 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove
                                             const pv = &p.tls_1_3;
                                             const P = @TypeOf(p.*);
                                             const client_secret = hkdfExpandLabel(P.Hkdf, pv.client_secret, "traffic upd", "", P.Hash.digest_length);
+                                            if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{
+                                                .counter = key_log.clientCounter(),
+                                                .client_random = &key_log.client_random,
+                                            }, .{
+                                                .CLIENT_TRAFFIC_SECRET = &client_secret,
+                                            });
                                             pv.client_secret = client_secret;
                                             pv.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length);
                                             pv.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length);
@@ -1426,6 +1537,18 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove
     }
 }
 
+fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) void {
+    const locked = if (key_log_file.lock(.exclusive)) |_| true else |_| false;
+    defer if (locked) key_log_file.unlock();
+    key_log_file.seekFromEnd(0) catch {};
+    inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| key_log_file.writer().print("{s}" ++
+        (if (@hasField(@TypeOf(context), "counter")) "_{d}" else "") ++ " {} {}\n", .{field.name} ++
+        (if (@hasField(@TypeOf(context), "counter")) .{context.counter} else .{}) ++ .{
+        std.fmt.fmtSliceHexLower(context.client_random),
+        std.fmt.fmtSliceHexLower(@field(secrets, field.name)),
+    }) catch {};
+}
+
 fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize {
     const saved_buf = frag[in..];
     if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
lib/std/http/Client.zig
@@ -388,6 +388,7 @@ pub const Connection = struct {
 
             // try to cleanly close the TLS connection, for any server that cares.
             _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {};
+            if (conn.tls_client.ssl_key_log) |key_log| key_log.file.close();
             allocator.destroy(conn.tls_client);
         }
 
@@ -566,7 +567,7 @@ pub const Response = struct {
             .reason = undefined,
             .version = undefined,
             .keep_alive = false,
-            .parser = proto.HeadersParser.init(&header_buffer),
+            .parser = .init(&header_buffer),
         };
 
         @memcpy(header_buffer[0..response_bytes.len], response_bytes);
@@ -610,7 +611,7 @@ pub const Response = struct {
     }
 
     pub fn iterateHeaders(r: Response) http.HeaderIterator {
-        return http.HeaderIterator.init(r.parser.get());
+        return .init(r.parser.get());
     }
 
     test iterateHeaders {
@@ -628,7 +629,7 @@ pub const Response = struct {
             .reason = undefined,
             .version = undefined,
             .keep_alive = false,
-            .parser = proto.HeadersParser.init(&header_buffer),
+            .parser = .init(&header_buffer),
         };
 
         @memcpy(header_buffer[0..response_bytes.len], response_bytes);
@@ -771,7 +772,7 @@ pub const Request = struct {
         req.client.connection_pool.release(req.client.allocator, req.connection.?);
         req.connection = null;
 
-        var server_header = std.heap.FixedBufferAllocator.init(req.response.parser.header_bytes_buffer);
+        var server_header: std.heap.FixedBufferAllocator = .init(req.response.parser.header_bytes_buffer);
         defer req.response.parser.header_bytes_buffer = server_header.buffer[server_header.end_index..];
         const protocol, const valid_uri = try validateUri(uri, server_header.allocator());
 
@@ -1354,7 +1355,21 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec
         conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client);
         errdefer client.allocator.destroy(conn.data.tls_client);
 
-        conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed;
+        const ssl_key_log_file: ?std.fs.File = if (std.options.http_enable_ssl_key_log_file) ssl_key_log_file: {
+            const ssl_key_log_path = std.process.getEnvVarOwned(client.allocator, "SSLKEYLOGFILE") catch |err| switch (err) {
+                error.EnvironmentVariableNotFound, error.InvalidWtf8 => break :ssl_key_log_file null,
+                error.OutOfMemory => return error.OutOfMemory,
+            };
+            defer client.allocator.free(ssl_key_log_path);
+            break :ssl_key_log_file std.fs.cwd().createFile(ssl_key_log_path, .{ .truncate = false }) catch null;
+        } else null;
+        errdefer if (ssl_key_log_file) |key_log_file| key_log_file.close();
+
+        conn.data.tls_client.* = std.crypto.tls.Client.init(stream, .{
+            .host = .{ .explicit = host },
+            .ca = .{ .bundle = client.ca_bundle },
+            .ssl_key_log_file = ssl_key_log_file,
+        }) catch return error.TlsInitializationFailed;
         // This is appropriate for HTTPS because the HTTP headers contain
         // the content length which is used to detect truncation attacks.
         conn.data.tls_client.allow_truncation_attacks = true;
@@ -1620,7 +1635,7 @@ pub fn open(
         }
     }
 
-    var server_header = std.heap.FixedBufferAllocator.init(options.server_header_buffer);
+    var server_header: std.heap.FixedBufferAllocator = .init(options.server_header_buffer);
     const protocol, const valid_uri = try validateUri(uri, server_header.allocator());
 
     if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) {
@@ -1654,7 +1669,7 @@ pub fn open(
             .status = undefined,
             .reason = undefined,
             .keep_alive = undefined,
-            .parser = proto.HeadersParser.init(server_header.buffer[server_header.end_index..]),
+            .parser = .init(server_header.buffer[server_header.end_index..]),
         },
         .headers = options.headers,
         .extra_headers = options.extra_headers,
lib/std/std.zig
@@ -146,6 +146,11 @@ pub const Options = struct {
     /// make a HTTPS connection.
     http_disable_tls: bool = false,
 
+    /// This enables `std.http.Client` to log ssl secrets to the file specified by the SSLKEYLOGFILE
+    /// env var.  Creating such a log file allows other programs with access to that file to decrypt
+    /// all `std.http.Client` traffic made by this program.
+    http_enable_ssl_key_log_file: bool = @import("builtin").mode == .Debug,
+
     side_channels_mitigations: crypto.SideChannelsMitigations = crypto.default_side_channels_mitigations,
 };