Commit 0eef21d8ec

Nameless <truemedian@gmail.com>
2023-10-05 19:19:06
std.http.Client: add option to disable https
std_options.http_connection_pool_size removed in favor of ``` client.connection_pool.resize(client.allocator, size); ``` std_options.http_disable_tls will remove all https capability from std.http when true. Any https request will error with `error.TlsInitializationFailed`. Solves #17051.
1 parent e1c37f7
Changed files (3)
lib
test
standalone
lib/std/http/Client.zig
@@ -12,8 +12,7 @@ const assert = std.debug.assert;
 const Client = @This();
 const proto = @import("protocol.zig");
 
-pub const default_connection_pool_size = 32;
-pub const connection_pool_size = std.options.http_connection_pool_size;
+pub const disable_tls = std.options.http_disable_tls;
 
 allocator: Allocator,
 ca_bundle: std.crypto.Certificate.Bundle = .{},
@@ -50,7 +49,7 @@ pub const ConnectionPool = struct {
     /// Open connections that are not currently in use.
     free: Queue = .{},
     free_len: usize = 0,
-    free_size: usize = connection_pool_size,
+    free_size: usize = 32,
 
     /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe.
     /// If no connection is found, null is returned.
@@ -127,23 +126,43 @@ pub const ConnectionPool = struct {
         pool.used.append(node);
     }
 
-    pub fn deinit(pool: *ConnectionPool, client: *Client) void {
+    /// Resizes the connection pool. This function is threadsafe.
+    ///
+    /// If the new size is smaller than the current size, then idle connections will be closed until the pool is the new size.
+    pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void {
+        pool.mutex.lock();
+        defer pool.mutex.unlock();
+
+        var next = pool.free.first;
+        _ = next;
+        while (pool.free_len > new_size) {
+            const popped = pool.free.popFirst() orelse unreachable;
+            pool.free_len -= 1;
+
+            popped.data.close(allocator);
+            allocator.destroy(popped);
+        }
+
+        pool.free_size = new_size;
+    }
+
+    pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void {
         pool.mutex.lock();
 
         var next = pool.free.first;
         while (next) |node| {
-            defer client.allocator.destroy(node);
+            defer allocator.destroy(node);
             next = node.next;
 
-            node.data.close(client.allocator);
+            node.data.close(allocator);
         }
 
         next = pool.used.first;
         while (next) |node| {
-            defer client.allocator.destroy(node);
+            defer allocator.destroy(node);
             next = node.next;
 
-            node.data.close(client.allocator);
+            node.data.close(allocator);
         }
 
         pool.* = undefined;
@@ -159,7 +178,7 @@ pub const Connection = struct {
 
     stream: net.Stream,
     /// undefined unless protocol is tls.
-    tls_client: *std.crypto.tls.Client,
+    tls_client: if (!disable_tls) *std.crypto.tls.Client else void,
 
     protocol: Protocol,
     host: []u8,
@@ -174,11 +193,8 @@ pub const Connection = struct {
     read_buf: [buffer_size]u8 = undefined,
     write_buf: [buffer_size]u8 = undefined,
 
-    pub fn readvDirect(conn: *Connection, buffers: []std.os.iovec) ReadError!usize {
-        return switch (conn.protocol) {
-            .plain => conn.stream.readv(buffers),
-            .tls => conn.tls_client.readv(conn.stream, buffers),
-        } catch |err| {
+    pub fn readvDirectTls(conn: *Connection, buffers: []std.os.iovec) ReadError!usize {
+        return conn.tls_client.readv(conn.stream, buffers) catch |err| {
             // TODO: https://github.com/ziglang/zig/issues/2473
             if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert;
 
@@ -191,6 +207,20 @@ pub const Connection = struct {
         };
     }
 
+    pub fn readvDirect(conn: *Connection, buffers: []std.os.iovec) ReadError!usize {
+        if (conn.protocol == .tls) {
+            if (disable_tls) unreachable;
+
+            return conn.readvDirectTls(buffers);
+        }
+
+        return conn.stream.readv(buffers) catch |err| switch (err) {
+            error.ConnectionTimedOut => return error.ConnectionTimedOut,
+            error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
+            else => return error.UnexpectedReadFailure,
+        };
+    }
+
     pub fn fill(conn: *Connection) ReadError!void {
         if (conn.read_end != conn.read_start) return;
 
@@ -257,11 +287,21 @@ pub const Connection = struct {
         return Reader{ .context = conn };
     }
 
+    pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void {
+        return conn.tls_client.writeAll(conn.stream, buffer) catch |err| switch (err) {
+            error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
+            else => return error.UnexpectedWriteFailure,
+        };
+    }
+
     pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void {
-        return switch (conn.protocol) {
-            .plain => conn.stream.writeAll(buffer),
-            .tls => conn.tls_client.writeAll(conn.stream, buffer),
-        } catch |err| switch (err) {
+        if (conn.protocol == .tls) {
+            if (disable_tls) unreachable;
+
+            return conn.writeAllDirectTls(buffer);
+        }
+
+        return conn.stream.writeAll(buffer) catch |err| switch (err) {
             error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
             else => return error.UnexpectedWriteFailure,
         };
@@ -303,6 +343,8 @@ pub const Connection = struct {
 
     pub fn close(conn: *Connection, allocator: Allocator) void {
         if (conn.protocol == .tls) {
+            if (disable_tls) unreachable;
+
             // try to cleanly close the TLS connection, for any server that cares.
             _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {};
             allocator.destroy(conn.tls_client);
@@ -932,7 +974,7 @@ pub const ProxyInformation = struct {
 /// Release all associated resources with the client.
 /// TODO: currently leaks all request allocated data
 pub fn deinit(client: *Client) void {
-    client.connection_pool.deinit(client);
+    client.connection_pool.deinit(client.allocator);
 
     if (client.http_proxy) |*proxy| {
         proxy.allocator.free(proxy.host);
@@ -1046,6 +1088,9 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec
     })) |node|
         return node;
 
+    if (disable_tls and protocol == .tls)
+        return error.TlsInitializationFailed;
+
     const conn = try client.allocator.create(ConnectionPool.Node);
     errdefer client.allocator.destroy(conn);
     conn.* = .{ .data = undefined };
@@ -1073,17 +1118,16 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec
     };
     errdefer client.allocator.free(conn.data.host);
 
-    switch (protocol) {
-        .plain => {},
-        .tls => {
-            conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client);
-            errdefer client.allocator.destroy(conn.data.tls_client);
+    if (protocol == .tls) {
+        if (disable_tls) unreachable;
 
-            conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed;
-            // This is appropriate for HTTPS because the HTTP headers contain
-            // the content length which is used to detect truncation attacks.
-            conn.data.tls_client.allow_truncation_attacks = true;
-        },
+        conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client);
+        errdefer client.allocator.destroy(conn.data.tls_client);
+
+        conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed;
+        // This is appropriate for HTTPS because the HTTP headers contain
+        // the content length which is used to detect truncation attacks.
+        conn.data.tls_client.allow_truncation_attacks = true;
     }
 
     client.connection_pool.addUsed(conn);
lib/std/std.zig
@@ -283,10 +283,15 @@ pub const options = struct {
     else
         false;
 
-    pub const http_connection_pool_size = if (@hasDecl(options_override, "http_connection_pool_size"))
-        options_override.http_connection_pool_size
+    /// By default, std.http.Client will support HTTPS connections.  Set this option to `true` to
+    /// disable TLS support.
+    /// 
+    /// This will likely reduce the size of the binary, but it will also make it impossible to
+    /// make a HTTPS connection.
+    pub const http_disable_tls = if (@hasDecl(options_override, "http_disable_tls"))
+        options_override.http_disable_tls
     else
-        http.Client.default_connection_pool_size;
+        false;
 
     pub const side_channels_mitigations: crypto.SideChannelsMitigations = if (@hasDecl(options_override, "side_channels_mitigations"))
         options_override.side_channels_mitigations
test/standalone/http.zig
@@ -7,6 +7,10 @@ const Client = http.Client;
 const mem = std.mem;
 const testing = std.testing;
 
+pub const std_options = struct {
+    pub const http_disable_tls = true;
+};
+
 const max_header_size = 8192;
 
 var gpa_server = std.heap.GeneralPurposeAllocator(.{ .stack_trace_frames = 12 }){};