Commit 1b3ebfefd8

Nameless <truemedian@gmail.com>
2023-05-03 21:34:10
fix keepalive and large buffered writes
1 parent 5f219a2
Changed files (3)
lib
test
standalone
lib/std/http/Client.zig
@@ -71,7 +71,7 @@ pub const ConnectionPool = struct {
         while (next) |node| : (next = node.prev) {
             if ((node.data.buffered.conn.protocol == .tls) != criteria.is_tls) continue;
             if (node.data.port != criteria.port) continue;
-            if (mem.eql(u8, node.data.host, criteria.host)) continue;
+            if (!mem.eql(u8, node.data.host, criteria.host)) continue;
 
             pool.acquireUnsafe(node);
             return node;
@@ -317,32 +317,29 @@ pub const BufferedConnection = struct {
     }
 
     pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void {
-        if (bconn.write_buf.len - bconn.write_end <= buffer.len) {
-            @memcpy(bconn.write_buf[bconn.write_end..], buffer);
+        if (bconn.write_buf.len - bconn.write_end >= buffer.len) {
+            @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer);
             bconn.write_end += @intCast(u16, buffer.len);
         } else {
-            try bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]);
-            bconn.write_end = 0;
-
+            try bconn.flush();
             try bconn.conn.writeAll(buffer);
         }
     }
 
     pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize {
-        if (bconn.write_buf.len - bconn.write_end <= buffer.len) {
-            @memcpy(bconn.write_buf[bconn.write_end..], buffer);
+        if (bconn.write_buf.len - bconn.write_end >= buffer.len) {
+            @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer);
             bconn.write_end += @intCast(u16, buffer.len);
 
             return buffer.len;
         } else {
-            try bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]);
-            bconn.write_end = 0;
-
+            try bconn.flush();
             return try bconn.conn.write(buffer);
         }
     }
 
     pub fn flush(bconn: *BufferedConnection) WriteError!void {
+        defer bconn.write_end = 0;
         return bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]);
     }
 
@@ -720,12 +717,13 @@ pub const Request = struct {
                 req.response.parser.done = true;
             }
 
+            // we default to using keep-alive if not provided
             const req_connection = req.headers.getFirstValue("connection");
             const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?);
 
             const res_connection = req.response.headers.getFirstValue("connection");
             const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?);
-            if (req_keepalive and res_keepalive) {
+            if (res_keepalive and (req_keepalive or req_connection == null)) {
                 req.connection.data.closing = false;
             } else {
                 req.connection.data.closing = true;
lib/std/http/Server.zig
@@ -161,32 +161,29 @@ pub const BufferedConnection = struct {
     }
 
     pub fn writeAll(bconn: *BufferedConnection, buffer: []const u8) WriteError!void {
-        if (bconn.write_buf.len - bconn.write_end <= buffer.len) {
-            @memcpy(bconn.write_buf[bconn.write_end..], buffer);
+        if (bconn.write_buf.len - bconn.write_end >= buffer.len) {
+            @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer);
             bconn.write_end += @intCast(u16, buffer.len);
         } else {
-            try bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]);
-            bconn.write_end = 0;
-
+            try bconn.flush();
             try bconn.conn.writeAll(buffer);
         }
     }
 
     pub fn write(bconn: *BufferedConnection, buffer: []const u8) WriteError!usize {
-        if (bconn.write_buf.len - bconn.write_end <= buffer.len) {
-            @memcpy(bconn.write_buf[bconn.write_end..], buffer);
+        if (bconn.write_buf.len - bconn.write_end >= buffer.len) {
+            @memcpy(bconn.write_buf[bconn.write_end..][0..buffer.len], buffer);
             bconn.write_end += @intCast(u16, buffer.len);
 
             return buffer.len;
         } else {
-            try bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]);
-            bconn.write_end = 0;
-
+            try bconn.flush();
             return try bconn.conn.write(buffer);
         }
     }
 
     pub fn flush(bconn: *BufferedConnection) WriteError!void {
+        defer bconn.write_end = 0;
         return bconn.conn.writeAll(bconn.write_buf[0..bconn.write_end]);
     }
 
@@ -397,12 +394,14 @@ pub const Response = struct {
 
         // A connection is only keep-alive if the Connection header is present and it's value is not "close".
         // The server and client must both agree
+        //
+        // do() defaults to using keep-alive if the client requests it.
         const res_connection = res.headers.getFirstValue("connection");
         const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?);
 
         const req_connection = res.request.headers.getFirstValue("connection");
         const req_keepalive = req_connection != null and !std.ascii.eqlIgnoreCase("close", req_connection.?);
-        if (res_keepalive and req_keepalive) {
+        if (req_keepalive and (res_keepalive or res_connection == null)) {
             res.connection.conn.closing = false;
         } else {
             res.connection.conn.closing = true;
@@ -424,7 +423,7 @@ pub const Response = struct {
 
         res.headers.clearRetainingCapacity();
 
-        res.request.headers.clearRetainingCapacity();
+        res.request.headers.clearAndFree(); // FIXME: figure out why `clearRetainingCapacity` causes a leak in hash_map here
         res.request.parser.reset();
 
         res.request = Request{
test/standalone/http.zig
@@ -9,8 +9,8 @@ const testing = std.testing;
 
 const max_header_size = 8192;
 
-var gpa_server = std.heap.GeneralPurposeAllocator(.{}){};
-var gpa_client = std.heap.GeneralPurposeAllocator(.{}){};
+var gpa_server = std.heap.GeneralPurposeAllocator(.{ .stack_trace_frames = 12 }){};
+var gpa_client = std.heap.GeneralPurposeAllocator(.{ .stack_trace_frames = 12 }){};
 
 const salloc = gpa_server.allocator();
 const calloc = gpa_client.allocator();
@@ -44,6 +44,24 @@ fn handleRequest(res: *Server.Response) !void {
             try res.writeAll("World!\n");
             try res.finish();
         }
+    } else if (mem.startsWith(u8, res.request.target, "/large")) {
+        res.transfer_encoding = .{ .content_length = 14 * 1024 + 14 * 10 };
+
+        try res.do();
+
+        var i: u32 = 0;
+        while (i < 5) : (i += 1) {
+            try res.writeAll("Hello, World!\n");
+        }
+
+        try res.writeAll("Hello, World!\n" ** 1024);
+
+        i = 0;
+        while (i < 5) : (i += 1) {
+            try res.writeAll("Hello, World!\n");
+        }
+
+        try res.finish();
     } else if (mem.eql(u8, res.request.target, "/echo-content")) {
         try testing.expectEqualStrings("Hello, World!\n", body);
         try testing.expectEqualStrings("text/plain", res.request.headers.getFirstValue("content-type").?);
@@ -68,6 +86,7 @@ fn handleRequest(res: *Server.Response) !void {
         try res.writeAll("World!\n");
         // try res.finish();
         try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n");
+        try res.connection.flush();
     } else if (mem.eql(u8, res.request.target, "/redirect/1")) {
         res.transfer_encoding = .chunked;
 
@@ -177,8 +196,7 @@ pub fn main() !void {
     const server_thread = try std.Thread.spawn(.{}, serverThread, .{&server});
 
     var client = Client{ .allocator = calloc };
-
-    defer client.deinit();
+    // defer client.deinit(); handled below
 
     { // read content-length response
         var h = http.Headers{ .allocator = calloc };
@@ -202,6 +220,33 @@ pub fn main() !void {
         try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?);
     }
 
+    // connection has been kept alive
+    try testing.expect(client.connection_pool.free_len == 1);
+
+    { // read large content-length response
+        var h = http.Headers{ .allocator = calloc };
+        defer h.deinit();
+
+        const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/large", .{port});
+        defer calloc.free(location);
+        const uri = try std.Uri.parse(location);
+
+        log.info("{s}", .{location});
+        var req = try client.request(.GET, uri, h, .{});
+        defer req.deinit();
+
+        try req.start();
+        try req.wait();
+
+        const body = try req.reader().readAllAlloc(calloc, 8192 * 1024);
+        defer calloc.free(body);
+
+        try testing.expectEqual(@as(usize, 14 * 1024 + 14 * 10), body.len);
+    }
+
+    // connection has been kept alive
+    try testing.expect(client.connection_pool.free_len == 1);
+
     { // send head request and not read chunked
         var h = http.Headers{ .allocator = calloc };
         defer h.deinit();
@@ -225,6 +270,9 @@ pub fn main() !void {
         try testing.expectEqualStrings("14", req.response.headers.getFirstValue("content-length").?);
     }
 
+    // connection has been kept alive
+    try testing.expect(client.connection_pool.free_len == 1);
+
     { // read chunked response
         var h = http.Headers{ .allocator = calloc };
         defer h.deinit();
@@ -247,6 +295,9 @@ pub fn main() !void {
         try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?);
     }
 
+    // connection has been kept alive
+    try testing.expect(client.connection_pool.free_len == 1);
+
     { // send head request and not read chunked
         var h = http.Headers{ .allocator = calloc };
         defer h.deinit();
@@ -270,6 +321,9 @@ pub fn main() !void {
         try testing.expectEqualStrings("chunked", req.response.headers.getFirstValue("transfer-encoding").?);
     }
 
+    // connection has been kept alive
+    try testing.expect(client.connection_pool.free_len == 1);
+
     { // check trailing headers
         var h = http.Headers{ .allocator = calloc };
         defer h.deinit();
@@ -292,6 +346,9 @@ pub fn main() !void {
         try testing.expectEqualStrings("aaaa", req.response.headers.getFirstValue("x-checksum").?);
     }
 
+    // connection has been kept alive
+    try testing.expect(client.connection_pool.free_len == 1);
+
     { // send content-length request
         var h = http.Headers{ .allocator = calloc };
         defer h.deinit();
@@ -321,6 +378,36 @@ pub fn main() !void {
         try testing.expectEqualStrings("Hello, World!\n", body);
     }
 
+    // connection has been kept alive
+    try testing.expect(client.connection_pool.free_len == 1);
+
+    { // read content-length response with connection close
+        var h = http.Headers{ .allocator = calloc };
+        defer h.deinit();
+
+        try h.append("connection", "close");
+
+        const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/get", .{port});
+        defer calloc.free(location);
+        const uri = try std.Uri.parse(location);
+
+        log.info("{s}", .{location});
+        var req = try client.request(.GET, uri, h, .{});
+        defer req.deinit();
+
+        try req.start();
+        try req.wait();
+
+        const body = try req.reader().readAllAlloc(calloc, 8192);
+        defer calloc.free(body);
+
+        try testing.expectEqualStrings("Hello, World!\n", body);
+        try testing.expectEqualStrings("text/plain", req.response.headers.getFirstValue("content-type").?);
+    }
+
+    // connection has been closed
+    try testing.expect(client.connection_pool.free_len == 0);
+
     { // send chunked request
         var h = http.Headers{ .allocator = calloc };
         defer h.deinit();
@@ -350,6 +437,9 @@ pub fn main() !void {
         try testing.expectEqualStrings("Hello, World!\n", body);
     }
 
+    // connection has been kept alive
+    try testing.expect(client.connection_pool.free_len == 1);
+
     { // relative redirect
         var h = http.Headers{ .allocator = calloc };
         defer h.deinit();
@@ -371,6 +461,9 @@ pub fn main() !void {
         try testing.expectEqualStrings("Hello, World!\n", body);
     }
 
+    // connection has been kept alive
+    try testing.expect(client.connection_pool.free_len == 1);
+
     { // redirect from root
         var h = http.Headers{ .allocator = calloc };
         defer h.deinit();
@@ -392,6 +485,9 @@ pub fn main() !void {
         try testing.expectEqualStrings("Hello, World!\n", body);
     }
 
+    // connection has been kept alive
+    try testing.expect(client.connection_pool.free_len == 1);
+
     { // absolute redirect
         var h = http.Headers{ .allocator = calloc };
         defer h.deinit();
@@ -413,6 +509,9 @@ pub fn main() !void {
         try testing.expectEqualStrings("Hello, World!\n", body);
     }
 
+    // connection has been kept alive
+    try testing.expect(client.connection_pool.free_len == 1);
+
     { // too many redirects
         var h = http.Headers{ .allocator = calloc };
         defer h.deinit();
@@ -432,6 +531,11 @@ pub fn main() !void {
         };
     }
 
+    // connection has been kept alive
+    try testing.expect(client.connection_pool.free_len == 1);
+
+    client.deinit();
+
     killServer(server.socket.listen_address);
     server_thread.join();
 }