Commit 50e2a5f673

Andrew Kelley <andrew@ziglang.org>
2024-02-12 06:19:41
std.http: remove 'done' flag
This is a state machine that already has a `state` field. No need to additionally store "done" - it just makes things unnecessarily complicated and buggy.
1 parent 06d0c58
Changed files (4)
lib
test
standalone
lib/std/http/Client.zig
@@ -610,7 +610,7 @@ pub const Request = struct {
         req.response.headers.deinit();
 
         if (req.connection) |connection| {
-            if (!req.response.parser.done) {
+            if (req.response.parser.state != .complete) {
                 // If the response wasn't fully read, then we need to close the connection.
                 connection.closing = true;
             }
@@ -624,7 +624,7 @@ pub const Request = struct {
     // This function must deallocate all resources associated with the request, or keep those which will be used
     // This needs to be kept in sync with deinit and request
     fn redirect(req: *Request, uri: Uri) !void {
-        assert(req.response.parser.done);
+        assert(req.response.parser.state == .complete);
 
         switch (req.response.compression) {
             .none => {},
@@ -794,12 +794,12 @@ pub const Request = struct {
     }
 
     fn transferRead(req: *Request, buf: []u8) TransferReadError!usize {
-        if (req.response.parser.done) return 0;
+        if (req.response.parser.state == .complete) return 0;
 
         var index: usize = 0;
         while (index == 0) {
             const amt = try req.response.parser.read(req.connection.?, buf[index..], req.response.skip);
-            if (amt == 0 and req.response.parser.done) break;
+            if (amt == 0 and req.response.parser.state == .complete) break;
             index += amt;
         }
 
@@ -840,7 +840,7 @@ pub const Request = struct {
             try req.response.parse(req.response.parser.get(), false);
 
             if (req.response.status == .@"continue") {
-                req.response.parser.done = true; // we're done parsing the continue response, reset to prepare for the real response
+                req.response.parser.state = .complete; // we're done parsing the continue response, reset to prepare for the real response
                 req.response.parser.reset();
 
                 if (req.handle_continue)
@@ -852,7 +852,7 @@ pub const Request = struct {
             // we're switching protocols, so this connection is no longer doing http
             if (req.method == .CONNECT and req.response.status.class() == .success) {
                 req.connection.?.closing = false;
-                req.response.parser.done = true;
+                req.response.parser.state = .complete;
 
                 return; // the connection is not HTTP past this point, return to the caller
             }
@@ -872,8 +872,10 @@ pub const Request = struct {
             // Any response to a HEAD request and any response with a 1xx (Informational), 204 (No Content), or 304 (Not Modified)
             // status code is always terminated by the first empty line after the header fields, regardless of the header fields
             // present in the message
-            if (req.method == .HEAD or req.response.status.class() == .informational or req.response.status == .no_content or req.response.status == .not_modified) {
-                req.response.parser.done = true;
+            if (req.method == .HEAD or req.response.status.class() == .informational or
+                req.response.status == .no_content or req.response.status == .not_modified)
+            {
+                req.response.parser.state = .complete;
 
                 return; // the response is empty, no further setup or redirection is necessary
             }
@@ -889,7 +891,7 @@ pub const Request = struct {
             } else if (req.response.content_length) |cl| {
                 req.response.parser.next_chunk_length = cl;
 
-                if (cl == 0) req.response.parser.done = true;
+                if (cl == 0) req.response.parser.state = .complete;
             } else {
                 // read until the connection is closed
                 req.response.parser.next_chunk_length = std.math.maxInt(u64);
@@ -947,7 +949,7 @@ pub const Request = struct {
                 try req.send(.{});
             } else {
                 req.response.skip = false;
-                if (!req.response.parser.done) {
+                if (req.response.parser.state != .complete) {
                     switch (req.response.transfer_compression) {
                         .identity => req.response.compression = .none,
                         .compress, .@"x-compress" => return error.CompressionNotSupported,
lib/std/http/protocol.zig
@@ -14,7 +14,7 @@ pub const State = enum {
     seen_r,
     seen_rn,
     seen_rnr,
-    finished,
+    headers_end,
     /// Begin transfer-encoding: chunked parsing states.
     chunk_head_size,
     chunk_head_ext,
@@ -22,46 +22,61 @@ pub const State = enum {
     chunk_data,
     chunk_data_suffix,
     chunk_data_suffix_r,
+    /// When the parser has finished parsing a complete message. A message is
+    /// only complete after the entire body has been read and any trailing
+    /// headers have been parsed.
+    complete,
 
     /// Returns true if the parser is in a content state (ie. not waiting for more headers).
     pub fn isContent(self: State) bool {
         return switch (self) {
-            .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => false,
-            .finished, .chunk_head_size, .chunk_head_ext, .chunk_head_r, .chunk_data, .chunk_data_suffix, .chunk_data_suffix_r => true,
+            .invalid,
+            .start,
+            .seen_n,
+            .seen_r,
+            .seen_rn,
+            .seen_rnr,
+            => false,
+
+            .headers_end,
+            .chunk_head_size,
+            .chunk_head_ext,
+            .chunk_head_r,
+            .chunk_data,
+            .chunk_data_suffix,
+            .chunk_data_suffix_r,
+            .complete,
+            => true,
         };
     }
 };
 
 pub const HeadersParser = struct {
-    state: State = .start,
+    state: State,
     /// A fixed buffer of len `max_header_bytes`.
     /// Pointers into this buffer are not stable until after a message is complete.
     header_bytes_buffer: []u8,
     header_bytes_len: u32,
     next_chunk_length: u64,
-    /// Whether this parser is done parsing a complete message.
-    /// A message is only done when the entire payload has been read.
-    done: bool,
 
     /// Initializes the parser with a provided buffer `buf`.
     pub fn init(buf: []u8) HeadersParser {
         return .{
+            .state = .start,
             .header_bytes_buffer = buf,
             .header_bytes_len = 0,
-            .done = false,
             .next_chunk_length = 0,
         };
     }
 
     /// Reinitialize the parser.
-    /// Asserts the parser is in the "done" state.
+    /// Asserts the parser is in the `complete` state.
     pub fn reset(hp: *HeadersParser) void {
-        assert(hp.done);
+        assert(hp.state == .complete);
         hp.* = .{
             .state = .start,
             .header_bytes_buffer = hp.header_bytes_buffer,
             .header_bytes_len = 0,
-            .done = false,
             .next_chunk_length = 0,
         };
     }
@@ -86,7 +101,8 @@ pub const HeadersParser = struct {
         while (true) {
             switch (r.state) {
                 .invalid => unreachable,
-                .finished => return index,
+                .complete => unreachable,
+                .headers_end => return index,
                 .start => switch (len - index) {
                     0 => return index,
                     1 => {
@@ -110,7 +126,7 @@ pub const HeadersParser = struct {
 
                         switch (b16) {
                             int16("\r\n") => r.state = .seen_rn,
-                            int16("\n\n") => r.state = .finished,
+                            int16("\n\n") => r.state = .headers_end,
                             else => {},
                         }
 
@@ -129,7 +145,7 @@ pub const HeadersParser = struct {
 
                         switch (b16) {
                             int16("\r\n") => r.state = .seen_rn,
-                            int16("\n\n") => r.state = .finished,
+                            int16("\n\n") => r.state = .headers_end,
                             else => {},
                         }
 
@@ -154,7 +170,7 @@ pub const HeadersParser = struct {
 
                         switch (b16) {
                             int16("\r\n") => r.state = .seen_rn,
-                            int16("\n\n") => r.state = .finished,
+                            int16("\n\n") => r.state = .headers_end,
                             else => {},
                         }
 
@@ -164,7 +180,7 @@ pub const HeadersParser = struct {
                         }
 
                         switch (b32) {
-                            int32("\r\n\r\n") => r.state = .finished,
+                            int32("\r\n\r\n") => r.state = .headers_end,
                             else => {},
                         }
 
@@ -212,7 +228,7 @@ pub const HeadersParser = struct {
 
                                 switch (b16) {
                                     int16("\r\n") => r.state = .seen_rn,
-                                    int16("\n\n") => r.state = .finished,
+                                    int16("\n\n") => r.state = .headers_end,
                                     else => {},
                                 }
                             },
@@ -229,7 +245,7 @@ pub const HeadersParser = struct {
 
                                 switch (b16) {
                                     int16("\r\n") => r.state = .seen_rn,
-                                    int16("\n\n") => r.state = .finished,
+                                    int16("\n\n") => r.state = .headers_end,
                                     else => {},
                                 }
 
@@ -246,10 +262,10 @@ pub const HeadersParser = struct {
                                     const b16 = intShift(u16, b32);
 
                                     if (b32 == int32("\r\n\r\n")) {
-                                        r.state = .finished;
+                                        r.state = .headers_end;
                                         return index + i + 4;
                                     } else if (b16 == int16("\n\n")) {
-                                        r.state = .finished;
+                                        r.state = .headers_end;
                                         return index + i + 2;
                                     }
                                 }
@@ -266,7 +282,7 @@ pub const HeadersParser = struct {
 
                                 switch (b16) {
                                     int16("\r\n") => r.state = .seen_rn,
-                                    int16("\n\n") => r.state = .finished,
+                                    int16("\n\n") => r.state = .headers_end,
                                     else => {},
                                 }
 
@@ -286,7 +302,7 @@ pub const HeadersParser = struct {
                     0 => return index,
                     else => {
                         switch (bytes[index]) {
-                            '\n' => r.state = .finished,
+                            '\n' => r.state = .headers_end,
                             else => r.state = .start,
                         }
 
@@ -318,7 +334,7 @@ pub const HeadersParser = struct {
                         switch (b16) {
                             int16("\r\n") => r.state = .seen_rn,
                             int16("\n\r") => r.state = .seen_rnr,
-                            int16("\n\n") => r.state = .finished,
+                            int16("\n\n") => r.state = .headers_end,
                             else => {},
                         }
 
@@ -337,12 +353,12 @@ pub const HeadersParser = struct {
 
                         switch (b16) {
                             int16("\r\n") => r.state = .seen_rn,
-                            int16("\n\n") => r.state = .finished,
+                            int16("\n\n") => r.state = .headers_end,
                             else => {},
                         }
 
                         switch (b24) {
-                            int24("\n\r\n") => r.state = .finished,
+                            int24("\n\r\n") => r.state = .headers_end,
                             else => {},
                         }
 
@@ -372,8 +388,8 @@ pub const HeadersParser = struct {
                         }
 
                         switch (b16) {
-                            int16("\r\n") => r.state = .finished,
-                            int16("\n\n") => r.state = .finished,
+                            int16("\r\n") => r.state = .headers_end,
+                            int16("\n\n") => r.state = .headers_end,
                             else => {},
                         }
 
@@ -385,7 +401,7 @@ pub const HeadersParser = struct {
                     0 => return index,
                     else => {
                         switch (bytes[index]) {
-                            '\n' => r.state = .finished,
+                            '\n' => r.state = .headers_end,
                             else => r.state = .start,
                         }
 
@@ -486,13 +502,6 @@ pub const HeadersParser = struct {
         return len;
     }
 
-    /// Returns whether or not the parser has finished parsing a complete
-    /// message. A message is only complete after the entire body has been read
-    /// and any trailing headers have been parsed.
-    pub fn isComplete(r: *HeadersParser) bool {
-        return r.done and r.state == .finished;
-    }
-
     pub const CheckCompleteHeadError = error{HttpHeadersOversize};
 
     /// Pushes `in` into the parser. Returns the number of bytes consumed by
@@ -523,13 +532,12 @@ pub const HeadersParser = struct {
     /// See `std.http.Client.Connection for an example of `conn`.
     pub fn read(r: *HeadersParser, conn: anytype, buffer: []u8, skip: bool) !usize {
         assert(r.state.isContent());
-        if (r.done) return 0;
-
         var out_index: usize = 0;
         while (true) {
             switch (r.state) {
+                .complete => return out_index,
                 .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => unreachable,
-                .finished => {
+                .headers_end => {
                     const data_avail = r.next_chunk_length;
 
                     if (skip) {
@@ -539,7 +547,8 @@ pub const HeadersParser = struct {
                         conn.drop(@intCast(nread));
                         r.next_chunk_length -= nread;
 
-                        if (r.next_chunk_length == 0 or nread == 0) r.done = true;
+                        if (r.next_chunk_length == 0 or nread == 0)
+                            r.state = .complete;
 
                         return out_index;
                     } else if (out_index < buffer.len) {
@@ -549,7 +558,8 @@ pub const HeadersParser = struct {
                         const nread = try conn.read(buffer[0..can_read]);
                         r.next_chunk_length -= nread;
 
-                        if (r.next_chunk_length == 0 or nread == 0) r.done = true;
+                        if (r.next_chunk_length == 0 or nread == 0)
+                            r.state = .complete;
 
                         return nread;
                     } else {
@@ -566,14 +576,12 @@ pub const HeadersParser = struct {
                         .invalid => return error.HttpChunkInvalid,
                         .chunk_data => if (r.next_chunk_length == 0) {
                             if (std.mem.eql(u8, conn.peek(), "\r\n")) {
-                                r.state = .finished;
-                                r.done = true;
+                                r.state = .complete;
                             } else {
-                                // The trailer section is formatted identically to the header section.
+                                // The trailer section is formatted identically
+                                // to the header section.
                                 r.state = .seen_rn;
                             }
-                            r.done = true;
-
                             return out_index;
                         },
                         else => return out_index,
@@ -611,21 +619,21 @@ pub const HeadersParser = struct {
 };
 
 inline fn int16(array: *const [2]u8) u16 {
-    return @as(u16, @bitCast(array.*));
+    return @bitCast(array.*);
 }
 
 inline fn int24(array: *const [3]u8) u24 {
-    return @as(u24, @bitCast(array.*));
+    return @bitCast(array.*);
 }
 
 inline fn int32(array: *const [4]u8) u32 {
-    return @as(u32, @bitCast(array.*));
+    return @bitCast(array.*);
 }
 
 inline fn intShift(comptime T: type, x: anytype) T {
     switch (@import("builtin").cpu.arch.endian()) {
-        .little => return @as(T, @truncate(x >> (@bitSizeOf(@TypeOf(x)) - @bitSizeOf(T)))),
-        .big => return @as(T, @truncate(x)),
+        .little => return @truncate(x >> (@bitSizeOf(@TypeOf(x)) - @bitSizeOf(T))),
+        .big => return @truncate(x),
     }
 }
 
lib/std/http/Server.zig
@@ -395,7 +395,7 @@ pub const Response = struct {
             return .reset;
         }
 
-        if (!res.request.parser.done) {
+        if (res.request.parser.state != .complete) {
             // If the response wasn't fully read, then we need to close the connection.
             res.connection.closing = true;
             return .closing;
@@ -534,12 +534,12 @@ pub const Response = struct {
     }
 
     fn transferRead(res: *Response, buf: []u8) TransferReadError!usize {
-        if (res.request.parser.done) return 0;
+        if (res.request.parser.state == .complete) return 0;
 
         var index: usize = 0;
         while (index == 0) {
             const amt = try res.request.parser.read(&res.connection, buf[index..], false);
-            if (amt == 0 and res.request.parser.done) break;
+            if (amt == 0 and res.request.parser.state == .complete) break;
             index += amt;
         }
 
@@ -596,12 +596,12 @@ pub const Response = struct {
         } else if (res.request.content_length) |cl| {
             res.request.parser.next_chunk_length = cl;
 
-            if (cl == 0) res.request.parser.done = true;
+            if (cl == 0) res.request.parser.state = .complete;
         } else {
-            res.request.parser.done = true;
+            res.request.parser.state = .complete;
         }
 
-        if (!res.request.parser.done) {
+        if (res.request.parser.state != .complete) {
             switch (res.request.transfer_compression) {
                 .identity => res.request.compression = .none,
                 .compress, .@"x-compress" => return error.CompressionNotSupported,
test/standalone/http.zig
@@ -165,10 +165,11 @@ fn handleRequest(res: *Server.Response) !void {
 var handle_new_requests = true;
 
 fn runServer(srv: *Server) !void {
+    var client_header_buffer: [1024]u8 = undefined;
     outer: while (handle_new_requests) {
         var res = try srv.accept(.{
             .allocator = salloc,
-            .header_strategy = .{ .dynamic = max_header_size },
+            .client_header_buffer = &client_header_buffer,
         });
         defer res.deinit();
 
@@ -244,7 +245,10 @@ pub fn main() !void {
         const uri = try std.Uri.parse(location);
 
         log.info("{s}", .{location});
-        var req = try client.open(.GET, uri, h, .{});
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.GET, uri, h, .{
+            .server_header_buffer = &server_header_buffer,
+        });
         defer req.deinit();
 
         try req.send(.{});
@@ -269,7 +273,10 @@ pub fn main() !void {
         const uri = try std.Uri.parse(location);
 
         log.info("{s}", .{location});
-        var req = try client.open(.GET, uri, h, .{});
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.GET, uri, h, .{
+            .server_header_buffer = &server_header_buffer,
+        });
         defer req.deinit();
 
         try req.send(.{});
@@ -293,7 +300,10 @@ pub fn main() !void {
         const uri = try std.Uri.parse(location);
 
         log.info("{s}", .{location});
-        var req = try client.open(.HEAD, uri, h, .{});
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.HEAD, uri, h, .{
+            .server_header_buffer = &server_header_buffer,
+        });
         defer req.deinit();
 
         try req.send(.{});
@@ -319,7 +329,10 @@ pub fn main() !void {
         const uri = try std.Uri.parse(location);
 
         log.info("{s}", .{location});
-        var req = try client.open(.GET, uri, h, .{});
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.GET, uri, h, .{
+            .server_header_buffer = &server_header_buffer,
+        });
         defer req.deinit();
 
         try req.send(.{});
@@ -344,7 +357,10 @@ pub fn main() !void {
         const uri = try std.Uri.parse(location);
 
         log.info("{s}", .{location});
-        var req = try client.open(.HEAD, uri, h, .{});
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.HEAD, uri, h, .{
+            .server_header_buffer = &server_header_buffer,
+        });
         defer req.deinit();
 
         try req.send(.{});
@@ -370,7 +386,10 @@ pub fn main() !void {
         const uri = try std.Uri.parse(location);
 
         log.info("{s}", .{location});
-        var req = try client.open(.GET, uri, h, .{});
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.GET, uri, h, .{
+            .server_header_buffer = &server_header_buffer,
+        });
         defer req.deinit();
 
         try req.send(.{});
@@ -397,7 +416,10 @@ pub fn main() !void {
         const uri = try std.Uri.parse(location);
 
         log.info("{s}", .{location});
-        var req = try client.open(.POST, uri, h, .{});
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.POST, uri, h, .{
+            .server_header_buffer = &server_header_buffer,
+        });
         defer req.deinit();
 
         req.transfer_encoding = .{ .content_length = 14 };
@@ -429,7 +451,10 @@ pub fn main() !void {
         const uri = try std.Uri.parse(location);
 
         log.info("{s}", .{location});
-        var req = try client.open(.GET, uri, h, .{});
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.GET, uri, h, .{
+            .server_header_buffer = &server_header_buffer,
+        });
         defer req.deinit();
 
         try req.send(.{});
@@ -456,7 +481,10 @@ pub fn main() !void {
         const uri = try std.Uri.parse(location);
 
         log.info("{s}", .{location});
-        var req = try client.open(.POST, uri, h, .{});
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.POST, uri, h, .{
+            .server_header_buffer = &server_header_buffer,
+        });
         defer req.deinit();
 
         req.transfer_encoding = .chunked;
@@ -486,7 +514,10 @@ pub fn main() !void {
         const uri = try std.Uri.parse(location);
 
         log.info("{s}", .{location});
-        var req = try client.open(.GET, uri, h, .{});
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.GET, uri, h, .{
+            .server_header_buffer = &server_header_buffer,
+        });
         defer req.deinit();
 
         try req.send(.{});
@@ -510,7 +541,10 @@ pub fn main() !void {
         const uri = try std.Uri.parse(location);
 
         log.info("{s}", .{location});
-        var req = try client.open(.GET, uri, h, .{});
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.GET, uri, h, .{
+            .server_header_buffer = &server_header_buffer,
+        });
         defer req.deinit();
 
         try req.send(.{});
@@ -534,7 +568,10 @@ pub fn main() !void {
         const uri = try std.Uri.parse(location);
 
         log.info("{s}", .{location});
-        var req = try client.open(.GET, uri, h, .{});
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.GET, uri, h, .{
+            .server_header_buffer = &server_header_buffer,
+        });
         defer req.deinit();
 
         try req.send(.{});
@@ -558,7 +595,10 @@ pub fn main() !void {
         const uri = try std.Uri.parse(location);
 
         log.info("{s}", .{location});
-        var req = try client.open(.GET, uri, h, .{});
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.GET, uri, h, .{
+            .server_header_buffer = &server_header_buffer,
+        });
         defer req.deinit();
 
         try req.send(.{});
@@ -580,7 +620,10 @@ pub fn main() !void {
         const uri = try std.Uri.parse(location);
 
         log.info("{s}", .{location});
-        var req = try client.open(.GET, uri, h, .{});
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.GET, uri, h, .{
+            .server_header_buffer = &server_header_buffer,
+        });
         defer req.deinit();
 
         try req.send(.{});
@@ -628,7 +671,10 @@ pub fn main() !void {
         const uri = try std.Uri.parse(location);
 
         log.info("{s}", .{location});
-        var req = try client.open(.POST, uri, h, .{});
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.POST, uri, h, .{
+            .server_header_buffer = &server_header_buffer,
+        });
         defer req.deinit();
 
         req.transfer_encoding = .chunked;
@@ -659,7 +705,10 @@ pub fn main() !void {
         const uri = try std.Uri.parse(location);
 
         log.info("{s}", .{location});
-        var req = try client.open(.POST, uri, h, .{});
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.POST, uri, h, .{
+            .server_header_buffer = &server_header_buffer,
+        });
         defer req.deinit();
 
         req.transfer_encoding = .chunked;
@@ -678,9 +727,17 @@ pub fn main() !void {
         var requests = try calloc.alloc(http.Client.Request, total_connections);
         defer calloc.free(requests);
 
+        var header_bufs = std.ArrayList([]u8).init(calloc);
+        defer header_bufs.deinit();
+        defer for (header_bufs.items) |item| calloc.free(item);
+
         for (0..total_connections) |i| {
-            var req = try client.open(.GET, uri, .{ .allocator = calloc }, .{});
-            req.response.parser.done = true;
+            const headers_buf = try calloc.alloc(u8, 1024);
+            try header_bufs.append(headers_buf);
+            var req = try client.open(.GET, uri, .{ .allocator = calloc }, .{
+                .server_header_buffer = headers_buf,
+            });
+            req.response.parser.state = .complete;
             req.connection.?.closing = false;
             requests[i] = req;
         }