Commit b97fc43baa

Andrew Kelley <andrew@ziglang.org>
2022-12-16 10:14:35
std.crypto.Tls: client is working against some servers
1 parent 40a8550
Changed files (2)
lib
std
crypto
http
lib/std/crypto/Tls.zig
@@ -9,12 +9,14 @@ application_cipher: ApplicationCipher,
 read_seq: u64,
 write_seq: u64,
 /// The size is enough to contain exactly one TLSCiphertext record.
-partially_read_buffer: [max_ciphertext_len + ciphertext_record_header_len]u8,
+partially_read_buffer: [max_ciphertext_record_len]u8,
 /// The number of partially read bytes inside `partiall_read_buffer`.
 partially_read_len: u15,
+eof: bool,
 
 pub const ciphertext_record_header_len = 5;
 pub const max_ciphertext_len = (1 << 14) + 256;
+pub const max_ciphertext_record_len = max_ciphertext_len + ciphertext_record_header_len;
 
 pub const ProtocolVersion = enum(u16) {
     tls_1_2 = 0x0303,
@@ -416,7 +418,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
 
     var cipher_params: CipherParams = undefined;
 
-    var handshake_buf: [4000]u8 = undefined;
+    var handshake_buf: [8000]u8 = undefined;
     var len: usize = 0;
     var i: usize = i: {
         const plaintext = handshake_buf[0..5];
@@ -554,8 +556,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
                         //    std.fmt.fmtSliceHexLower(&hello_hash),
                         //    std.fmt.fmtSliceHexLower(&early_secret),
                         //    std.fmt.fmtSliceHexLower(&empty_hash),
-                        //    std.fmt.fmtSliceHexLower(&derived_secret),
-                        //    std.fmt.fmtSliceHexLower(&handshake_secret),
+                        //    std.fmt.fmtSliceHexLower(&hs_derived_secret),
+                        //    std.fmt.fmtSliceHexLower(&p.handshake_secret),
                         //    std.fmt.fmtSliceHexLower(&client_secret),
                         //    std.fmt.fmtSliceHexLower(&server_secret),
                         //});
@@ -582,7 +584,9 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
         const end_hdr = i + 5;
         if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow;
         if (end_hdr > len) {
+            std.debug.print("read len={d} atleast={d}\n", .{ len, end_hdr - len });
             len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len);
+            std.debug.print("new len: {d} bytes\n", .{len});
             if (end_hdr > len) return error.EndOfStream;
         }
         const ct = @intToEnum(ContentType, handshake_buf[i]);
@@ -593,9 +597,12 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
         const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]);
         i += 2;
         const end = i + record_size;
+        std.debug.print("ct={any} record_size={d} end={d}\n", .{ ct, record_size, end });
         if (end > handshake_buf.len) return error.TlsRecordOverflow;
         if (end > len) {
+            std.debug.print("read len={d} atleast={d}\n", .{ len, end - len });
             len += try stream.readAtLeast(handshake_buf[len..], end - len);
+            std.debug.print("new len: {d} bytes\n", .{len});
             if (end > len) return error.EndOfStream;
         }
         switch (ct) {
@@ -604,7 +611,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
                 if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage;
             },
             .application_data => {
-                var cleartext_buf: [1000]u8 = undefined;
+                var cleartext_buf: [8000]u8 = undefined;
                 const cleartext = switch (cipher_params) {
                     inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: {
                         const P = @TypeOf(p.*);
@@ -637,17 +644,18 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
                 };
 
                 const inner_ct = cleartext[cleartext.len - 1];
+                std.debug.print("inner_ct={any}\n", .{@intToEnum(ContentType, inner_ct)});
                 switch (inner_ct) {
                     @enumToInt(ContentType.handshake) => {
                         const handshake_len = mem.readIntBig(u24, cleartext[1..4]);
-                        if (4 + handshake_len != cleartext.len - 1) return error.TlsBadLength;
+                        if (4 + handshake_len > cleartext.len - 1) return error.TlsBadLength;
+                        std.debug.print("handshake type: {any} size: {d}\n", .{ @intToEnum(HandshakeType, cleartext[0]), handshake_len });
                         switch (cleartext[0]) {
                             @enumToInt(HandshakeType.encrypted_extensions) => {
                                 const ext_size = mem.readIntBig(u16, cleartext[4..6]);
-                                if (ext_size != 0) {
-                                    @panic("TODO handle encrypted extensions");
-                                }
-                                std.debug.print("empty encrypted extensions\n", .{});
+                                std.debug.print("{d} bytes of encrypted extensions\n", .{
+                                    ext_size,
+                                });
                             },
                             @enumToInt(HandshakeType.certificate) => {
                                 std.debug.print("cool certificate bro\n", .{});
@@ -688,22 +696,18 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
                                         const nonce = p.client_handshake_iv;
                                         P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key);
 
-                                        {
-                                            var iovecs = [_]std.os.iovec_const{
-                                                .{
-                                                    .iov_base = &client_change_cipher_spec_msg,
-                                                    .iov_len = client_change_cipher_spec_msg.len,
-                                                },
-                                                .{
-                                                    .iov_base = &finished_msg,
-                                                    .iov_len = finished_msg.len,
-                                                },
-                                            };
-                                            try stream.writevAll(&iovecs);
-                                        }
+                                        //const both_msgs = client_change_cipher_spec_msg ++ finished_msg;
+                                        _ = client_change_cipher_spec_msg;
+                                        const both_msgs = finished_msg;
+                                        try stream.writeAll(&both_msgs);
 
                                         const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
                                         const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
+                                        //std.debug.print("master_secret={}\nclient_secret={}\nserver_secret={}\n", .{
+                                        //    std.fmt.fmtSliceHexLower(&p.master_secret),
+                                        //    std.fmt.fmtSliceHexLower(&client_secret),
+                                        //    std.fmt.fmtSliceHexLower(&server_secret),
+                                        //});
                                         break :c @unionInit(ApplicationCipher, @tagName(tag), .{
                                             .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length),
                                             .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length),
@@ -721,12 +725,14 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
                                         @panic("TODO");
                                     },
                                 };
+                                std.debug.print("remaining bytes: {d}\n", .{len - end});
                                 return .{
                                     .application_cipher = app_cipher,
-                                    .read_seq = read_seq,
-                                    .write_seq = 1,
+                                    .read_seq = 0,
+                                    .write_seq = 0,
                                     .partially_read_buffer = undefined,
                                     .partially_read_len = 0,
+                                    .eof = false,
                                 };
                             },
                             else => {
@@ -753,49 +759,67 @@ pub fn init(stream: net.Stream, host: []const u8) !Tls {
 }
 
 pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize {
-    var ciphertext_buf: [max_ciphertext_len * 4]u8 = undefined;
+    var ciphertext_buf: [max_ciphertext_record_len * 4]u8 = undefined;
+    // Due to the trailing inner content type byte in the ciphertext, we need
+    // an additional buffer for storing the cleartext into before encrypting.
+    var cleartext_buf: [max_ciphertext_len]u8 = undefined;
     var iovecs_buf: [5]std.os.iovec_const = undefined;
     var ciphertext_end: usize = 0;
     var iovec_end: usize = 0;
     var bytes_i: usize = 0;
-    switch (tls.application_cipher) {
-        inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| {
+    // How many bytes are taken up by overhead per record.
+    const overhead_len: usize = switch (tls.application_cipher) {
+        inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| l: {
             const P = @TypeOf(p.*);
             const V = @Vector(P.AEAD.nonce_length, u8);
+            const overhead_len = ciphertext_record_header_len + P.AEAD.tag_length + 1;
             while (true) {
-                const ciphertext_len = @intCast(u16, @min(
-                    @min(bytes.len - bytes_i, max_ciphertext_len),
-                    ciphertext_buf.len - 5 - P.AEAD.tag_length - ciphertext_end,
+                const encrypted_content_len = @intCast(u16, @min(
+                    @min(bytes.len - bytes_i, max_ciphertext_len - 1),
+                    ciphertext_buf.len -
+                        ciphertext_record_header_len - P.AEAD.tag_length - ciphertext_end - 1,
                 ));
-                if (ciphertext_len == 0) return bytes_i;
+                if (encrypted_content_len == 0) break :l overhead_len;
 
-                const wrapped_len = ciphertext_len + P.AEAD.tag_length;
-                const record = ciphertext_buf[ciphertext_end..][0 .. 5 + wrapped_len];
+                mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]);
+                cleartext_buf[encrypted_content_len] = @enumToInt(ContentType.application_data);
+                bytes_i += encrypted_content_len;
+                const ciphertext_len = encrypted_content_len + 1;
+                const cleartext = cleartext_buf[0..ciphertext_len];
 
-                const ad = record[0..5];
-                ciphertext_end += 5;
+                const record_start = ciphertext_end;
+                const ad = ciphertext_buf[ciphertext_end..][0..5];
+                ad.* =
+                    [_]u8{@enumToInt(ContentType.application_data)} ++
+                    int2(@enumToInt(ProtocolVersion.tls_1_2)) ++
+                    int2(ciphertext_len + P.AEAD.tag_length);
+                ciphertext_end += ad.len;
                 const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len];
                 ciphertext_end += ciphertext_len;
                 const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length];
-                ciphertext_end += P.AEAD.tag_length;
+                ciphertext_end += auth_tag.len;
                 const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
                 const operand: V = pad ++ @bitCast([8]u8, big(tls.write_seq));
                 tls.write_seq += 1;
                 const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.client_iv) ^ operand;
-                ad.* =
-                    [_]u8{@enumToInt(ContentType.application_data)} ++
-                    int2(@enumToInt(ProtocolVersion.tls_1_2)) ++
-                    int2(wrapped_len);
-                const cleartext = bytes[bytes_i..ciphertext.len];
                 P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key);
-
+                //std.debug.print("seq: {d} nonce: {} client_key: {} client_iv: {} ad: {} auth_tag: {}\nserver_key: {} server_iv: {}", .{
+                //    tls.write_seq - 1,
+                //    std.fmt.fmtSliceHexLower(&nonce),
+                //    std.fmt.fmtSliceHexLower(&p.client_key),
+                //    std.fmt.fmtSliceHexLower(&p.client_iv),
+                //    std.fmt.fmtSliceHexLower(ad),
+                //    std.fmt.fmtSliceHexLower(auth_tag),
+                //    std.fmt.fmtSliceHexLower(&p.server_key),
+                //    std.fmt.fmtSliceHexLower(&p.server_iv),
+                //});
+
+                const record = ciphertext_buf[record_start..ciphertext_end];
                 iovecs_buf[iovec_end] = .{
                     .iov_base = record.ptr,
                     .iov_len = record.len,
                 };
                 iovec_end += 1;
-
-                bytes_i += ciphertext_len;
             }
         },
         .TLS_CHACHA20_POLY1305_SHA256 => {
@@ -807,7 +831,7 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize {
         .TLS_AES_128_CCM_8_SHA256 => {
             @panic("TODO");
         },
-    }
+    };
 
     // Ideally we would call writev exactly once here, however, we must ensure
     // that we don't return with a record partially written.
@@ -815,9 +839,10 @@ pub fn write(tls: *Tls, stream: net.Stream, bytes: []const u8) !usize {
     var total_amt: usize = 0;
     while (true) {
         var amt = try stream.writev(iovecs_buf[i..iovec_end]);
-        total_amt += amt;
         while (amt >= iovecs_buf[i].iov_len) {
-            amt -= iovecs_buf[i].iov_len;
+            const encrypted_amt = iovecs_buf[i].iov_len;
+            total_amt += encrypted_amt - overhead_len;
+            amt -= encrypted_amt;
             i += 1;
             // Rely on the property that iovecs delineate records, meaning that
             // if amt equals zero here, we have fortunately found ourselves
@@ -849,11 +874,17 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
     const wanted_read_len = buf_cap * (max_ciphertext_len + ciphertext_record_header_len);
     const actual_read_len = try stream.read(in_buf[prev_len..@min(wanted_read_len, in_buf.len)]);
     const frag = in_buf[0 .. prev_len + actual_read_len];
+    if (frag.len == 0) {
+        tls.eof = true;
+        return 0;
+    }
+    std.debug.print("actual_read_len={d} frag.len={d}\n", .{ actual_read_len, frag.len });
     var in: usize = 0;
     var out: usize = 0;
 
     while (true) {
         if (in + ciphertext_record_header_len > frag.len) {
+            std.debug.print("in={d} frag.len={d}\n", .{ in, frag.len });
             return finishRead(tls, frag, in, out);
         }
         const ct = @intToEnum(ContentType, frag[in]);
@@ -866,6 +897,7 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
         const end = in + record_size;
         if (end > frag.len) {
             if (record_size > max_ciphertext_len) return error.TlsRecordOverflow;
+            std.debug.print("end={d} frag.len={d}\n", .{ end, frag.len });
             return finishRead(tls, frag, in, out);
         }
         switch (ct) {
@@ -877,6 +909,7 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
                     inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: {
                         const P = @TypeOf(p.*);
                         const V = @Vector(P.AEAD.nonce_length, u8);
+                        const ad = frag[in - 5 ..][0..5];
                         const ciphertext_len = record_size - P.AEAD.tag_length;
                         const ciphertext = frag[in..][0..ciphertext_len];
                         in += ciphertext_len;
@@ -886,7 +919,12 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
                         const operand: V = pad ++ @bitCast([8]u8, big(tls.read_seq));
                         tls.read_seq += 1;
                         const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand;
-                        const ad = frag[0..ciphertext_record_header_len];
+                        //std.debug.print("seq: {d} nonce: {} server_key: {} server_iv: {}\n", .{
+                        //    tls.read_seq - 1,
+                        //    std.fmt.fmtSliceHexLower(&nonce),
+                        //    std.fmt.fmtSliceHexLower(&p.server_key),
+                        //    std.fmt.fmtSliceHexLower(&p.server_iv),
+                        //});
                         P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch
                             return error.TlsBadRecordMac;
                         break :c cleartext.len;
@@ -902,15 +940,26 @@ pub fn read(tls: *Tls, stream: net.Stream, buffer: []u8) !usize {
                     },
                 };
 
-                const inner_ct = buffer[out + cleartext_len - 1];
+                const inner_ct = @intToEnum(ContentType, buffer[out + cleartext_len - 1]);
                 switch (inner_ct) {
-                    @enumToInt(ContentType.handshake) => {
+                    .alert => {
+                        const level = @intToEnum(AlertLevel, buffer[out]);
+                        const desc = @intToEnum(AlertDescription, buffer[out + 1]);
+                        if (desc == .close_notify) {
+                            tls.eof = true;
+                            return out;
+                        }
+                        std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
+                        return error.TlsAlert;
+                    },
+                    .handshake => {
                         std.debug.print("the server wants to keep shaking hands\n", .{});
                     },
-                    @enumToInt(ContentType.application_data) => {
+                    .application_data => {
                         out += cleartext_len - 1;
                     },
                     else => {
+                        std.debug.print("inner content type: {d}\n", .{inner_ct});
                         return error.TlsUnexpectedMessage;
                     },
                 }
lib/std/http/Client.zig
@@ -62,6 +62,25 @@ pub const Request = struct {
             .https => return req.tls.read(req.stream, buffer),
         }
     }
+
+    pub fn readAll(req: *Request, buffer: []u8) !usize {
+        return readAtLeast(req, buffer, buffer.len);
+    }
+
+    pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize {
+        var index: usize = 0;
+        while (index < len) {
+            const amt = try req.read(buffer[index..]);
+            if (amt == 0) {
+                switch (req.protocol) {
+                    .http => break,
+                    .https => if (req.tls.eof) break,
+                }
+            }
+            index += amt;
+        }
+        return index;
+    }
 };
 
 pub fn deinit(client: *Client) void {
@@ -92,7 +111,7 @@ pub fn request(client: *Client, options: Request.Options) !Request {
         @tagName(options.method).len +
             1 +
             options.path.len +
-            " HTTP/2\r\nHost: ".len +
+            " HTTP/1.1\r\nHost: ".len +
             options.host.len +
             "\r\nUpgrade-Insecure-Requests: 1\r\n".len +
             client.headers.items.len +
@@ -101,7 +120,7 @@ pub fn request(client: *Client, options: Request.Options) !Request {
     req.headers.appendSliceAssumeCapacity(@tagName(options.method));
     req.headers.appendSliceAssumeCapacity(" ");
     req.headers.appendSliceAssumeCapacity(options.path);
-    req.headers.appendSliceAssumeCapacity(" HTTP/2\r\nHost: ");
+    req.headers.appendSliceAssumeCapacity(" HTTP/1.1\r\nHost: ");
     req.headers.appendSliceAssumeCapacity(options.host);
     switch (options.protocol) {
         .https => req.headers.appendSliceAssumeCapacity("\r\nUpgrade-Insecure-Requests: 1\r\n"),