Commit c0d8ac83eb

Andrew Kelley <andrew@ziglang.org>
2024-02-22 01:42:17
std.http.Server: fix handling of HEAD + chunked
1 parent 5145134
Changed files (1)
lib
std
lib/std/http/Server.zig
@@ -309,14 +309,13 @@ pub const Request = struct {
 
         var first_buffer: [500]u8 = undefined;
         var h = std.ArrayListUnmanaged(u8).initBuffer(&first_buffer);
-        h.writerAssumeCapacity().print("{s} {d} {s}\r\n", .{
+        h.fixedWriter().print("{s} {d} {s}\r\n", .{
             @tagName(options.version), @intFromEnum(options.status), phrase,
-        }) catch |err| switch (err) {};
+        }) catch unreachable;
         if (keep_alive)
             h.appendSliceAssumeCapacity("connection: keep-alive\r\n");
         if (content.len > 0)
-            h.writerAssumeCapacity().print("content-length: {d}\r\n", .{content.len}) catch |err|
-                switch (err) {};
+            h.fixedWriter().print("content-length: {d}\r\n", .{content.len}) catch unreachable;
 
         var iovecs: [max_extra_headers * 4 + 3]std.posix.iovec_const = undefined;
         var iovecs_len: usize = 0;
@@ -407,13 +406,13 @@ pub const Request = struct {
 
         var h = std.ArrayListUnmanaged(u8).initBuffer(options.send_buffer);
 
-        h.writerAssumeCapacity().print("{s} {d} {s}\r\n", .{
+        h.fixedWriter().print("{s} {d} {s}\r\n", .{
             @tagName(o.version), @intFromEnum(o.status), phrase,
-        }) catch |err| switch (err) {};
+        }) catch unreachable;
         if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n");
 
         if (options.content_length) |len| {
-            h.writerAssumeCapacity().print("content-length: {d}\r\n", .{len}) catch |err| switch (err) {};
+            h.fixedWriter().print("content-length: {d}\r\n", .{len}) catch unreachable;
         } else {
             h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n");
         }
@@ -443,12 +442,12 @@ pub const Request = struct {
     fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize {
         const request: *Request = @constCast(@alignCast(@ptrCast(context)));
         const s = request.server;
-        assert(s.state == .receiving_body);
         const remaining_content_length = &request.reader_state.remaining_content_length;
         if (remaining_content_length.* == 0) {
             s.state = .ready;
             return 0;
         }
+        assert(s.state == .receiving_body);
         const available = try fill(s, request.head_end);
         const len = @min(remaining_content_length.*, available.len, buffer.len);
         @memcpy(buffer[0..len], available[0..len]);
@@ -470,8 +469,6 @@ pub const Request = struct {
     fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize {
         const request: *Request = @constCast(@alignCast(@ptrCast(context)));
         const s = request.server;
-        assert(s.state == .receiving_body);
-
         const cp = &request.reader_state.chunk_parser;
         const head_end = request.head_end;
 
@@ -481,6 +478,7 @@ pub const Request = struct {
             switch (cp.state) {
                 .invalid => return 0,
                 .data => {
+                    assert(s.state == .receiving_body);
                     const available = try fill(s, head_end);
                     const len = @min(cp.chunk_len, available.len, buffer.len);
                     @memcpy(buffer[0..len], available[0..len]);
@@ -492,6 +490,7 @@ pub const Request = struct {
                     continue;
                 },
                 else => {
+                    assert(s.state == .receiving_body);
                     const available = try fill(s, head_end);
                     const n = cp.feed(available);
                     switch (cp.state) {
@@ -514,6 +513,8 @@ pub const Request = struct {
                                     const bytes = s.read_buffer[head_end..s.read_buffer_len];
                                     const end = hp.feed(bytes);
                                     if (hp.state == .finished) {
+                                        cp.state = .invalid;
+                                        s.state = .ready;
                                         s.next_request_start = s.read_buffer_len - bytes.len + end;
                                         return out_end;
                                     }
@@ -527,6 +528,8 @@ pub const Request = struct {
                                     const bytes = buf[0..read_n];
                                     const end = hp.feed(bytes);
                                     if (hp.state == .finished) {
+                                        cp.state = .invalid;
+                                        s.state = .ready;
                                         s.next_request_start = s.read_buffer_len - bytes.len + end;
                                         return out_end;
                                     }
@@ -625,14 +628,13 @@ pub const Response = struct {
     /// the value sent in the header, then calls `flush`.
     /// Otherwise, transfer-encoding: chunked is being used, and it writes the
     /// end-of-stream message, then flushes the stream to the system.
-    /// When request method is HEAD, does not write anything to the stream.
+    /// Respects the value of `elide_body` to omit all data after the headers.
     pub fn end(r: *Response) WriteError!void {
         if (r.content_length) |len| {
             assert(len == 0); // Trips when end() called before all bytes written.
-            return flush_cl(r);
-        }
-        if (!r.elide_body) {
-            return flush_chunked(r, &.{});
+            try flush_cl(r);
+        } else {
+            try flush_chunked(r, &.{});
         }
         r.* = undefined;
     }
@@ -644,11 +646,10 @@ pub const Response = struct {
     /// Asserts that the Response is using transfer-encoding: chunked.
     /// Writes the end-of-stream message and any optional trailers, then
     /// flushes the stream to the system.
-    /// When request method is HEAD, does not write anything to the stream.
+    /// Respects the value of `elide_body` to omit all data after the headers.
     /// Asserts there are at most 25 trailers.
     pub fn endChunked(r: *Response, options: EndChunkedOptions) WriteError!void {
         assert(r.content_length == null);
-        if (r.elide_body) return;
         try flush_chunked(r, options.trailers);
         r.* = undefined;
     }
@@ -771,6 +772,7 @@ pub const Response = struct {
 
     /// Sends all buffered data to the client.
     /// This is redundant after calling `end`.
+    /// Respects the value of `elide_body` to omit all data after the headers.
     pub fn flush(r: *Response) WriteError!void {
         if (r.content_length != null) {
             return flush_cl(r);
@@ -790,7 +792,17 @@ pub const Response = struct {
         const max_trailers = 25;
         if (end_trailers) |trailers| assert(trailers.len <= max_trailers);
         assert(r.content_length == null);
-        const send_buffer_len = r.send_buffer_end - r.send_buffer_start;
+
+        const http_headers = r.send_buffer[r.send_buffer_start .. r.send_buffer_end - r.chunk_len];
+
+        if (r.elide_body) {
+            try r.stream.writeAll(http_headers);
+            r.send_buffer_start = 0;
+            r.send_buffer_end = 0;
+            r.chunk_len = 0;
+            return;
+        }
+
         var header_buf: [18]u8 = undefined;
         const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{r.chunk_len}) catch unreachable;
 
@@ -798,8 +810,8 @@ pub const Response = struct {
         var iovecs_len: usize = 0;
 
         iovecs[iovecs_len] = .{
-            .iov_base = r.send_buffer.ptr + r.send_buffer_start,
-            .iov_len = send_buffer_len - r.chunk_len,
+            .iov_base = http_headers.ptr,
+            .iov_len = http_headers.len,
         };
         iovecs_len += 1;