Commit 611a1fdd6d

Andrew Kelley <andrew@ziglang.org>
2022-12-31 04:06:42
std.crypto.tls: add API for sending close_notify
This commit adds `writeEnd` and `writeAllEnd` in order to send data and also notify the server that there will be no more data written. Unfortunately, it seems most TLS implementations in the wild get this wrong and immediately close the socket when they see a close_notify, rather than only ending the data stream on the application layer.
1 parent b3c8c38
Changed files (3)
lib
std
lib/std/crypto/tls/Client.zig
@@ -37,8 +37,54 @@ application_cipher: tls.ApplicationCipher,
 /// `partial_ciphertext_end` describe the span of the segments.
 partially_read_buffer: [tls.max_ciphertext_record_len]u8,
 
+/// This is an example of the type that is needed by the read and write
+/// functions. It can have any fields but it must at least have these
+/// functions.
+///
+/// Note that `std.net.Stream` conforms to this interface.
+///
+/// This declaration serves as documentation only.
+pub const StreamInterface = struct {
+    /// Can be any error set.
+    pub const ReadError = error{};
+
+    /// Returns the number of bytes read. The number read may be less than the
+    /// buffer space provided. End-of-stream is indicated by a return value of 0.
+    ///
+    /// The `iovecs` parameter is mutable because so that function may to
+    /// mutate the fields in order to handle partial reads from the underlying
+    /// stream layer.
+    pub fn readv(this: @This(), iovecs: []std.os.iovec) ReadError!usize {
+        _ = .{ this, iovecs };
+        @panic("unimplemented");
+    }
+
+    /// Can be any error set.
+    pub const WriteError = error{};
+
+    /// Returns the number of bytes read, which may be less than the buffer
+    /// space provided. A short read does not indicate end-of-stream.
+    pub fn writev(this: @This(), iovecs: []const std.os.iovec_const) WriteError!usize {
+        _ = .{ this, iovecs };
+        @panic("unimplemented");
+    }
+
+    /// Returns the number of bytes read, which may be less than the buffer
+    /// space provided, indicating end-of-stream.
+    /// The `iovecs` parameter is mutable in case this function needs to mutate
+    /// the fields in order to handle partial writes from the underlying layer.
+    pub fn writevAll(this: @This(), iovecs: []std.os.iovec_const) WriteError!usize {
+        // This can be implemented in terms of writev, or specialized if desired.
+        _ = .{ this, iovecs };
+        @panic("unimplemented");
+    }
+};
+
+/// Initiates a TLS handshake and establishes a TLSv1.3 session with `stream`, which
+/// must conform to `StreamInterface`.
+///
 /// `host` is only borrowed during this function call.
-pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) !Client {
+pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) !Client {
     const host_len = @intCast(u16, host.len);
 
     var random_buffer: [128]u8 = undefined;
@@ -579,31 +625,115 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
     }
 }
 
-pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
+/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
+/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`.
+pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize {
+    return writeEnd(c, stream, bytes, false);
+}
+
+/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
+pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void {
+    var index: usize = 0;
+    while (index < bytes.len) {
+        index += try c.write(stream, bytes[index..]);
+    }
+}
+
+/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
+/// If `end` is true, then this function additionally sends a `close_notify` alert,
+/// which is necessary for the server to distinguish between a properly finished
+/// TLS session, or a truncation attack.
+pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void {
+    var index: usize = 0;
+    while (index < bytes.len) {
+        index += try c.writeEnd(stream, bytes[index..], end);
+    }
+}
+
+/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
+/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`.
+/// If `end` is true, then this function additionally sends a `close_notify` alert,
+/// which is necessary for the server to distinguish between a properly finished
+/// TLS session, or a truncation attack.
+pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize {
     var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined;
+    var iovecs_buf: [6]std.os.iovec_const = undefined;
+    var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data);
+    if (end) {
+        prepared.iovec_end += prepareCiphertextRecord(
+            c,
+            iovecs_buf[prepared.iovec_end..],
+            ciphertext_buf[prepared.ciphertext_end..],
+            &tls.close_notify_alert,
+            .alert,
+        ).iovec_end;
+    }
+
+    const iovec_end = prepared.iovec_end;
+    const overhead_len = prepared.overhead_len;
+
+    // Ideally we would call writev exactly once here, however, we must ensure
+    // that we don't return with a record partially written.
+    var i: usize = 0;
+    var total_amt: usize = 0;
+    while (true) {
+        var amt = try stream.writev(iovecs_buf[i..iovec_end]);
+        while (amt >= iovecs_buf[i].iov_len) {
+            const encrypted_amt = iovecs_buf[i].iov_len;
+            total_amt += encrypted_amt - overhead_len;
+            amt -= encrypted_amt;
+            i += 1;
+            // Rely on the property that iovecs delineate records, meaning that
+            // if amt equals zero here, we have fortunately found ourselves
+            // with a short read that aligns at the record boundary.
+            if (i >= iovec_end) return total_amt;
+            // We also cannot return on a vector boundary if the final close_notify is
+            // not sent; otherwise the caller would not know to retry the call.
+            if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt;
+        }
+        iovecs_buf[i].iov_base += amt;
+        iovecs_buf[i].iov_len -= amt;
+    }
+}
+
+fn prepareCiphertextRecord(
+    c: *Client,
+    iovecs: []std.os.iovec_const,
+    ciphertext_buf: []u8,
+    bytes: []const u8,
+    inner_content_type: tls.ContentType,
+) struct {
+    iovec_end: usize,
+    ciphertext_end: usize,
+    /// How many bytes are taken up by overhead per record.
+    overhead_len: usize,
+} {
     // Due to the trailing inner content type byte in the ciphertext, we need
     // an additional buffer for storing the cleartext into before encrypting.
     var cleartext_buf: [max_ciphertext_len]u8 = undefined;
-    var iovecs_buf: [5]std.os.iovec_const = undefined;
     var ciphertext_end: usize = 0;
     var iovec_end: usize = 0;
     var bytes_i: usize = 0;
-    // How many bytes are taken up by overhead per record.
-    const overhead_len: usize = switch (c.application_cipher) {
-        inline else => |*p| l: {
+    switch (c.application_cipher) {
+        inline else => |*p| {
             const P = @TypeOf(p.*);
             const V = @Vector(P.AEAD.nonce_length, u8);
             const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1;
+            const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len;
             while (true) {
                 const encrypted_content_len = @intCast(u16, @min(
                     @min(bytes.len - bytes_i, max_ciphertext_len - 1),
-                    ciphertext_buf.len -
-                        tls.record_header_len - P.AEAD.tag_length - ciphertext_end - 1,
+                    ciphertext_buf.len - close_notify_alert_reserved -
+                        overhead_len - ciphertext_end,
                 ));
-                if (encrypted_content_len == 0) break :l overhead_len;
+                if (encrypted_content_len == 0) return .{
+                    .iovec_end = iovec_end,
+                    .ciphertext_end = ciphertext_end,
+                    .overhead_len = overhead_len,
+                };
 
                 mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]);
-                cleartext_buf[encrypted_content_len] = @enumToInt(tls.ContentType.application_data);
+                cleartext_buf[encrypted_content_len] = @enumToInt(inner_content_type);
                 bytes_i += encrypted_content_len;
                 const ciphertext_len = encrypted_content_len + 1;
                 const cleartext = cleartext_buf[0..ciphertext_len];
@@ -626,40 +756,13 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
                 P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key);
 
                 const record = ciphertext_buf[record_start..ciphertext_end];
-                iovecs_buf[iovec_end] = .{
+                iovecs[iovec_end] = .{
                     .iov_base = record.ptr,
                     .iov_len = record.len,
                 };
                 iovec_end += 1;
             }
         },
-    };
-
-    // Ideally we would call writev exactly once here, however, we must ensure
-    // that we don't return with a record partially written.
-    var i: usize = 0;
-    var total_amt: usize = 0;
-    while (true) {
-        var amt = try stream.writev(iovecs_buf[i..iovec_end]);
-        while (amt >= iovecs_buf[i].iov_len) {
-            const encrypted_amt = iovecs_buf[i].iov_len;
-            total_amt += encrypted_amt - overhead_len;
-            amt -= encrypted_amt;
-            i += 1;
-            // Rely on the property that iovecs delineate records, meaning that
-            // if amt equals zero here, we have fortunately found ourselves
-            // with a short read that aligns at the record boundary.
-            if (i >= iovec_end or amt == 0) return total_amt;
-        }
-        iovecs_buf[i].iov_base += amt;
-        iovecs_buf[i].iov_len -= amt;
-    }
-}
-
-pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void {
-    var index: usize = 0;
-    while (index < bytes.len) {
-        index += try c.write(stream, bytes[index..]);
     }
 }
 
@@ -669,6 +772,7 @@ pub fn eof(c: Client) bool {
         c.partial_ciphertext_idx >= c.partial_ciphertext_end;
 }
 
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
 /// Returns the number of bytes read, calling the underlying read function the
 /// minimal number of times until the buffer has at least `len` bytes filled.
 /// If the number read is less than `len` it means the stream reached the end.
@@ -678,10 +782,12 @@ pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize
     return readvAtLeast(c, stream, &iovecs, len);
 }
 
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
 pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize {
     return readAtLeast(c, stream, buffer, 1);
 }
 
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
 /// Returns the number of bytes read. If the number read is smaller than
 /// `buffer.len`, it means the stream reached the end. Reaching the end of the
 /// stream is not an error condition.
@@ -689,6 +795,7 @@ pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize {
     return readAtLeast(c, stream, buffer, buffer.len);
 }
 
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
 /// Returns the number of bytes read. If the number read is less than the space
 /// provided it means the stream reached the end. Reaching the end of the
 /// stream is not an error condition.
@@ -698,6 +805,7 @@ pub fn readv(c: *Client, stream: anytype, iovecs: []std.os.iovec) !usize {
     return readvAtLeast(c, stream, iovecs);
 }
 
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
 /// Returns the number of bytes read, calling the underlying read function the
 /// minimal number of times until the iovecs have at least `len` bytes filled.
 /// If the number read is less than `len` it means the stream reached the end.
@@ -722,6 +830,7 @@ pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.os.iovec, len: us
     }
 }
 
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
 /// Returns number of bytes that have been read, populated inside `iovecs`. A
 /// return value of zero bytes does not mean end of stream. Instead, check the `eof()`
 /// for the end of stream. The `eof()` may be true after any call to
@@ -729,7 +838,7 @@ pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.os.iovec, len: us
 /// function asserts that `eof()` is `false`.
 /// See `readv` for a higher level function that has the same, familiar API as
 /// other read functions, such as `std.fs.File.read`.
-pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iovec) !usize {
+pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec) !usize {
     var vp: VecPut = .{ .iovecs = iovecs };
 
     // Give away the buffered cleartext we have, if any.
@@ -905,7 +1014,8 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
                         break :c cleartext;
                     },
                 };
-                c.read_seq += 1;
+
+                c.read_seq = try std.math.add(u64, c.read_seq, 1);
 
                 const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]);
                 switch (inner_ct) {
@@ -1196,3 +1306,7 @@ const cipher_suites = enum_array(tls.CipherSuite, &.{
     .AES_256_GCM_SHA384,
     .CHACHA20_POLY1305_SHA256,
 });
+
+test {
+    _ = StreamInterface;
+}
lib/std/crypto/tls.zig
@@ -47,6 +47,11 @@ pub const hello_retry_request_sequence = [32]u8{
     0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
 };
 
+pub const close_notify_alert = [_]u8{
+    @enumToInt(AlertLevel.warning),
+    @enumToInt(AlertDescription.close_notify),
+};
+
 pub const ProtocolVersion = enum(u16) {
     tls_1_2 = 0x0303,
     tls_1_3 = 0x0304,
lib/std/http/Client.zig
@@ -47,7 +47,7 @@ pub const Request = struct {
                 try req.stream.writeAll(req.headers.items);
             },
             .https => {
-                try req.tls_client.writeAll(req.stream, req.headers.items);
+                try req.tls_client.writeAllEnd(req.stream, req.headers.items, true);
             },
         }
     }