Commit 46b34949c3

Andrew Kelley <andrew@ziglang.org>
2025-08-07 07:39:26
TLS, HTTP, and package fetching fixes
* TLS: add missing assert for output buffer length requirement * TLS: add missing flushes * TLS: add flush implementation * TLS: finish drain implementation * HTTP: correct buffer sizes for TLS * HTTP: expose a getReadError method on Connection * HTTP: add missing flush on sendBodyComplete * Fetch: remove unwanted deinit * Fetch: improve error reporting
1 parent 172d31b
Changed files (3)
lib
std
crypto
http
src
Package
lib/std/crypto/tls/Client.zig
@@ -8,8 +8,8 @@ const mem = std.mem;
 const crypto = std.crypto;
 const assert = std.debug.assert;
 const Certificate = std.crypto.Certificate;
-const Reader = std.io.Reader;
-const Writer = std.io.Writer;
+const Reader = std.Io.Reader;
+const Writer = std.Io.Writer;
 
 const max_ciphertext_len = tls.max_ciphertext_len;
 const hmacExpandLabel = tls.hmacExpandLabel;
@@ -27,6 +27,8 @@ reader: Reader,
 
 /// The encrypted stream from the client to the server. Bytes are pushed here
 /// via `writer`.
+///
+/// The buffer is asserted to have capacity at least `min_buffer_len`.
 output: *Writer,
 /// The plaintext stream from the client to the server.
 writer: Writer,
@@ -122,7 +124,6 @@ pub const Options = struct {
     /// the amount of data expected, such as HTTP with the Content-Length header.
     allow_truncation_attacks: bool = false,
     write_buffer: []u8,
-    /// Asserted to have capacity at least `min_buffer_len`.
     read_buffer: []u8,
     /// Populated when `error.TlsAlert` is returned from `init`.
     alert: ?*tls.Alert = null,
@@ -185,6 +186,7 @@ const InitError = error{
 /// `input` is asserted to have buffer capacity at least `min_buffer_len`.
 pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client {
     assert(input.buffer.len >= min_buffer_len);
+    assert(output.buffer.len >= min_buffer_len);
     const host = switch (options.host) {
         .no_verification => "",
         .explicit => |host| host,
@@ -278,6 +280,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
     {
         var iovecs: [2][]const u8 = .{ cleartext_header, host };
         try output.writeVecAll(iovecs[0..if (host.len == 0) 1 else 2]);
+        try output.flush();
     }
 
     var tls_version: tls.ProtocolVersion = undefined;
@@ -763,6 +766,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
                                     &client_verify_msg,
                                 };
                                 try output.writeVecAll(&all_msgs_vec);
+                                try output.flush();
                             },
                         }
                         write_seq += 1;
@@ -828,6 +832,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
                                         &finished_msg,
                                     };
                                     try output.writeVecAll(&all_msgs_vec);
+                                    try output.flush();
 
                                     const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
                                     const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
@@ -877,7 +882,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
                                 .buffer = options.write_buffer,
                                 .vtable = &.{
                                     .drain = drain,
-                                    .sendFile = Writer.unimplementedSendFile,
+                                    .flush = flush,
                                 },
                             },
                             .tls_version = tls_version,
@@ -911,31 +916,56 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
 
 fn drain(w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize {
     const c: *Client = @alignCast(@fieldParentPtr("writer", w));
-    if (true) @panic("update to use the buffer and flush");
-    const sliced_data = if (splat == 0) data[0..data.len -| 1] else data;
     const output = c.output;
     const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
-    var total_clear: usize = 0;
     var ciphertext_end: usize = 0;
-    for (sliced_data) |buf| {
-        const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
-        total_clear += prepared.cleartext_len;
-        ciphertext_end += prepared.ciphertext_end;
-        if (total_clear < buf.len) break;
+    var total_clear: usize = 0;
+    done: {
+        {
+            const buf = w.buffered();
+            const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
+            total_clear += prepared.cleartext_len;
+            ciphertext_end += prepared.ciphertext_end;
+            if (prepared.cleartext_len < buf.len) break :done;
+        }
+        for (data[0 .. data.len - 1]) |buf| {
+            if (buf.len < min_buffer_len) break :done;
+            const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
+            total_clear += prepared.cleartext_len;
+            ciphertext_end += prepared.ciphertext_end;
+            if (prepared.cleartext_len < buf.len) break :done;
+        }
+        const buf = data[data.len - 1];
+        for (0..splat) |_| {
+            if (buf.len < min_buffer_len) break :done;
+            const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
+            total_clear += prepared.cleartext_len;
+            ciphertext_end += prepared.ciphertext_end;
+            if (prepared.cleartext_len < buf.len) break :done;
+        }
     }
     output.advance(ciphertext_end);
-    return total_clear;
+    return w.consume(total_clear);
+}
+
+fn flush(w: *Writer) Writer.Error!void {
+    const c: *Client = @alignCast(@fieldParentPtr("writer", w));
+    const output = c.output;
+    const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
+    const prepared = prepareCiphertextRecord(c, ciphertext_buf, w.buffered(), .application_data);
+    output.advance(prepared.ciphertext_end);
+    w.end = 0;
 }
 
 /// Sends a `close_notify` alert, which is necessary for the server to
 /// distinguish between a properly finished TLS session, or a truncation
 /// attack.
 pub fn end(c: *Client) Writer.Error!void {
+    try flush(&c.writer);
     const output = c.output;
     const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
     const prepared = prepareCiphertextRecord(c, ciphertext_buf, &tls.close_notify_alert, .alert);
-    output.advance(prepared.cleartext_len);
-    return prepared.ciphertext_end;
+    output.advance(prepared.ciphertext_end);
 }
 
 fn prepareCiphertextRecord(
@@ -1045,7 +1075,7 @@ pub fn eof(c: Client) bool {
     return c.received_close_notify;
 }
 
-fn stream(r: *Reader, w: *Writer, limit: std.io.Limit) Reader.StreamError!usize {
+fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
     const c: *Client = @alignCast(@fieldParentPtr("reader", r));
     if (c.eof()) return error.EndOfStream;
     const input = c.input;
lib/std/http/Client.zig
@@ -42,7 +42,7 @@ connection_pool: ConnectionPool = .{},
 ///
 /// If the entire HTTP header cannot fit in this amount of bytes,
 /// `error.HttpHeadersOversize` will be returned from `Request.wait`.
-read_buffer_size: usize = 4096,
+read_buffer_size: usize = 4096 + if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len,
 /// Each `Connection` allocates this amount for the writer buffer.
 write_buffer_size: usize = 1024,
 
@@ -304,15 +304,16 @@ pub const Connection = struct {
             const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len];
             const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.tls_buffer_size];
             const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size];
-            const socket_write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size];
-            assert(base.ptr + alloc_len == socket_write_buffer.ptr + socket_write_buffer.len);
+            const write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size];
+            const read_buffer = write_buffer.ptr[write_buffer.len..][0..client.read_buffer_size];
+            assert(base.ptr + alloc_len == read_buffer.ptr + read_buffer.len);
             @memcpy(host_buffer, remote_host);
             const tls: *Tls = @ptrCast(base);
             tls.* = .{
                 .connection = .{
                     .client = client,
-                    .stream_writer = stream.writer(socket_write_buffer),
-                    .stream_reader = stream.reader(&.{}),
+                    .stream_writer = stream.writer(tls_write_buffer),
+                    .stream_reader = stream.reader(tls_read_buffer),
                     .pool_node = .{},
                     .port = port,
                     .host_len = @intCast(remote_host.len),
@@ -328,8 +329,8 @@ pub const Connection = struct {
                         .host = .{ .explicit = remote_host },
                         .ca = .{ .bundle = client.ca_bundle },
                         .ssl_key_log = client.ssl_key_log,
-                        .read_buffer = tls_read_buffer,
-                        .write_buffer = tls_write_buffer,
+                        .read_buffer = read_buffer,
+                        .write_buffer = write_buffer,
                         // This is appropriate for HTTPS because the HTTP headers contain
                         // the content length which is used to detect truncation attacks.
                         .allow_truncation_attacks = true,
@@ -347,7 +348,8 @@ pub const Connection = struct {
         }
 
         fn allocLen(client: *Client, host_len: usize) usize {
-            return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size + client.write_buffer_size;
+            return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size +
+                client.write_buffer_size + client.read_buffer_size;
         }
 
         fn host(tls: *Tls) []u8 {
@@ -356,6 +358,21 @@ pub const Connection = struct {
         }
     };
 
+    pub const ReadError = std.crypto.tls.Client.ReadError || std.net.Stream.ReadError;
+
+    pub fn getReadError(c: *const Connection) ?ReadError {
+        return switch (c.protocol) {
+            .tls => {
+                if (disable_tls) unreachable;
+                const tls: *const Tls = @alignCast(@fieldParentPtr("connection", c));
+                return tls.client.read_err orelse c.stream_reader.getError();
+            },
+            .plain => {
+                return c.stream_reader.getError();
+            },
+        };
+    }
+
     fn getStream(c: *Connection) net.Stream {
         return c.stream_reader.getStream();
     }
@@ -434,7 +451,6 @@ pub const Connection = struct {
             if (disable_tls) unreachable;
             const tls: *Tls = @alignCast(@fieldParentPtr("connection", c));
             try tls.client.end();
-            try tls.client.writer.flush();
         }
         try c.stream_writer.interface.flush();
     }
@@ -874,6 +890,7 @@ pub const Request = struct {
         var bw = try sendBodyUnflushed(r, body);
         bw.writer.end = body.len;
         try bw.end();
+        try r.connection.?.flush();
     }
 
     /// Transfers the HTTP head over the connection, which is not flushed until
@@ -1063,6 +1080,9 @@ pub const Request = struct {
     /// buffer capacity would be exceeded, `error.HttpRedirectLocationOversize`
     /// is returned instead. This buffer may be empty if no redirects are to be
     /// handled.
+    ///
+    /// If this fails with `error.ReadFailed` then the `Connection.getReadError`
+    /// method of `r.connection` can be used to get more detailed information.
     pub fn receiveHead(r: *Request, redirect_buffer: []u8) ReceiveHeadError!Response {
         var aux_buf = redirect_buffer;
         while (true) {
src/Package/Fetch.zig
@@ -998,15 +998,21 @@ fn initResource(f: *Fetch, uri: std.Uri, resource: *Resource, reader_buffer: []u
             .buffer = reader_buffer,
         } };
         const request = &resource.http_request.request;
-        defer request.deinit();
+        errdefer request.deinit();
 
         request.sendBodiless() catch |err|
             return f.fail(f.location_tok, try eb.printString("HTTP request failed: {t}", .{err}));
 
         var redirect_buffer: [1024]u8 = undefined;
         const response = &resource.http_request.response;
-        response.* = request.receiveHead(&redirect_buffer) catch |err|
-            return f.fail(f.location_tok, try eb.printString("invalid HTTP response: {t}", .{err}));
+        response.* = request.receiveHead(&redirect_buffer) catch |err| switch (err) {
+            error.ReadFailed => {
+                return f.fail(f.location_tok, try eb.printString("HTTP response read failure: {t}", .{
+                    request.connection.?.getReadError().?,
+                }));
+            },
+            else => |e| return f.fail(f.location_tok, try eb.printString("invalid HTTP response: {t}", .{e})),
+        };
 
         if (response.head.status != .ok) return f.fail(f.location_tok, try eb.printString(
             "bad HTTP response code: '{d} {s}'",