Commit 96533b1289

Nameless <truemedian@gmail.com>
2023-04-14 19:38:13
std.http: very basic http client proxy
1 parent 2c49206
Changed files (4)
lib/std/http/Client.zig
@@ -25,27 +25,7 @@ next_https_rescan_certs: bool = true,
 /// The pool of connections that can be reused (and currently in use).
 connection_pool: ConnectionPool = .{},
 
-pub const ExtraError = union(enum) {
-    pub const TcpConnectError = std.net.TcpConnectToHostError;
-    pub const TlsError = std.crypto.tls.Client.InitError(net.Stream);
-    pub const WriteError = BufferedConnection.WriteError;
-    pub const ReadError = BufferedConnection.ReadError || error{HttpChunkInvalid};
-    pub const CaBundleError = std.crypto.Certificate.Bundle.RescanError;
-
-    pub const ZlibInitError = error{ BadHeader, InvalidCompression, InvalidWindowSize, Unsupported, EndOfStream, OutOfMemory } || Request.TransferReadError;
-    pub const GzipInitError = error{ BadHeader, InvalidCompression, OutOfMemory, WrongChecksum, EndOfStream, StreamTooLong } || Request.TransferReadError;
-    // pub const DecompressError = Compression.DeflateDecompressor.Error || Compression.GzipDecompressor.Error || Compression.ZstdDecompressor.Error;
-    pub const DecompressError = anyerror; // FIXME: the above line causes a false positive dependency loop
-
-    zlib_init: ZlibInitError, // error.CompressionInitializationFailed
-    gzip_init: GzipInitError, // error.CompressionInitializationFailed
-    connect: TcpConnectError, // error.ConnectionFailed
-    ca_bundle: CaBundleError, // error.CertificateAuthorityBundleFailed
-    tls: TlsError, // error.TlsInitializationFailed
-    write: WriteError, // error.WriteFailed
-    read: ReadError, // error.ReadFailed
-    decompress: DecompressError, // error.ReadFailed
-};
+proxy: ?HttpProxy = null,
 
 /// A set of linked lists of connections that can be reused.
 pub const ConnectionPool = struct {
@@ -61,6 +41,7 @@ pub const ConnectionPool = struct {
         host: []u8,
         port: u16,
 
+        proxied: bool = false,
         closing: bool = false,
 
         pub fn deinit(self: *StoredConnection, client: *Client) void {
@@ -137,7 +118,12 @@ pub const ConnectionPool = struct {
             return client.allocator.destroy(popped);
         }
 
-        pool.free.append(node);
+        if (node.data.proxied) {
+            pool.free.prepend(node); // proxied connections go to the end of the queue, always try direct connections first
+        } else {
+            pool.free.append(node);
+        }
+
         pool.free_len += 1;
     }
 
@@ -546,9 +532,10 @@ pub const Request = struct {
         if (!req.response.parser.done) {
             // If the response wasn't fully read, then we need to close the connection.
             req.connection.data.closing = true;
-            req.client.connection_pool.release(req.client, req.connection);
         }
 
+        req.client.connection_pool.release(req.client, req.connection);
+
         req.arena.deinit();
         req.* = undefined;
     }
@@ -557,30 +544,20 @@ pub const Request = struct {
         var buffered = std.io.bufferedWriter(req.connection.data.buffered.writer());
         const w = buffered.writer();
 
-        const escaped_path = try Uri.escapePath(req.client.allocator, uri.path);
-        defer req.client.allocator.free(escaped_path);
-
-        const escaped_query = if (uri.query) |q| try Uri.escapeQuery(req.client.allocator, q) else null;
-        defer if (escaped_query) |q| req.client.allocator.free(q);
-
-        const escaped_fragment = if (uri.fragment) |f| try Uri.escapeQuery(req.client.allocator, f) else null;
-        defer if (escaped_fragment) |f| req.client.allocator.free(f);
-
         try w.writeAll(@tagName(headers.method));
         try w.writeByte(' ');
-        if (escaped_path.len == 0) {
-            try w.writeByte('/');
+
+        if (req.headers.method == .CONNECT) {
+            try w.writeAll(uri.host.?);
+            try w.writeByte(':');
+            try w.print("{}", .{uri.port.?});
+        } else if (req.connection.data.proxied) {
+            // proxied connections require the full uri
+            try w.print("{+/}", .{uri});
         } else {
-            try w.writeAll(escaped_path);
-        }
-        if (escaped_query) |q| {
-            try w.writeByte('?');
-            try w.writeAll(q);
-        }
-        if (escaped_fragment) |f| {
-            try w.writeByte('#');
-            try w.writeAll(f);
+            try w.print("{/}", .{uri});
         }
+
         try w.writeByte(' ');
         try w.writeAll(@tagName(headers.version));
         try w.writeAll("\r\nHost: ");
@@ -659,6 +636,12 @@ pub const Request = struct {
                 req.response.parser.done = true;
             }
 
+            if (req.headers.method == .CONNECT and req.response.headers.status == .ok) {
+                req.connection.data.closing = false;
+                req.connection.data.proxied = true;
+                req.response.parser.done = true;
+            }
+
             if (req.headers.connection == .keep_alive and req.response.headers.connection == .keep_alive) {
                 req.connection.data.closing = false;
             } else {
@@ -802,7 +785,7 @@ pub const Request = struct {
         }
     }
 
-    pub const FinishError = WriteError || error{ MessageNotCompleted };
+    pub const FinishError = WriteError || error{MessageNotCompleted};
 
     /// Finish the body of a request. This notifies the server that you have no more data to send.
     pub fn finish(req: *Request) FinishError!void {
@@ -817,6 +800,20 @@ pub const Request = struct {
     }
 };
 
+pub const HttpProxy = struct {
+    pub const ProxyAuthentication = union(enum) {
+        basic: []const u8,
+        custom: []const u8,
+    };
+
+    protocol: Connection.Protocol,
+    host: []const u8,
+    port: ?u16 = null,
+
+    /// The value for the Proxy-Authorization header.
+    auth: ?ProxyAuthentication = null,
+};
+
 /// Release all associated resources with the client.
 /// TODO: currently leaks all request allocated data
 pub fn deinit(client: *Client) void {
@@ -826,11 +823,11 @@ pub fn deinit(client: *Client) void {
     client.* = undefined;
 }
 
-pub const ConnectError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed };
+pub const ConnectUnproxiedError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed };
 
 /// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open.
 /// This function is threadsafe.
-pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node {
+pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectUnproxiedError!*ConnectionPool.Node {
     if (client.connection_pool.findConnection(.{
         .host = host,
         .port = port,
@@ -884,7 +881,34 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
     return conn;
 }
 
-pub const RequestError = ConnectError || BufferedConnection.WriteError || error{
+// Prevents a dependency loop in request()
+const ConnectErrorPartial = ConnectUnproxiedError || error{ UnsupportedUrlScheme, ConnectionRefused };
+pub const ConnectError = ConnectErrorPartial || RequestError;
+
+pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node {
+    if (client.connection_pool.findConnection(.{
+        .host = host,
+        .port = port,
+        .is_tls = protocol == .tls,
+    })) |node|
+        return node;
+
+    if (client.proxy) |proxy| {
+        const proxy_port: u16 = proxy.port orelse switch (proxy.protocol) {
+            .plain => 80,
+            .tls => 443,
+        };
+
+        const conn = try client.connectUnproxied(proxy.host, proxy_port, proxy.protocol);
+        conn.data.proxied = true;
+
+        return conn;
+    } else {
+        return client.connectUnproxied(host, port, protocol);
+    }
+}
+
+pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || BufferedConnection.WriteError || error{
     UnsupportedUrlScheme,
     UriMissingHost,
 
@@ -896,6 +920,9 @@ pub const Options = struct {
     max_redirects: u32 = 3,
     header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 },
 
+    /// Must be an already acquired connection.
+    connection: ?*ConnectionPool.Node = null,
+
     pub const HeaderStrategy = union(enum) {
         /// In this case, the client's Allocator will be used to store the
         /// entire HTTP header. This value is the maximum total size of
@@ -939,10 +966,12 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Opt
         }
     }
 
+    const conn = options.connection orelse try client.connect(host, port, protocol);
+
     var req: Request = .{
         .uri = uri,
         .client = client,
-        .connection = try client.connect(host, port, protocol),
+        .connection = conn,
         .headers = headers,
         .redirects_left = options.max_redirects,
         .handle_redirects = options.handle_redirects,
lib/std/http/protocol.zig
@@ -1,4 +1,4 @@
-const std = @import("std");
+const std = @import("../std.zig");
 const testing = std.testing;
 const mem = std.mem;
 
lib/std/http.zig
@@ -265,7 +265,7 @@ pub const Connection = enum {
     close,
 };
 
-pub const CustomHeader = struct {
+pub const Header = struct {
     name: []const u8,
     value: []const u8,
 };
lib/std/Uri.zig
@@ -27,6 +27,18 @@ pub fn escapeQuery(allocator: std.mem.Allocator, input: []const u8) error{OutOfM
     return escapeStringWithFn(allocator, input, isQueryChar);
 }
 
+pub fn writeEscapedString(writer: anytype, input: []const u8) !void {
+    return writeEscapedStringWithFn(writer, input, isUnreserved);
+}
+
+pub fn writeEscapedPath(writer: anytype, input: []const u8) !void {
+    return writeEscapedStringWithFn(writer, input, isPathChar);
+}
+
+pub fn writeEscapedQuery(writer: anytype, input: []const u8) !void {
+    return writeEscapedStringWithFn(writer, input, isQueryChar);
+}
+
 pub fn escapeStringWithFn(allocator: std.mem.Allocator, input: []const u8, comptime keepUnescaped: fn (c: u8) bool) std.mem.Allocator.Error![]const u8 {
     var outsize: usize = 0;
     for (input) |c| {
@@ -52,6 +64,16 @@ pub fn escapeStringWithFn(allocator: std.mem.Allocator, input: []const u8, compt
     return output;
 }
 
+pub fn writeEscapedStringWithFn(writer: anytype, input: []const u8, comptime keepUnescaped: fn (c: u8) bool) @TypeOf(writer).Error!void {
+    for (input) |c| {
+        if (keepUnescaped(c)) {
+            try writer.writeByte(c);
+        } else {
+            try writer.print("%{X:0>2}", .{c});
+        }
+    }
+}
+
 /// Parses a URI string and unescapes all %XX where XX is a valid hex number. Otherwise, verbatim copies
 /// them to the output.
 pub fn unescapeString(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 {
@@ -184,6 +206,60 @@ pub fn parseWithoutScheme(text: []const u8) ParseError!Uri {
     return uri;
 }
 
+pub fn format(
+    uri: Uri,
+    comptime fmt: []const u8,
+    options: std.fmt.FormatOptions,
+    writer: anytype,
+) @TypeOf(writer).Error!void {
+    _ = options;
+
+    const needs_absolute = comptime std.mem.indexOf(u8, fmt, "+") != null;
+    const needs_path = comptime std.mem.indexOf(u8, fmt, "/") != null or fmt.len == 0;
+
+    if (needs_absolute) {
+        try writer.writeAll(uri.scheme);
+        try writer.writeAll(":");
+        if (uri.host) |host| {
+            try writer.writeAll("//");
+
+            if (uri.user) |user| {
+                try writer.writeAll(user);
+                if (uri.password) |password| {
+                    try writer.writeAll(":");
+                    try writer.writeAll(password);
+                }
+                try writer.writeAll("@");
+            }
+
+            try writer.writeAll(host);
+
+            if (uri.port) |port| {
+                try writer.writeAll(":");
+                try std.fmt.formatInt(port, 10, .lower, .{}, writer);
+            }
+        }
+    }
+
+    if (needs_path) {
+        if (uri.path.len == 0) {
+            try writer.writeAll("/");
+        } else {
+            try Uri.writeEscapedPath(writer, uri.path);
+        }
+
+        if (uri.query) |q| {
+            try writer.writeAll("?");
+            try Uri.writeEscapedQuery(writer, q);
+        }
+
+        if (uri.fragment) |f| {
+            try writer.writeAll("#");
+            try Uri.writeEscapedQuery(writer, f);
+        }
+    }
+}
+
 /// Parses the URI or returns an error.
 /// The return value will contain unescaped strings pointing into the
 /// original `text`. Each component that is provided, will be non-`null`.