Commit aa090a49d9

Nameless <truemedian@gmail.com>
2023-08-22 17:05:03
std.http: handle expect:100-continue and continue responses
1 parent 5d40338
Changed files (4)
lib
test
standalone
lib/std/http/Client.zig
@@ -478,6 +478,7 @@ pub const Request = struct {
             .zstd => |*zstd| zstd.deinit(),
         }
 
+        req.headers.deinit();
         req.response.headers.deinit();
 
         if (req.response.parser.header_bytes_owned) {
@@ -667,17 +668,19 @@ pub const Request = struct {
 
             try req.response.parse(req.response.parser.header_bytes.items, false);
 
-            if (req.response.status == .switching_protocols) {
-                req.connection.?.data.closing = false;
-                req.response.parser.done = true;
+            if (req.response.status == .@"continue") {
+                req.response.parser.done = true; // we're done parsing the continue response, reset to prepare for the real response
+                req.response.parser.reset();
+                break;
             }
 
-            if (req.method == .CONNECT and req.response.status == .ok) {
+            // we're switching protocols, so this connection is no longer doing http
+            if (req.response.status == .switching_protocols or (req.method == .CONNECT and req.response.status == .ok)) {
                 req.connection.?.data.closing = false;
                 req.response.parser.done = true;
             }
 
-            // we default to using keep-alive if not provided
+            // we default to using keep-alive if not provided in the client if the server asks for it
             const req_connection = req.headers.getFirstValue("connection");
             const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?);
 
@@ -955,6 +958,38 @@ pub fn connectUnproxied(client: *Client, host: []const u8, port: u16, protocol:
     return conn;
 }
 
+pub const ConnectUnixError = Allocator.Error || std.os.SocketError || error{NameTooLong} || std.os.ConnectError;
+
+pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*ConnectionPool.Node {
+    if (client.connection_pool.findConnection(.{
+        .host = path,
+        .port = 0,
+        .is_tls = false,
+    })) |node|
+        return node;
+
+    const conn = try client.allocator.create(ConnectionPool.Node);
+    errdefer client.allocator.destroy(conn);
+    conn.* = .{ .data = undefined };
+
+    const stream = try std.net.connectUnixSocket(path);
+    errdefer stream.close();
+
+    conn.data = .{
+        .stream = stream,
+        .tls_client = undefined,
+        .protocol = .plain,
+
+        .host = try client.allocator.dupe(u8, path),
+        .port = 0,
+    };
+    errdefer client.allocator.free(conn.data.host);
+
+    client.connection_pool.addUsed(conn);
+
+    return conn;
+}
+
 // Prevents a dependency loop in request()
 const ConnectErrorPartial = ConnectUnproxiedError || error{ UnsupportedUrlScheme, ConnectionRefused };
 pub const ConnectError = ConnectErrorPartial || RequestError;
lib/std/http/protocol.zig
@@ -534,9 +534,9 @@ pub const HeadersParser = struct {
 
                         if (r.next_chunk_length == 0) r.done = true;
 
-                        return 0;
-                    } else {
-                        const out_avail = buffer.len;
+                        return out_index;
+                    } else if (out_index < buffer.len) {
+                        const out_avail = buffer.len - out_index;
 
                         const can_read = @as(usize, @intCast(@min(data_avail, out_avail)));
                         const nread = try conn.read(buffer[0..can_read]);
@@ -545,6 +545,8 @@ pub const HeadersParser = struct {
                         if (r.next_chunk_length == 0) r.done = true;
 
                         return nread;
+                    } else {
+                        return out_index;
                     }
                 },
                 .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => {
@@ -558,6 +560,7 @@ pub const HeadersParser = struct {
                         .chunk_data => if (r.next_chunk_length == 0) {
                             if (std.mem.eql(u8, conn.peek(), "\r\n")) {
                                 r.state = .finished;
+                                r.done = true;
                             } else {
                                 // The trailer section is formatted identically to the header section.
                                 r.state = .seen_rn;
lib/std/http/Server.zig
@@ -411,48 +411,52 @@ pub const Response = struct {
         }
         try w.writeAll("\r\n");
 
-        if (!res.headers.contains("server")) {
-            try w.writeAll("Server: zig (std.http)\r\n");
-        }
+        if (res.status == .@"continue") {
+            res.state = .waited; // we still need to send another request after this
+        } else {
+            if (!res.headers.contains("server")) {
+                try w.writeAll("Server: zig (std.http)\r\n");
+            }
 
-        if (!res.headers.contains("connection")) {
-            const req_connection = res.request.headers.getFirstValue("connection");
-            const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?);
+            if (!res.headers.contains("connection")) {
+                const req_connection = res.request.headers.getFirstValue("connection");
+                const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?);
 
-            if (req_keepalive) {
-                try w.writeAll("Connection: keep-alive\r\n");
-            } else {
-                try w.writeAll("Connection: close\r\n");
+                if (req_keepalive) {
+                    try w.writeAll("Connection: keep-alive\r\n");
+                } else {
+                    try w.writeAll("Connection: close\r\n");
+                }
             }
-        }
 
-        const has_transfer_encoding = res.headers.contains("transfer-encoding");
-        const has_content_length = res.headers.contains("content-length");
+            const has_transfer_encoding = res.headers.contains("transfer-encoding");
+            const has_content_length = res.headers.contains("content-length");
 
-        if (!has_transfer_encoding and !has_content_length) {
-            switch (res.transfer_encoding) {
-                .chunked => try w.writeAll("Transfer-Encoding: chunked\r\n"),
-                .content_length => |content_length| try w.print("Content-Length: {d}\r\n", .{content_length}),
-                .none => {},
-            }
-        } else {
-            if (has_content_length) {
-                const content_length = std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength;
-
-                res.transfer_encoding = .{ .content_length = content_length };
-            } else if (has_transfer_encoding) {
-                const transfer_encoding = res.headers.getFirstValue("transfer-encoding").?;
-                if (std.mem.eql(u8, transfer_encoding, "chunked")) {
-                    res.transfer_encoding = .chunked;
-                } else {
-                    return error.UnsupportedTransferEncoding;
+            if (!has_transfer_encoding and !has_content_length) {
+                switch (res.transfer_encoding) {
+                    .chunked => try w.writeAll("Transfer-Encoding: chunked\r\n"),
+                    .content_length => |content_length| try w.print("Content-Length: {d}\r\n", .{content_length}),
+                    .none => {},
                 }
             } else {
-                res.transfer_encoding = .none;
+                if (has_content_length) {
+                    const content_length = std.fmt.parseInt(u64, res.headers.getFirstValue("content-length").?, 10) catch return error.InvalidContentLength;
+
+                    res.transfer_encoding = .{ .content_length = content_length };
+                } else if (has_transfer_encoding) {
+                    const transfer_encoding = res.headers.getFirstValue("transfer-encoding").?;
+                    if (std.mem.eql(u8, transfer_encoding, "chunked")) {
+                        res.transfer_encoding = .chunked;
+                    } else {
+                        return error.UnsupportedTransferEncoding;
+                    }
+                } else {
+                    res.transfer_encoding = .none;
+                }
             }
-        }
 
-        try w.print("{}", .{res.headers});
+            try w.print("{}", .{res.headers});
+        }
 
         try w.writeAll("\r\n");
 
@@ -516,6 +520,10 @@ pub const Response = struct {
             res.request.parser.done = true;
         }
 
+        if (res.request.method == .HEAD) {
+            res.request.parser.done = true;
+        }
+
         if (!res.request.parser.done) {
             if (res.request.transfer_compression) |tc| switch (tc) {
                 .compress => return error.CompressionNotSupported,
test/standalone/http.zig
@@ -22,6 +22,18 @@ fn handleRequest(res: *Server.Response) !void {
 
     log.info("{s} {s} {s}", .{ @tagName(res.request.method), @tagName(res.request.version), res.request.target });
 
+    if (res.request.headers.contains("expect")) {
+        if (mem.eql(u8, res.request.headers.getFirstValue("expect").?, "100-continue")) {
+            res.status = .@"continue";
+            try res.do();
+            res.status = .ok;
+        } else {
+            res.status = .expectation_failed;
+            try res.do();
+            return;
+        }
+    }
+
     const body = try res.reader().readAllAlloc(salloc, 8192);
     defer salloc.free(body);
 
@@ -62,7 +74,7 @@ fn handleRequest(res: *Server.Response) !void {
         }
 
         try res.finish();
-    } else if (mem.eql(u8, res.request.target, "/echo-content")) {
+    } else if (mem.startsWith(u8, res.request.target, "/echo-content")) {
         try testing.expectEqualStrings("Hello, World!\n", body);
         try testing.expectEqualStrings("text/plain", res.request.headers.getFirstValue("content-type").?);
 
@@ -592,6 +604,62 @@ pub fn main() !void {
         try testing.expectEqualStrings("Hello, World!\n", res.body.?);
     }
 
+    { // expect: 100-continue
+        var h = http.Headers{ .allocator = calloc };
+        defer h.deinit();
+
+        try h.append("expect", "100-continue");
+        try h.append("content-type", "text/plain");
+
+        const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content#expect-100", .{port});
+        defer calloc.free(location);
+        const uri = try std.Uri.parse(location);
+
+        log.info("{s}", .{location});
+        var req = try client.request(.POST, uri, h, .{});
+        defer req.deinit();
+
+        req.transfer_encoding = .chunked;
+
+        try req.start();
+        try req.wait();
+        try testing.expectEqual(http.Status.@"continue", req.response.status);
+
+        try req.writeAll("Hello, ");
+        try req.writeAll("World!\n");
+        try req.finish();
+
+        try req.wait();
+        try testing.expectEqual(http.Status.ok, req.response.status);
+
+        const body = try req.reader().readAllAlloc(calloc, 8192);
+        defer calloc.free(body);
+
+        try testing.expectEqualStrings("Hello, World!\n", body);
+    }
+
+    { // expect: garbage
+        var h = http.Headers{ .allocator = calloc };
+        defer h.deinit();
+
+        try h.append("content-type", "text/plain");
+        try h.append("expect", "garbage");
+
+        const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/echo-content#expect-garbage", .{port});
+        defer calloc.free(location);
+        const uri = try std.Uri.parse(location);
+
+        log.info("{s}", .{location});
+        var req = try client.request(.POST, uri, h, .{});
+        defer req.deinit();
+
+        req.transfer_encoding = .chunked;
+
+        try req.start();
+        try req.wait();
+        try testing.expectEqual(http.Status.expectation_failed, req.response.status);
+    }
+
     { // issue 16282 *** This test leaves the client in an invalid state, it must be last ***
         const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get", .{port});
         defer calloc.free(location);