Commit d051b13963

Andrew Kelley <andrew@ziglang.org>
2024-02-23 02:36:40
std.http.Server: implement respondStreaming with unknown len
no content-length header no transfer-encoding header
1 parent 737e7be
Changed files (2)
lib
lib/std/http/Server.zig
@@ -474,7 +474,10 @@ pub const Request = struct {
             }) catch unreachable;
             if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n");
 
-            if (options.content_length) |len| {
+            if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) {
+                .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"),
+                .none => {},
+            } else if (options.content_length) |len| {
                 h.fixedWriter().print("content-length: {d}\r\n", .{len}) catch unreachable;
             } else {
                 h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n");
@@ -496,7 +499,12 @@ pub const Request = struct {
             .send_buffer = options.send_buffer,
             .send_buffer_start = 0,
             .send_buffer_end = h.items.len,
-            .content_length = options.content_length,
+            .transfer_encoding = if (o.transfer_encoding) |te| switch (te) {
+                .chunked => .chunked,
+                .none => .none,
+            } else if (options.content_length) |len| .{
+                .content_length = len,
+            } else .chunked,
             .elide_body = elide_body,
             .chunk_len = 0,
         };
@@ -709,12 +717,21 @@ pub const Response = struct {
     send_buffer_end: usize,
     /// `null` means transfer-encoding: chunked.
     /// As a debugging utility, counts down to zero as bytes are written.
-    content_length: ?u64,
+    transfer_encoding: TransferEncoding,
     elide_body: bool,
     /// Indicates how much of the end of the `send_buffer` corresponds to a
     /// chunk. This amount of data will be wrapped by an HTTP chunk header.
     chunk_len: usize,
 
+    pub const TransferEncoding = union(enum) {
+        /// End of connection signals the end of the stream.
+        none,
+        /// As a debugging utility, counts down to zero as bytes are written.
+        content_length: u64,
+        /// Each chunk is wrapped in a header and trailer.
+        chunked,
+    };
+
     pub const WriteError = net.Stream.WriteError;
 
     /// When using content-length, asserts that the amount of data sent matches
@@ -723,11 +740,17 @@ pub const Response = struct {
     /// end-of-stream message, then flushes the stream to the system.
     /// Respects the value of `elide_body` to omit all data after the headers.
     pub fn end(r: *Response) WriteError!void {
-        if (r.content_length) |len| {
-            assert(len == 0); // Trips when end() called before all bytes written.
-            try flush_cl(r);
-        } else {
-            try flush_chunked(r, &.{});
+        switch (r.transfer_encoding) {
+            .content_length => |len| {
+                assert(len == 0); // Trips when end() called before all bytes written.
+                try flush_cl(r);
+            },
+            .none => {
+                try flush_cl(r);
+            },
+            .chunked => {
+                try flush_chunked(r, &.{});
+            },
         }
         r.* = undefined;
     }
@@ -752,16 +775,21 @@ pub const Response = struct {
     /// May return 0, which does not indicate end of stream. The caller decides
     /// when the end of stream occurs by calling `end`.
     pub fn write(r: *Response, bytes: []const u8) WriteError!usize {
-        if (r.content_length != null) {
-            return write_cl(r, bytes);
-        } else {
-            return write_chunked(r, bytes);
+        switch (r.transfer_encoding) {
+            .content_length, .none => return write_cl(r, bytes),
+            .chunked => return write_chunked(r, bytes),
         }
     }
 
     fn write_cl(context: *const anyopaque, bytes: []const u8) WriteError!usize {
         const r: *Response = @constCast(@alignCast(@ptrCast(context)));
-        const len = &r.content_length.?;
+
+        var trash: u64 = std.math.maxInt(u64);
+        const len = switch (r.transfer_encoding) {
+            .content_length => |*len| len,
+            else => &trash,
+        };
+
         if (r.elide_body) {
             len.* -= bytes.len;
             return bytes.len;
@@ -805,7 +833,7 @@ pub const Response = struct {
 
     fn write_chunked(context: *const anyopaque, bytes: []const u8) WriteError!usize {
         const r: *Response = @constCast(@alignCast(@ptrCast(context)));
-        assert(r.content_length == null);
+        assert(r.transfer_encoding == .chunked);
 
         if (r.elide_body)
             return bytes.len;
@@ -867,15 +895,13 @@ pub const Response = struct {
     /// This is redundant after calling `end`.
     /// Respects the value of `elide_body` to omit all data after the headers.
     pub fn flush(r: *Response) WriteError!void {
-        if (r.content_length != null) {
-            return flush_cl(r);
-        } else {
-            return flush_chunked(r, null);
+        switch (r.transfer_encoding) {
+            .none, .content_length => return flush_cl(r),
+            .chunked => return flush_chunked(r, null),
         }
     }
 
     fn flush_cl(r: *Response) WriteError!void {
-        assert(r.content_length != null);
         try r.stream.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]);
         r.send_buffer_start = 0;
         r.send_buffer_end = 0;
@@ -884,7 +910,7 @@ pub const Response = struct {
     fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) WriteError!void {
         const max_trailers = 25;
         if (end_trailers) |trailers| assert(trailers.len <= max_trailers);
-        assert(r.content_length == null);
+        assert(r.transfer_encoding == .chunked);
 
         const http_headers = r.send_buffer[r.send_buffer_start .. r.send_buffer_end - r.chunk_len];
 
@@ -976,7 +1002,10 @@ pub const Response = struct {
 
     pub fn writer(r: *Response) std.io.AnyWriter {
         return .{
-            .writeFn = if (r.content_length != null) write_cl else write_chunked,
+            .writeFn = switch (r.transfer_encoding) {
+                .none, .content_length => write_cl,
+                .chunked => write_chunked,
+            },
             .context = r,
         };
     }
lib/std/http/test.zig
@@ -222,6 +222,72 @@ test "echo content server" {
     }
 }
 
+test "Server.Request.respondStreaming non-chunked, unknown content-length" {
+    // In this case, the response is expected to stream until the connection is
+    // closed, indicating the end of the body.
+    const test_server = try createTestServer(struct {
+        fn run(net_server: *std.net.Server) anyerror!void {
+            var header_buffer: [1000]u8 = undefined;
+            var remaining: usize = 1;
+            while (remaining != 0) : (remaining -= 1) {
+                const conn = try net_server.accept();
+                defer conn.stream.close();
+
+                var server = std.http.Server.init(conn, &header_buffer);
+
+                try expectEqual(.ready, server.state);
+                var request = try server.receiveHead();
+                try expectEqualStrings(request.head.target, "/foo");
+                var send_buffer: [500]u8 = undefined;
+                var response = request.respondStreaming(.{
+                    .send_buffer = &send_buffer,
+                    .respond_options = .{
+                        .transfer_encoding = .none,
+                    },
+                });
+                var total: usize = 0;
+                for (0..500) |i| {
+                    var buf: [30]u8 = undefined;
+                    const line = try std.fmt.bufPrint(&buf, "{d}, ah ha ha!\n", .{i});
+                    try response.writeAll(line);
+                    total += line.len;
+                }
+                try expectEqual(7390, total);
+                try response.end();
+                try expectEqual(.closing, server.state);
+            }
+        }
+    });
+    defer test_server.destroy();
+
+    const request_bytes = "GET /foo HTTP/1.1\r\n\r\n";
+    const gpa = std.testing.allocator;
+    const stream = try std.net.tcpConnectToHost(gpa, "127.0.0.1", test_server.port());
+    defer stream.close();
+    try stream.writeAll(request_bytes);
+
+    const response = try stream.reader().readAllAlloc(gpa, 8192);
+    defer gpa.free(response);
+
+    var expected_response = std.ArrayList(u8).init(gpa);
+    defer expected_response.deinit();
+
+    try expected_response.appendSlice("HTTP/1.1 200 OK\r\n\r\n");
+
+    {
+        var total: usize = 0;
+        for (0..500) |i| {
+            var buf: [30]u8 = undefined;
+            const line = try std.fmt.bufPrint(&buf, "{d}, ah ha ha!\n", .{i});
+            try expected_response.appendSlice(line);
+            total += line.len;
+        }
+        try expectEqual(7390, total);
+    }
+
+    try expectEqualStrings(expected_response.items, response);
+}
+
 fn echoTests(client: *std.http.Client, port: u16) !void {
     const gpa = std.testing.allocator;
     var location_buffer: [100]u8 = undefined;