Commit 380916c0f8

Andrew Kelley <andrew@ziglang.org>
2024-02-22 04:51:55
std.http.Server.Request.Respond: support all transfer encodings
Before I mistakenly thought that missing content-length meant zero when it actually means to stream until the connection is closed. Now the respond() function accepts transfer_encoding which can be left as default (use content.len for content-length), set to none which makes it omit the content-length, or chunked, which makes it format the response as a chunked transfer even though the server has the entire contents already buffered. The echo-content tests are moved from test/standalone/http.zig to the standard library where they are actually run.
1 parent 40ed3c4
Changed files (3)
lib
test
standalone
lib/std/http/Server.zig
@@ -279,13 +279,15 @@ pub const Request = struct {
         reason: ?[]const u8 = null,
         keep_alive: bool = true,
         extra_headers: []const http.Header = &.{},
+        transfer_encoding: ?http.TransferEncoding = null,
     };
 
     /// Send an entire HTTP response to the client, including headers and body.
     ///
     /// Automatically handles HEAD requests by omitting the body.
-    /// Uses the "content-length" header unless `content` is empty in which
-    /// case it omits the content-length header.
+    ///
+    /// Unless `transfer_encoding` is specified, uses the "content-length"
+    /// header.
     ///
     /// If the request contains a body and the connection is to be reused,
     /// discards the request body, leaving the Server in the `ready` state. If
@@ -303,7 +305,9 @@ pub const Request = struct {
         assert(options.status != .@"continue");
         assert(options.extra_headers.len <= max_extra_headers);
 
-        const keep_alive = request.discardBody(options.keep_alive);
+        const transfer_encoding_none = (options.transfer_encoding orelse .chunked) == .none;
+        const server_keep_alive = !transfer_encoding_none and options.keep_alive;
+        const keep_alive = request.discardBody(server_keep_alive);
 
         const phrase = options.reason orelse options.status.phrase() orelse "";
 
@@ -314,9 +318,15 @@ pub const Request = struct {
         }) catch unreachable;
         if (keep_alive)
             h.appendSliceAssumeCapacity("connection: keep-alive\r\n");
-        if (content.len > 0)
+
+        if (options.transfer_encoding) |transfer_encoding| switch (transfer_encoding) {
+            .none => {},
+            .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"),
+        } else {
             h.fixedWriter().print("content-length: {d}\r\n", .{content.len}) catch unreachable;
+        }
 
+        var chunk_header_buffer: [18]u8 = undefined;
         var iovecs: [max_extra_headers * 4 + 3]std.posix.iovec_const = undefined;
         var iovecs_len: usize = 0;
 
@@ -358,12 +368,47 @@ pub const Request = struct {
         };
         iovecs_len += 1;
 
-        if (request.head.method != .HEAD and content.len > 0) {
-            iovecs[iovecs_len] = .{
-                .iov_base = content.ptr,
-                .iov_len = content.len,
-            };
-            iovecs_len += 1;
+        if (request.head.method != .HEAD) {
+            const is_chunked = (options.transfer_encoding orelse .none) == .chunked;
+            if (is_chunked) {
+                if (content.len > 0) {
+                    const chunk_header = std.fmt.bufPrint(
+                        &chunk_header_buffer,
+                        "{x}\r\n",
+                        .{content.len},
+                    ) catch unreachable;
+
+                    iovecs[iovecs_len] = .{
+                        .iov_base = chunk_header.ptr,
+                        .iov_len = chunk_header.len,
+                    };
+                    iovecs_len += 1;
+
+                    iovecs[iovecs_len] = .{
+                        .iov_base = content.ptr,
+                        .iov_len = content.len,
+                    };
+                    iovecs_len += 1;
+
+                    iovecs[iovecs_len] = .{
+                        .iov_base = "\r\n",
+                        .iov_len = 2,
+                    };
+                    iovecs_len += 1;
+                }
+
+                iovecs[iovecs_len] = .{
+                    .iov_base = "0\r\n\r\n",
+                    .iov_len = 5,
+                };
+                iovecs_len += 1;
+            } else if (content.len > 0) {
+                iovecs[iovecs_len] = .{
+                    .iov_base = content.ptr,
+                    .iov_len = content.len,
+                };
+                iovecs_len += 1;
+            }
         }
 
         try request.server.connection.stream.writevAll(iovecs[0..iovecs_len]);
@@ -400,8 +445,9 @@ pub const Request = struct {
     pub fn respondStreaming(request: *Request, options: RespondStreamingOptions) Response {
         const o = options.respond_options;
         assert(o.status != .@"continue");
-
-        const keep_alive = request.discardBody(o.keep_alive);
+        const transfer_encoding_none = (o.transfer_encoding orelse .chunked) == .none;
+        const server_keep_alive = !transfer_encoding_none and o.keep_alive;
+        const keep_alive = request.discardBody(server_keep_alive);
         const phrase = o.reason orelse o.status.phrase() orelse "";
 
         var h = std.ArrayListUnmanaged(u8).initBuffer(options.send_buffer);
@@ -815,26 +861,32 @@ pub const Response = struct {
         };
         iovecs_len += 1;
 
-        iovecs[iovecs_len] = .{
-            .iov_base = chunk_header.ptr,
-            .iov_len = chunk_header.len,
-        };
-        iovecs_len += 1;
+        if (r.chunk_len > 0) {
+            iovecs[iovecs_len] = .{
+                .iov_base = chunk_header.ptr,
+                .iov_len = chunk_header.len,
+            };
+            iovecs_len += 1;
 
-        iovecs[iovecs_len] = .{
-            .iov_base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len,
-            .iov_len = r.chunk_len,
-        };
-        iovecs_len += 1;
+            iovecs[iovecs_len] = .{
+                .iov_base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len,
+                .iov_len = r.chunk_len,
+            };
+            iovecs_len += 1;
+
+            iovecs[iovecs_len] = .{
+                .iov_base = "\r\n",
+                .iov_len = 2,
+            };
+            iovecs_len += 1;
+        }
 
         if (end_trailers) |trailers| {
-            if (r.chunk_len > 0) {
-                iovecs[iovecs_len] = .{
-                    .iov_base = "\r\n0\r\n",
-                    .iov_len = 5,
-                };
-                iovecs_len += 1;
-            }
+            iovecs[iovecs_len] = .{
+                .iov_base = "0\r\n",
+                .iov_len = 3,
+            };
+            iovecs_len += 1;
 
             for (trailers) |trailer| {
                 iovecs[iovecs_len] = .{
@@ -862,12 +914,6 @@ pub const Response = struct {
                 iovecs_len += 1;
             }
 
-            iovecs[iovecs_len] = .{
-                .iov_base = "\r\n",
-                .iov_len = 2,
-            };
-            iovecs_len += 1;
-        } else if (r.chunk_len > 0) {
             iovecs[iovecs_len] = .{
                 .iov_base = "\r\n",
                 .iov_len = 2,
lib/std/http/test.zig
@@ -1,6 +1,7 @@
 const builtin = @import("builtin");
 const std = @import("std");
 const testing = std.testing;
+const native_endian = builtin.cpu.arch.endian();
 
 test "trailers" {
     if (builtin.single_threaded) return error.SkipZigTest;
@@ -106,7 +107,6 @@ test "HTTP server handles a chunked transfer coding request" {
         return error.SkipZigTest;
     }
 
-    const native_endian = comptime builtin.cpu.arch.endian();
     if (builtin.zig_backend == .stage2_llvm and native_endian == .big) {
         // https://github.com/ziglang/zig/issues/13782
         return error.SkipZigTest;
@@ -168,3 +168,243 @@ test "HTTP server handles a chunked transfer coding request" {
 
     server_thread.join();
 }
+
+test "echo content server" {
+    if (builtin.single_threaded) return error.SkipZigTest;
+    if (builtin.os.tag == .wasi) return error.SkipZigTest;
+
+    if (builtin.zig_backend == .stage2_llvm and native_endian == .big) {
+        // https://github.com/ziglang/zig/issues/13782
+        return error.SkipZigTest;
+    }
+
+    const gpa = std.testing.allocator;
+
+    const address = try std.net.Address.parseIp("127.0.0.1", 0);
+    var socket_server = try address.listen(.{ .reuse_address = true });
+    defer socket_server.deinit();
+    const port = socket_server.listen_address.in.getPort();
+
+    const server_thread = try std.Thread.spawn(.{}, (struct {
+        fn handleRequest(request: *std.http.Server.Request) !void {
+            std.debug.print("server received {s} {s} {s}\n", .{
+                @tagName(request.head.method),
+                @tagName(request.head.version),
+                request.head.target,
+            });
+
+            const body = try request.reader().readAllAlloc(std.testing.allocator, 8192);
+            defer std.testing.allocator.free(body);
+
+            try testing.expect(std.mem.startsWith(u8, request.head.target, "/echo-content"));
+            try testing.expectEqualStrings("Hello, World!\n", body);
+            try testing.expectEqualStrings("text/plain", request.head.content_type.?);
+
+            var send_buffer: [100]u8 = undefined;
+            var response = request.respondStreaming(.{
+                .send_buffer = &send_buffer,
+                .content_length = switch (request.head.transfer_encoding) {
+                    .chunked => null,
+                    .none => len: {
+                        try testing.expectEqual(14, request.head.content_length.?);
+                        break :len 14;
+                    },
+                },
+            });
+
+            try response.flush(); // Test an early flush to send the HTTP headers before the body.
+            const w = response.writer();
+            try w.writeAll("Hello, ");
+            try w.writeAll("World!\n");
+            try response.end();
+            std.debug.print("  server finished responding\n", .{});
+        }
+
+        fn run(net_server: *std.net.Server) anyerror!void {
+            var read_buffer: [1024]u8 = undefined;
+
+            accept: while (true) {
+                const conn = try net_server.accept();
+                defer conn.stream.close();
+
+                var http_server = std.http.Server.init(conn, &read_buffer);
+
+                while (http_server.state == .ready) {
+                    var request = http_server.receiveHead() catch |err| switch (err) {
+                        error.HttpConnectionClosing => continue :accept,
+                        else => |e| return e,
+                    };
+                    if (std.mem.eql(u8, request.head.target, "/end")) {
+                        return request.respond("", .{ .keep_alive = false });
+                    }
+                    handleRequest(&request) catch |err| {
+                        // This message helps the person troubleshooting determine whether
+                        // output comes from the server thread or the client thread.
+                        std.debug.print("handleRequest failed with '{s}'\n", .{@errorName(err)});
+                        return err;
+                    };
+                }
+            }
+        }
+    }).run, .{&socket_server});
+
+    defer server_thread.join();
+
+    {
+        var client: std.http.Client = .{ .allocator = gpa };
+        defer client.deinit();
+
+        try echoTests(&client, port);
+    }
+}
+
+fn echoTests(client: *std.http.Client, port: u16) !void {
+    const gpa = testing.allocator;
+    var location_buffer: [100]u8 = undefined;
+
+    { // send content-length request
+        const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/echo-content", .{port});
+        defer gpa.free(location);
+        const uri = try std.Uri.parse(location);
+
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.POST, uri, .{
+            .server_header_buffer = &server_header_buffer,
+            .extra_headers = &.{
+                .{ .name = "content-type", .value = "text/plain" },
+            },
+        });
+        defer req.deinit();
+
+        req.transfer_encoding = .{ .content_length = 14 };
+
+        try req.send(.{});
+        try req.writeAll("Hello, ");
+        try req.writeAll("World!\n");
+        try req.finish();
+
+        try req.wait();
+
+        const body = try req.reader().readAllAlloc(gpa, 8192);
+        defer gpa.free(body);
+
+        try testing.expectEqualStrings("Hello, World!\n", body);
+    }
+
+    // connection has been kept alive
+    try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1);
+
+    { // send chunked request
+        const uri = try std.Uri.parse(try std.fmt.bufPrint(
+            &location_buffer,
+            "http://127.0.0.1:{d}/echo-content",
+            .{port},
+        ));
+
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.POST, uri, .{
+            .server_header_buffer = &server_header_buffer,
+            .extra_headers = &.{
+                .{ .name = "content-type", .value = "text/plain" },
+            },
+        });
+        defer req.deinit();
+
+        req.transfer_encoding = .chunked;
+
+        try req.send(.{});
+        try req.writeAll("Hello, ");
+        try req.writeAll("World!\n");
+        try req.finish();
+
+        try req.wait();
+
+        const body = try req.reader().readAllAlloc(gpa, 8192);
+        defer gpa.free(body);
+
+        try testing.expectEqualStrings("Hello, World!\n", body);
+    }
+
+    // connection has been kept alive
+    try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1);
+
+    { // Client.fetch()
+
+        const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/echo-content#fetch", .{port});
+        defer gpa.free(location);
+
+        var body = std.ArrayList(u8).init(gpa);
+        defer body.deinit();
+
+        const res = try client.fetch(.{
+            .location = .{ .url = location },
+            .method = .POST,
+            .payload = "Hello, World!\n",
+            .extra_headers = &.{
+                .{ .name = "content-type", .value = "text/plain" },
+            },
+            .response_storage = .{ .dynamic = &body },
+        });
+        try testing.expectEqual(.ok, res.status);
+        try testing.expectEqualStrings("Hello, World!\n", body.items);
+    }
+
+    { // expect: 100-continue
+        const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/echo-content#expect-100", .{port});
+        defer gpa.free(location);
+        const uri = try std.Uri.parse(location);
+
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.POST, uri, .{
+            .server_header_buffer = &server_header_buffer,
+            .extra_headers = &.{
+                .{ .name = "expect", .value = "100-continue" },
+                .{ .name = "content-type", .value = "text/plain" },
+            },
+        });
+        defer req.deinit();
+
+        req.transfer_encoding = .chunked;
+
+        try req.send(.{});
+        try req.writeAll("Hello, ");
+        try req.writeAll("World!\n");
+        try req.finish();
+
+        try req.wait();
+        try testing.expectEqual(.ok, req.response.status);
+
+        const body = try req.reader().readAllAlloc(gpa, 8192);
+        defer gpa.free(body);
+
+        try testing.expectEqualStrings("Hello, World!\n", body);
+    }
+
+    { // expect: garbage
+        const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/echo-content#expect-garbage", .{port});
+        defer gpa.free(location);
+        const uri = try std.Uri.parse(location);
+
+        var server_header_buffer: [1024]u8 = undefined;
+        var req = try client.open(.POST, uri, .{
+            .server_header_buffer = &server_header_buffer,
+            .extra_headers = &.{
+                .{ .name = "content-type", .value = "text/plain" },
+                .{ .name = "expect", .value = "garbage" },
+            },
+        });
+        defer req.deinit();
+
+        req.transfer_encoding = .chunked;
+
+        try req.send(.{});
+        try req.wait();
+        try testing.expectEqual(.expectation_failed, req.response.status);
+    }
+
+    _ = try client.fetch(.{
+        .location = .{
+            .url = try std.fmt.bufPrint(&location_buffer, "http://127.0.0.1:{d}/end", .{port}),
+        },
+    });
+}
test/standalone/http.zig
@@ -81,26 +81,6 @@ fn handleRequest(request: *http.Server.Request, listen_port: u16) !void {
             try w.writeAll("Hello, World!\n");
         }
 
-        try response.end();
-    } else if (mem.startsWith(u8, request.head.target, "/echo-content")) {
-        try testing.expectEqualStrings("Hello, World!\n", body);
-        try testing.expectEqualStrings("text/plain", request.head.content_type.?);
-
-        var response = request.respondStreaming(.{
-            .send_buffer = &send_buffer,
-            .content_length = switch (request.head.transfer_encoding) {
-                .chunked => null,
-                .none => len: {
-                    try testing.expectEqual(14, request.head.content_length.?);
-                    break :len 14;
-                },
-            },
-        });
-
-        try response.flush(); // Test an early flush to send the HTTP headers before the body.
-        const w = response.writer();
-        try w.writeAll("Hello, ");
-        try w.writeAll("World!\n");
         try response.end();
     } else if (mem.eql(u8, request.head.target, "/redirect/1")) {
         var response = request.respondStreaming(.{
@@ -351,39 +331,6 @@ pub fn main() !void {
     // connection has been kept alive
     try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1);
 
-    { // send content-length request
-        const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content", .{port});
-        defer calloc.free(location);
-        const uri = try std.Uri.parse(location);
-
-        log.info("{s}", .{location});
-        var server_header_buffer: [1024]u8 = undefined;
-        var req = try client.open(.POST, uri, .{
-            .server_header_buffer = &server_header_buffer,
-            .extra_headers = &.{
-                .{ .name = "content-type", .value = "text/plain" },
-            },
-        });
-        defer req.deinit();
-
-        req.transfer_encoding = .{ .content_length = 14 };
-
-        try req.send(.{});
-        try req.writeAll("Hello, ");
-        try req.writeAll("World!\n");
-        try req.finish();
-
-        try req.wait();
-
-        const body = try req.reader().readAllAlloc(calloc, 8192);
-        defer calloc.free(body);
-
-        try testing.expectEqualStrings("Hello, World!\n", body);
-    }
-
-    // connection has been kept alive
-    try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1);
-
     { // read content-length response with connection close
         const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get", .{port});
         defer calloc.free(location);
@@ -410,39 +357,6 @@ pub fn main() !void {
     // connection has been closed
     try testing.expect(client.connection_pool.free_len == 0);
 
-    { // send chunked request
-        const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content", .{port});
-        defer calloc.free(location);
-        const uri = try std.Uri.parse(location);
-
-        log.info("{s}", .{location});
-        var server_header_buffer: [1024]u8 = undefined;
-        var req = try client.open(.POST, uri, .{
-            .server_header_buffer = &server_header_buffer,
-            .extra_headers = &.{
-                .{ .name = "content-type", .value = "text/plain" },
-            },
-        });
-        defer req.deinit();
-
-        req.transfer_encoding = .chunked;
-
-        try req.send(.{});
-        try req.writeAll("Hello, ");
-        try req.writeAll("World!\n");
-        try req.finish();
-
-        try req.wait();
-
-        const body = try req.reader().readAllAlloc(calloc, 8192);
-        defer calloc.free(body);
-
-        try testing.expectEqualStrings("Hello, World!\n", body);
-    }
-
-    // connection has been kept alive
-    try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1);
-
     { // relative redirect
         const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/redirect/1", .{port});
         defer calloc.free(location);
@@ -561,83 +475,6 @@ pub fn main() !void {
     // connection has been kept alive
     try testing.expect(client.http_proxy != null or client.connection_pool.free_len == 1);
 
-    { // Client.fetch()
-
-        const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content#fetch", .{port});
-        defer calloc.free(location);
-
-        log.info("{s}", .{location});
-        var body = std.ArrayList(u8).init(calloc);
-        defer body.deinit();
-
-        const res = try client.fetch(.{
-            .location = .{ .url = location },
-            .method = .POST,
-            .payload = "Hello, World!\n",
-            .extra_headers = &.{
-                .{ .name = "content-type", .value = "text/plain" },
-            },
-            .response_storage = .{ .dynamic = &body },
-        });
-        try testing.expectEqual(.ok, res.status);
-        try testing.expectEqualStrings("Hello, World!\n", body.items);
-    }
-
-    { // expect: 100-continue
-        const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content#expect-100", .{port});
-        defer calloc.free(location);
-        const uri = try std.Uri.parse(location);
-
-        log.info("{s}", .{location});
-        var server_header_buffer: [1024]u8 = undefined;
-        var req = try client.open(.POST, uri, .{
-            .server_header_buffer = &server_header_buffer,
-            .extra_headers = &.{
-                .{ .name = "expect", .value = "100-continue" },
-                .{ .name = "content-type", .value = "text/plain" },
-            },
-        });
-        defer req.deinit();
-
-        req.transfer_encoding = .chunked;
-
-        try req.send(.{});
-        try req.writeAll("Hello, ");
-        try req.writeAll("World!\n");
-        try req.finish();
-
-        try req.wait();
-        try testing.expectEqual(http.Status.ok, req.response.status);
-
-        const body = try req.reader().readAllAlloc(calloc, 8192);
-        defer calloc.free(body);
-
-        try testing.expectEqualStrings("Hello, World!\n", body);
-    }
-
-    { // expect: garbage
-        const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content#expect-garbage", .{port});
-        defer calloc.free(location);
-        const uri = try std.Uri.parse(location);
-
-        log.info("{s}", .{location});
-        var server_header_buffer: [1024]u8 = undefined;
-        var req = try client.open(.POST, uri, .{
-            .server_header_buffer = &server_header_buffer,
-            .extra_headers = &.{
-                .{ .name = "content-type", .value = "text/plain" },
-                .{ .name = "expect", .value = "garbage" },
-            },
-        });
-        defer req.deinit();
-
-        req.transfer_encoding = .chunked;
-
-        try req.send(.{});
-        try req.wait();
-        try testing.expectEqual(http.Status.expectation_failed, req.response.status);
-    }
-
     { // issue 16282 *** This test leaves the client in an invalid state, it must be last ***
         const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get", .{port});
         defer calloc.free(location);