Commit 2c492064fb

Nameless <truemedian@gmail.com>
2023-04-13 06:26:40
std.http: further curate error set, remove last_error
1 parent 038ed32
Changed files (1)
lib
std
lib/std/http/Client.zig
@@ -25,9 +25,6 @@ next_https_rescan_certs: bool = true,
 /// The pool of connections that can be reused (and currently in use).
 connection_pool: ConnectionPool = .{},
 
-/// The last error that occurred on this client. This is not threadsafe, do not expect it to be completely accurate.
-last_error: ?ExtraError = null,
-
 pub const ExtraError = union(enum) {
     pub const TcpConnectError = std.net.TcpConnectToHostError;
     pub const TlsError = std.crypto.tls.Client.InitError(net.Stream);
@@ -184,31 +181,33 @@ pub const Connection = struct {
 
     pub const Protocol = enum { plain, tls };
 
-    pub fn read(conn: *Connection, buffer: []u8) !usize {
-        switch (conn.protocol) {
-            .plain => return conn.stream.read(buffer),
-            .tls => return conn.tls_client.read(conn.stream, buffer),
-        }
+    pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
+        return switch (conn.protocol) {
+            .plain => conn.stream.read(buffer),
+            .tls => conn.tls_client.read(conn.stream, buffer),
+        } catch |err| switch (err) {
+            error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
+            error.TlsAlert => return error.TlsAlert,
+            error.ConnectionTimedOut => return error.ConnectionTimedOut,
+            error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
+            else => return error.UnexpectedReadFailure,
+        };
     }
 
-    pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize {
-        switch (conn.protocol) {
-            .plain => return conn.stream.readAtLeast(buffer, len),
-            .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len),
-        }
+    pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
+        return switch (conn.protocol) {
+            .plain => conn.stream.readAtLeast(buffer, len),
+            .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len),
+        } catch |err| switch (err) {
+            error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
+            error.TlsAlert => return error.TlsAlert,
+            error.ConnectionTimedOut => return error.ConnectionTimedOut,
+            error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
+            else => return error.UnexpectedReadFailure,
+        };
     }
 
-    pub const ReadError = net.Stream.ReadError || error{
-        TlsConnectionTruncated,
-        TlsRecordOverflow,
-        TlsDecodeError,
-        TlsAlert,
-        TlsBadRecordMac,
-        Overflow,
-        TlsBadLength,
-        TlsIllegalParameter,
-        TlsUnexpectedMessage,
-    };
+    pub const ReadError = error{ TlsFailure, TlsAlert, ConnectionTimedOut, ConnectionResetByPeer, UnexpectedReadFailure };
 
     pub const Reader = std.io.Reader(*Connection, ReadError, read);
 
@@ -217,20 +216,30 @@ pub const Connection = struct {
     }
 
     pub fn writeAll(conn: *Connection, buffer: []const u8) !void {
-        switch (conn.protocol) {
-            .plain => return conn.stream.writeAll(buffer),
-            .tls => return conn.tls_client.writeAll(conn.stream, buffer),
-        }
+        return switch (conn.protocol) {
+            .plain => conn.stream.writeAll(buffer),
+            .tls => conn.tls_client.writeAll(conn.stream, buffer),
+        } catch |err| switch (err) {
+            error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
+            else => return error.UnexpectedWriteFailure,
+        };
     }
 
     pub fn write(conn: *Connection, buffer: []const u8) !usize {
-        switch (conn.protocol) {
-            .plain => return conn.stream.write(buffer),
-            .tls => return conn.tls_client.write(conn.stream, buffer),
-        }
+        return switch (conn.protocol) {
+            .plain => conn.stream.write(buffer),
+            .tls => conn.tls_client.write(conn.stream, buffer),
+        } catch |err| switch (err) {
+            error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
+            else => return error.UnexpectedWriteFailure,
+        };
     }
 
-    pub const WriteError = net.Stream.WriteError || error{};
+    pub const WriteError = error{
+        ConnectionResetByPeer,
+        UnexpectedWriteFailure,
+    };
+
     pub const Writer = std.io.Writer(*Connection, WriteError, write);
 
     pub fn writer(conn: *Connection) Writer {
@@ -604,7 +613,7 @@ pub const Request = struct {
         try buffered.flush();
     }
 
-    pub const TransferReadError = proto.HeadersParser.ReadError || error{ReadFailed};
+    pub const TransferReadError = BufferedConnection.ReadError || proto.HeadersParser.ReadError;
 
     pub const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead);
 
@@ -617,10 +626,7 @@ pub const Request = struct {
 
         var index: usize = 0;
         while (index == 0) {
-            const amt = req.response.parser.read(&req.connection.data.buffered, buf[index..], req.response.skip) catch |err| {
-                req.client.last_error = .{ .read = err };
-                return error.ReadFailed;
-            };
+            const amt = try req.response.parser.read(&req.connection.data.buffered, buf[index..], req.response.skip);
             if (amt == 0 and req.response.parser.done) break;
             index += amt;
         }
@@ -638,10 +644,7 @@ pub const Request = struct {
     pub fn do(req: *Request) DoError!void {
         while (true) { // handle redirects
             while (true) { // read headers
-                req.connection.data.buffered.fill() catch |err| {
-                    req.client.last_error = .{ .read = err };
-                    return error.ReadFailed;
-                };
+                try req.connection.data.buffered.fill();
 
                 const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
                 req.connection.data.buffered.clear(@intCast(u16, nchecked));
@@ -712,16 +715,10 @@ pub const Request = struct {
                     if (req.response.headers.transfer_compression) |tc| switch (tc) {
                         .compress => return error.CompressionNotSupported,
                         .deflate => req.response.compression = .{
-                            .deflate = std.compress.zlib.zlibStream(req.client.allocator, req.transferReader()) catch |err| {
-                                req.client.last_error = .{ .zlib_init = err };
-                                return error.CompressionInitializationFailed;
-                            },
+                            .deflate = std.compress.zlib.zlibStream(req.client.allocator, req.transferReader()) catch return error.CompressionInitializationFailed,
                         },
                         .gzip => req.response.compression = .{
-                            .gzip = std.compress.gzip.decompress(req.client.allocator, req.transferReader()) catch |err| {
-                                req.client.last_error = .{ .gzip_init = err };
-                                return error.CompressionInitializationFailed;
-                            },
+                            .gzip = std.compress.gzip.decompress(req.client.allocator, req.transferReader()) catch return error.CompressionInitializationFailed,
                         },
                         .zstd => req.response.compression = .{
                             .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()),
@@ -734,7 +731,7 @@ pub const Request = struct {
         }
     }
 
-    pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError;
+    pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{DecompressionFailure};
 
     pub const Reader = std.io.Reader(*Request, ReadError, read);
 
@@ -746,30 +743,15 @@ pub const Request = struct {
     pub fn read(req: *Request, buffer: []u8) ReadError!usize {
         while (true) {
             const out_index = switch (req.response.compression) {
-                .deflate => |*deflate| deflate.read(buffer) catch |err| {
-                    req.client.last_error = .{ .decompress = err };
-                    err catch {};
-                    return error.ReadFailed;
-                },
-                .gzip => |*gzip| gzip.read(buffer) catch |err| {
-                    req.client.last_error = .{ .decompress = err };
-                    err catch {};
-                    return error.ReadFailed;
-                },
-                .zstd => |*zstd| zstd.read(buffer) catch |err| {
-                    req.client.last_error = .{ .decompress = err };
-                    err catch {};
-                    return error.ReadFailed;
-                },
+                .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure,
+                .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
+                .zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
                 else => try req.transferRead(buffer),
             };
 
             if (out_index == 0) {
                 while (!req.response.parser.state.isContent()) { // read trailing headers
-                    req.connection.data.buffered.fill() catch |err| {
-                        req.client.last_error = .{ .read = err };
-                        return error.ReadFailed;
-                    };
+                    try req.connection.data.buffered.fill();
 
                     const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
                     req.connection.data.buffered.clear(@intCast(u16, nchecked));
@@ -784,17 +766,14 @@ pub const Request = struct {
     pub fn readAll(req: *Request, buffer: []u8) !usize {
         var index: usize = 0;
         while (index < buffer.len) {
-            const amt = read(req, buffer[index..]) catch |err| {
-                req.client.last_error = .{ .read = err };
-                return error.ReadFailed;
-            };
+            const amt = try read(req, buffer[index..]);
             if (amt == 0) break;
             index += amt;
         }
         return index;
     }
 
-    pub const WriteError = error{ WriteFailed, NotWriteable, MessageTooLong };
+    pub const WriteError = BufferedConnection.WriteError || error{ NotWriteable, MessageTooLong };
 
     pub const Writer = std.io.Writer(*Request, WriteError, write);
 
@@ -806,28 +785,16 @@ pub const Request = struct {
     pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
         switch (req.headers.transfer_encoding) {
             .chunked => {
-                req.connection.data.conn.writer().print("{x}\r\n", .{bytes.len}) catch |err| {
-                    req.client.last_error = .{ .write = err };
-                    return error.WriteFailed;
-                };
-                req.connection.data.conn.writeAll(bytes) catch |err| {
-                    req.client.last_error = .{ .write = err };
-                    return error.WriteFailed;
-                };
-                req.connection.data.conn.writeAll("\r\n") catch |err| {
-                    req.client.last_error = .{ .write = err };
-                    return error.WriteFailed;
-                };
+                try req.connection.data.conn.writer().print("{x}\r\n", .{bytes.len});
+                try req.connection.data.conn.writeAll(bytes);
+                try req.connection.data.conn.writeAll("\r\n");
 
                 return bytes.len;
             },
             .content_length => |*len| {
                 if (len.* < bytes.len) return error.MessageTooLong;
 
-                const amt = req.connection.data.conn.write(bytes) catch |err| {
-                    req.client.last_error = .{ .write = err };
-                    return error.WriteFailed;
-                };
+                const amt = try req.connection.data.conn.write(bytes);
                 len.* -= amt;
                 return amt;
             },
@@ -835,8 +802,10 @@ pub const Request = struct {
         }
     }
 
+    pub const FinishError = WriteError || error{ MessageNotCompleted };
+
     /// Finish the body of a request. This notifies the server that you have no more data to send.
-    pub fn finish(req: *Request) !void {
+    pub fn finish(req: *Request) FinishError!void {
         switch (req.headers.transfer_encoding) {
             .chunked => req.connection.data.conn.writeAll("0\r\n\r\n") catch |err| {
                 req.client.last_error = .{ .write = err };
@@ -857,7 +826,7 @@ pub fn deinit(client: *Client) void {
     client.* = undefined;
 }
 
-pub const ConnectError = Allocator.Error || error{ ConnectionFailed, TlsInitializationFailed };
+pub const ConnectError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed };
 
 /// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open.
 /// This function is threadsafe.
@@ -873,9 +842,16 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
     errdefer client.allocator.destroy(conn);
     conn.* = .{ .data = undefined };
 
-    const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| {
-        client.last_error = .{ .connect = err };
-        return error.ConnectionFailed;
+    const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) {
+        error.ConnectionRefused => return error.ConnectionRefused,
+        error.NetworkUnreachable => return error.NetworkUnreachable,
+        error.ConnectionTimedOut => return error.ConnectionTimedOut,
+        error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
+        error.TemporaryNameServerFailure => return error.TemporaryNameServerFailure,
+        error.NameServerFailure => return error.NameServerFailure,
+        error.UnknownHostName => return error.UnknownHostName,
+        error.HostLacksNetworkAddresses => return error.HostLacksNetworkAddresses,
+        else => return error.UnexpectedConnectFailure,
     };
     errdefer stream.close();
 
@@ -896,10 +872,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
             conn.data.buffered.conn.tls_client = try client.allocator.create(std.crypto.tls.Client);
             errdefer client.allocator.destroy(conn.data.buffered.conn.tls_client);
 
-            conn.data.buffered.conn.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch |err| {
-                client.last_error = .{ .tls = err };
-                return error.TlsInitializationFailed;
-            };
+            conn.data.buffered.conn.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed;
             // This is appropriate for HTTPS because the HTTP headers contain
             // the content length which is used to detect truncation attacks.
             conn.data.buffered.conn.tls_client.allow_truncation_attacks = true;
@@ -911,12 +884,11 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
     return conn;
 }
 
-pub const RequestError = ConnectError || error{
+pub const RequestError = ConnectError || BufferedConnection.WriteError || error{
     UnsupportedUrlScheme,
     UriMissingHost,
 
-    CertificateAuthorityBundleFailed,
-    WriteFailed,
+    CertificateBundleLoadFailure,
 };
 
 pub const Options = struct {
@@ -962,10 +934,7 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Opt
         defer client.ca_bundle_mutex.unlock();
 
         if (client.next_https_rescan_certs) {
-            client.ca_bundle.rescan(client.allocator) catch |err| {
-                client.last_error = .{ .ca_bundle = err };
-                return error.CertificateAuthorityBundleFailed;
-            };
+            client.ca_bundle.rescan(client.allocator) catch return error.CertificateBundleLoadFailure;
             @atomicStore(bool, &client.next_https_rescan_certs, false, .Release);
         }
     }
@@ -989,13 +958,7 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Opt
 
     req.arena = std.heap.ArenaAllocator.init(client.allocator);
 
-    req.start(uri, headers) catch |err| {
-        if (err == error.OutOfMemory) return error.OutOfMemory;
-        const err_casted = @errSetCast(BufferedConnection.WriteError, err);
-
-        client.last_error = .{ .write = err_casted };
-        return error.WriteFailed;
-    };
+    try req.start(uri, headers);
 
     return req;
 }