Commit 85221b4e97

Nameless <truemedian@gmail.com>
2023-04-16 23:26:25
std.http: curate some Server errors, fix reading chunked bodies
1 parent 1342942
Changed files (3)
lib/std/http/Client.zig
@@ -193,7 +193,13 @@ pub const Connection = struct {
         };
     }
 
-    pub const ReadError = error{ TlsFailure, TlsAlert, ConnectionTimedOut, ConnectionResetByPeer, UnexpectedReadFailure };
+    pub const ReadError = error{
+        TlsFailure,
+        TlsAlert,
+        ConnectionTimedOut,
+        ConnectionResetByPeer,
+        UnexpectedReadFailure,
+    };
 
     pub const Reader = std.io.Reader(*Connection, ReadError, read);
 
@@ -518,7 +524,10 @@ pub const Request = struct {
         req.* = undefined;
     }
 
-    pub fn start(req: *Request, uri: Uri) !void {
+    pub const StartError = BufferedConnection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding };
+
+    /// Send the request to the server.
+    pub fn start(req: *Request, uri: Uri) StartError!void {
         var buffered = std.io.bufferedWriter(req.connection.data.buffered.writer());
         const w = buffered.writer();
 
@@ -575,7 +584,7 @@ pub const Request = struct {
             }
         } else {
             if (has_content_length) {
-                const content_length = try std.fmt.parseInt(u64, req.headers.getFirstValue("content-length").?, 10);
+                const content_length = std.fmt.parseInt(u64, req.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength;
 
                 req.transfer_encoding = .{ .content_length = content_length };
             } else if (has_transfer_encoding) {
@@ -618,7 +627,7 @@ pub const Request = struct {
         return index;
     }
 
-    pub const DoError = RequestError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, HttpRedirectMissingLocation, CompressionInitializationFailed };
+    pub const DoError = RequestError || TransferReadError || proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || Uri.ParseError || error{ TooManyHttpRedirects, HttpRedirectMissingLocation, CompressionInitializationFailed, CompressionNotSupported };
 
     /// Waits for a response from the server and parses any headers that are sent.
     /// This function will block until the final response is received.
@@ -739,25 +748,23 @@ pub const Request = struct {
 
     /// Reads data from the response body. Must be called after `do`.
     pub fn read(req: *Request, buffer: []u8) ReadError!usize {
-        while (true) {
-            const out_index = switch (req.response.compression) {
-                .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure,
-                .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
-                .zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
-                else => try req.transferRead(buffer),
-            };
-
-            if (out_index == 0) {
-                while (!req.response.parser.state.isContent()) { // read trailing headers
-                    try req.connection.data.buffered.fill();
-
-                    const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
-                    req.connection.data.buffered.clear(@intCast(u16, nchecked));
-                }
-            }
+        const out_index = switch (req.response.compression) {
+            .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure,
+            .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
+            .zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
+            else => try req.transferRead(buffer),
+        };
+
+        if (out_index == 0) {
+            while (!req.response.parser.state.isContent()) { // read trailing headers
+                try req.connection.data.buffered.fill();
 
-            return out_index;
+                const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.buffered.peek());
+                req.connection.data.buffered.clear(@intCast(u16, nchecked));
+            }
         }
+
+        return out_index;
     }
 
     /// Reads data from the response body. Must be called after `do`.
@@ -800,15 +807,19 @@ pub const Request = struct {
         }
     }
 
+    pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void {
+        var index: usize = 0;
+        while (index < bytes.len) {
+            index += try write(req, bytes[index..]);
+        }
+    }
+
     pub const FinishError = WriteError || error{MessageNotCompleted};
 
     /// Finish the body of a request. This notifies the server that you have no more data to send.
     pub fn finish(req: *Request) FinishError!void {
         switch (req.transfer_encoding) {
-            .chunked => req.connection.data.conn.writeAll("0\r\n\r\n") catch |err| {
-                req.client.last_error = .{ .write = err };
-                return error.WriteFailed;
-            },
+            .chunked => try req.connection.data.conn.writeAll("0\r\n\r\n"),
             .content_length => |len| if (len != 0) return error.MessageNotCompleted,
             .none => {},
         }
@@ -923,7 +934,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
     }
 }
 
-pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || std.fmt.ParseIntError || BufferedConnection.WriteError || error{
+pub const RequestError = ConnectUnproxiedError || ConnectErrorPartial || Request.StartError || std.fmt.ParseIntError || BufferedConnection.WriteError || error{
     UnsupportedUrlScheme,
     UriMissingHost,
 
@@ -998,6 +1009,7 @@ pub fn request(client: *Client, uri: Uri, headers: http.Headers, options: Option
         .handle_redirects = options.handle_redirects,
         .response = .{
             .status = undefined,
+            .reason = undefined,
             .version = undefined,
             .headers = undefined,
             .parser = switch (options.header_strategy) {
@@ -1011,8 +1023,6 @@ pub fn request(client: *Client, uri: Uri, headers: http.Headers, options: Option
 
     req.arena = std.heap.ArenaAllocator.init(client.allocator);
 
-    try req.start(uri);
-
     return req;
 }
 
lib/std/http/Server.zig
@@ -23,21 +23,33 @@ pub const Connection = struct {
 
     pub const Protocol = enum { plain };
 
-    pub fn read(conn: *Connection, buffer: []u8) !usize {
-        switch (conn.protocol) {
-            .plain => return conn.stream.read(buffer),
+    pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
+        return switch (conn.protocol) {
+            .plain => conn.stream.read(buffer),
             // .tls => return conn.tls_client.read(conn.stream, buffer),
-        }
+        } catch |err| switch (err) {
+            error.ConnectionTimedOut => return error.ConnectionTimedOut,
+            error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
+            else => return error.UnexpectedReadFailure,
+        };
     }
 
-    pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) !usize {
-        switch (conn.protocol) {
-            .plain => return conn.stream.readAtLeast(buffer, len),
+    pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
+        return switch (conn.protocol) {
+            .plain => conn.stream.readAtLeast(buffer, len),
             // .tls => return conn.tls_client.readAtLeast(conn.stream, buffer, len),
-        }
+        } catch |err| switch (err) {
+            error.ConnectionTimedOut => return error.ConnectionTimedOut,
+            error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
+            else => return error.UnexpectedReadFailure,
+        };
     }
 
-    pub const ReadError = net.Stream.ReadError;
+    pub const ReadError = error{
+        ConnectionTimedOut,
+        ConnectionResetByPeer,
+        UnexpectedReadFailure,
+    };
 
     pub const Reader = std.io.Reader(*Connection, ReadError, read);
 
@@ -45,21 +57,31 @@ pub const Connection = struct {
         return Reader{ .context = conn };
     }
 
-    pub fn writeAll(conn: *Connection, buffer: []const u8) !void {
-        switch (conn.protocol) {
-            .plain => return conn.stream.writeAll(buffer),
+    pub fn writeAll(conn: *Connection, buffer: []const u8) WriteError!void {
+        return switch (conn.protocol) {
+            .plain => conn.stream.writeAll(buffer),
             // .tls => return conn.tls_client.writeAll(conn.stream, buffer),
-        }
+        } catch |err| switch (err) {
+            error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
+            else => return error.UnexpectedWriteFailure,
+        };
     }
 
-    pub fn write(conn: *Connection, buffer: []const u8) !usize {
-        switch (conn.protocol) {
-            .plain => return conn.stream.write(buffer),
+    pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize {
+        return switch (conn.protocol) {
+            .plain => conn.stream.write(buffer),
             // .tls => return conn.tls_client.write(conn.stream, buffer),
-        }
+        } catch |err| switch (err) {
+            error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
+            else => return error.UnexpectedWriteFailure,
+        };
     }
 
-    pub const WriteError = net.Stream.WriteError || error{};
+    pub const WriteError = error{
+        ConnectionResetByPeer,
+        UnexpectedWriteFailure,
+    };
+
     pub const Writer = std.io.Writer(*Connection, WriteError, write);
 
     pub fn writer(conn: *Connection) Writer {
@@ -155,6 +177,25 @@ pub const BufferedConnection = struct {
     }
 };
 
+/// The mode of transport for responses.
+pub const ResponseTransfer = union(enum) {
+    content_length: u64,
+    chunked: void,
+    none: void,
+};
+
+/// The decompressor for request messages.
+pub const Compression = union(enum) {
+    pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Response.TransferReader);
+    pub const GzipDecompressor = std.compress.gzip.Decompress(Response.TransferReader);
+    pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Response.TransferReader, .{});
+
+    deflate: DeflateDecompressor,
+    gzip: GzipDecompressor,
+    zstd: ZstdDecompressor,
+    none: void,
+};
+
 /// A HTTP request originating from a client.
 pub const Request = struct {
     pub const ParseError = Allocator.Error || error{
@@ -165,10 +206,11 @@ pub const Request = struct {
         HttpHeaderContinuationsUnsupported,
         HttpTransferEncodingUnsupported,
         HttpConnectionHeaderUnsupported,
-        InvalidCharacter,
+        InvalidContentLength,
+        CompressionNotSupported,
     };
 
-    pub fn parse(req: *Request, bytes: []const u8) !void {
+    pub fn parse(req: *Request, bytes: []const u8) ParseError!void {
         var it = mem.tokenize(u8, bytes[0 .. bytes.len - 4], "\r\n");
 
         const first_line = it.next() orelse return error.HttpHeadersInvalid;
@@ -211,7 +253,7 @@ pub const Request = struct {
 
             if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
                 if (req.content_length != null) return error.HttpHeadersInvalid;
-                req.content_length = try std.fmt.parseInt(u64, header_value, 10);
+                req.content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength;
             } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
                 // Transfer-Encoding: second, first
                 // Transfer-Encoding: deflate, chunked
@@ -321,6 +363,8 @@ pub const Response = struct {
         }
     }
 
+    pub const DoError = BufferedConnection.WriteError || error{ UnsupportedTransferEncoding, InvalidContentLength };
+
     /// Send the response headers.
     pub fn do(res: *Response) !void {
         var buffered = std.io.bufferedWriter(res.connection.writer());
@@ -356,7 +400,7 @@ pub const Response = struct {
             }
         } else {
             if (has_content_length) {
-                const content_length = try std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10);
+                const content_length = std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength;
 
                 res.transfer_encoding = .{ .content_length = content_length };
             } else if (has_transfer_encoding) {
@@ -386,23 +430,23 @@ pub const Response = struct {
         return .{ .context = res };
     }
 
-    pub fn transferRead(res: *Response, buf: []u8) TransferReadError!usize {
-        if (res.request.parser.isComplete()) return 0;
+    fn transferRead(res: *Response, buf: []u8) TransferReadError!usize {
+        if (res.request.parser.done) return 0;
 
         var index: usize = 0;
         while (index == 0) {
             const amt = try res.request.parser.read(&res.connection, buf[index..], false);
-            if (amt == 0 and res.request.parser.isComplete()) break;
+            if (amt == 0 and res.request.parser.done) break;
             index += amt;
         }
 
         return index;
     }
 
-    pub const WaitForCompleteHeadError = BufferedConnection.ReadError || proto.HeadersParser.WaitForCompleteHeadError || Request.Headers.ParseError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize } || error{CompressionNotSupported};
+    pub const WaitError = BufferedConnection.ReadError || proto.HeadersParser.CheckCompleteHeadError || Request.ParseError || error{ CompressionInitializationFailed, CompressionNotSupported };
 
     /// Wait for the client to send a complete request head.
-    pub fn wait(res: *Response) !void {
+    pub fn wait(res: *Response) WaitError!void {
         while (true) {
             try res.connection.fill();
 
@@ -445,10 +489,10 @@ pub const Response = struct {
             if (res.request.transfer_compression) |tc| switch (tc) {
                 .compress => return error.CompressionNotSupported,
                 .deflate => res.request.compression = .{
-                    .deflate = try std.compress.zlib.zlibStream(res.server.allocator, res.transferReader()),
+                    .deflate = std.compress.zlib.zlibStream(res.server.allocator, res.transferReader()) catch return error.CompressionInitializationFailed,
                 },
                 .gzip => res.request.compression = .{
-                    .gzip = try std.compress.gzip.decompress(res.server.allocator, res.transferReader()),
+                    .gzip = std.compress.gzip.decompress(res.server.allocator, res.transferReader()) catch return error.CompressionInitializationFailed,
                 },
                 .zstd => res.request.compression = .{
                     .zstd = std.compress.zstd.decompressStream(res.server.allocator, res.transferReader()),
@@ -457,7 +501,7 @@ pub const Response = struct {
         }
     }
 
-    pub const ReadError = Compression.DeflateDecompressor.Error || Compression.GzipDecompressor.Error || Compression.ZstdDecompressor.Error || WaitForCompleteHeadError;
+    pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || error{DecompressionFailure};
 
     pub const Reader = std.io.Reader(*Response, ReadError, read);
 
@@ -466,12 +510,23 @@ pub const Response = struct {
     }
 
     pub fn read(res: *Response, buffer: []u8) ReadError!usize {
-        return switch (res.request.compression) {
-            .deflate => |*deflate| try deflate.read(buffer),
-            .gzip => |*gzip| try gzip.read(buffer),
-            .zstd => |*zstd| try zstd.read(buffer),
+        const out_index = switch (res.request.compression) {
+            .deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure,
+            .gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
+            .zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure,
             else => try res.transferRead(buffer),
         };
+
+        if (out_index == 0) {
+            while (!res.request.parser.state.isContent()) { // read trailing headers
+                try res.connection.fill();
+
+                const nchecked = try res.request.parser.checkCompleteHead(res.server.allocator, res.connection.peek());
+                res.connection.clear(@intCast(u16, nchecked));
+            }
+        }
+
+        return out_index;
     }
 
     pub fn readAll(res: *Response, buffer: []u8) !usize {
@@ -513,9 +568,18 @@ pub const Response = struct {
         }
     }
 
+    pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void {
+        var index: usize = 0;
+        while (index < bytes.len) {
+            index += try write(req, bytes[index..]);
+        }
+    }
+
+    pub const FinishError = WriteError || error{MessageNotCompleted};
+
     /// Finish the body of a request. This notifies the server that you have no more data to send.
-    pub fn finish(res: *Response) !void {
-        switch (res.headers.transfer_encoding) {
+    pub fn finish(res: *Response) FinishError!void {
+        switch (res.transfer_encoding) {
             .chunked => try res.connection.writeAll("0\r\n\r\n"),
             .content_length => |len| if (len != 0) return error.MessageNotCompleted,
             .none => {},
@@ -523,25 +587,6 @@ pub const Response = struct {
     }
 };
 
-/// The mode of transport for responses.
-pub const ResponseTransfer = union(enum) {
-    content_length: u64,
-    chunked: void,
-    none: void,
-};
-
-/// The decompressor for request messages.
-pub const Compression = union(enum) {
-    pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Response.TransferReader);
-    pub const GzipDecompressor = std.compress.gzip.Decompress(Response.TransferReader);
-    pub const ZstdDecompressor = std.compress.zstd.DecompressStream(Response.TransferReader, .{});
-
-    deflate: DeflateDecompressor,
-    gzip: GzipDecompressor,
-    zstd: ZstdDecompressor,
-    none: void,
-};
-
 pub fn init(allocator: Allocator, options: net.StreamServer.Options) Server {
     return .{
         .allocator = allocator,
src/Package.zig
@@ -485,6 +485,8 @@ fn fetchAndUnpack(
         var req = try http_client.request(uri, h, .{ .method = .GET });
         defer req.deinit();
 
+        try req.start();
+
         try req.do();
 
         if (mem.endsWith(u8, uri.path, ".tar.gz")) {