Commit 729a051e9e

Mizuochi Keita <keitam913@yahoo.co.jp>
2023-06-07 18:09:23
std.http: Fix segfault while redirecting
Make to avoid releasing request's connection twice. Change the `Request.connection` field optional. This field is null while the connection is released. Fixes #15965
1 parent e23d48e
Changed files (2)
lib
std
test
standalone
lib/std/http/Client.zig
@@ -451,7 +451,8 @@ pub const Response = struct {
 pub const Request = struct {
     uri: Uri,
     client: *Client,
-    connection: *ConnectionPool.Node,
+    /// is null when this connection is released
+    connection: ?*ConnectionPool.Node,
 
     method: http.Method,
     version: http.Version = .@"HTTP/1.1",
@@ -481,13 +482,14 @@ pub const Request = struct {
             req.response.parser.header_bytes.deinit(req.client.allocator);
         }
 
-        if (!req.response.parser.done) {
-            // If the response wasn't fully read, then we need to close the connection.
-            req.connection.data.closing = true;
+        if (req.connection) |connection| {
+            if (!req.response.parser.done) {
+                // If the response wasn't fully read, then we need to close the connection.
+                connection.data.closing = true;
+            }
+            req.client.connection_pool.release(req.client, connection);
         }
 
-        req.client.connection_pool.release(req.client, req.connection);
-
         req.arena.deinit();
         req.* = undefined;
     }
@@ -504,7 +506,8 @@ pub const Request = struct {
             .zstd => |*zstd| zstd.deinit(),
         }
 
-        req.client.connection_pool.release(req.client, req.connection);
+        req.client.connection_pool.release(req.client, req.connection.?);
+        req.connection = null;
 
         const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme;
 
@@ -534,7 +537,7 @@ pub const Request = struct {
 
     /// Send the request to the server.
     pub fn start(req: *Request) StartError!void {
-        var buffered = std.io.bufferedWriter(req.connection.data.writer());
+        var buffered = std.io.bufferedWriter(req.connection.?.data.writer());
         const w = buffered.writer();
 
         try w.writeAll(@tagName(req.method));
@@ -544,7 +547,7 @@ pub const Request = struct {
             try w.writeAll(req.uri.host.?);
             try w.writeByte(':');
             try w.print("{}", .{req.uri.port.?});
-        } else if (req.connection.data.proxied) {
+        } else if (req.connection.?.data.proxied) {
             // proxied connections require the full uri
             try w.print("{+/}", .{req.uri});
         } else {
@@ -625,7 +628,7 @@ pub const Request = struct {
 
         var index: usize = 0;
         while (index == 0) {
-            const amt = try req.response.parser.read(&req.connection.data, buf[index..], req.response.skip);
+            const amt = try req.response.parser.read(&req.connection.?.data, buf[index..], req.response.skip);
             if (amt == 0 and req.response.parser.done) break;
             index += amt;
         }
@@ -643,10 +646,10 @@ pub const Request = struct {
     pub fn wait(req: *Request) WaitError!void {
         while (true) { // handle redirects
             while (true) { // read headers
-                try req.connection.data.fill();
+                try req.connection.?.data.fill();
 
-                const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
-                req.connection.data.drop(@intCast(u16, nchecked));
+                const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek());
+                req.connection.?.data.drop(@intCast(u16, nchecked));
 
                 if (req.response.parser.state.isContent()) break;
             }
@@ -654,12 +657,12 @@ pub const Request = struct {
             try req.response.parse(req.response.parser.header_bytes.items, false);
 
             if (req.response.status == .switching_protocols) {
-                req.connection.data.closing = false;
+                req.connection.?.data.closing = false;
                 req.response.parser.done = true;
             }
 
             if (req.method == .CONNECT and req.response.status == .ok) {
-                req.connection.data.closing = false;
+                req.connection.?.data.closing = false;
                 req.response.parser.done = true;
             }
 
@@ -670,9 +673,9 @@ pub const Request = struct {
             const res_connection = req.response.headers.getFirstValue("connection");
             const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?);
             if (res_keepalive and (req_keepalive or req_connection == null)) {
-                req.connection.data.closing = false;
+                req.connection.?.data.closing = false;
             } else {
-                req.connection.data.closing = true;
+                req.connection.?.data.closing = true;
             }
 
             if (req.response.transfer_encoding) |te| {
@@ -762,10 +765,10 @@ pub const Request = struct {
             const has_trail = !req.response.parser.state.isContent();
 
             while (!req.response.parser.state.isContent()) { // read trailing headers
-                try req.connection.data.fill();
+                try req.connection.?.data.fill();
 
-                const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
-                req.connection.data.drop(@intCast(u16, nchecked));
+                const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek());
+                req.connection.?.data.drop(@intCast(u16, nchecked));
             }
 
             if (has_trail) {
@@ -803,16 +806,16 @@ pub const Request = struct {
     pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
         switch (req.transfer_encoding) {
             .chunked => {
-                try req.connection.data.writer().print("{x}\r\n", .{bytes.len});
-                try req.connection.data.writeAll(bytes);
-                try req.connection.data.writeAll("\r\n");
+                try req.connection.?.data.writer().print("{x}\r\n", .{bytes.len});
+                try req.connection.?.data.writeAll(bytes);
+                try req.connection.?.data.writeAll("\r\n");
 
                 return bytes.len;
             },
             .content_length => |*len| {
                 if (len.* < bytes.len) return error.MessageTooLong;
 
-                const amt = try req.connection.data.write(bytes);
+                const amt = try req.connection.?.data.write(bytes);
                 len.* -= amt;
                 return amt;
             },
@@ -832,7 +835,7 @@ pub const Request = struct {
     /// Finish the body of a request. This notifies the server that you have no more data to send.
     pub fn finish(req: *Request) FinishError!void {
         switch (req.transfer_encoding) {
-            .chunked => try req.connection.data.writeAll("0\r\n\r\n"),
+            .chunked => try req.connection.?.data.writeAll("0\r\n\r\n"),
             .content_length => |len| if (len != 0) return error.MessageNotCompleted,
             .none => {},
         }
test/standalone/http.zig
@@ -129,6 +129,15 @@ fn handleRequest(res: *Server.Response) !void {
         try res.writeAll("Hello, ");
         try res.writeAll("Redirected!\n");
         try res.finish();
+    } else if (mem.eql(u8, res.request.target, "/redirect/invalid")) {
+        const invalid_port = try getUnusedTcpPort();
+        const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}", .{invalid_port});
+        defer salloc.free(location);
+
+        res.status = .found;
+        try res.headers.append("location", location);
+        try res.do();
+        try res.finish();
     } else {
         res.status = .not_found;
         try res.do();
@@ -180,6 +189,14 @@ fn killServer(addr: std.net.Address) void {
     conn.close();
 }
 
+fn getUnusedTcpPort() !u16 {
+    const addr = try std.net.Address.parseIp("127.0.0.1", 0);
+    var s = std.net.StreamServer.init(.{});
+    defer s.deinit();
+    try s.listen(addr);
+    return s.listen_address.in.getPort();
+}
+
 pub fn main() !void {
     const log = std.log.scoped(.client);
 
@@ -533,6 +550,27 @@ pub fn main() !void {
     // connection has been kept alive
     try testing.expect(client.connection_pool.free_len == 1);
 
+    { // check client without segfault by connection error after redirection
+        var h = http.Headers{ .allocator = calloc };
+        defer h.deinit();
+
+        const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/redirect/invalid", .{port});
+        defer calloc.free(location);
+        const uri = try std.Uri.parse(location);
+
+        log.info("{s}", .{location});
+        var req = try client.request(.GET, uri, h, .{});
+        defer req.deinit();
+
+        try req.start();
+        const result = req.wait();
+
+        try testing.expectError(error.ConnectionRefused, result); // expects not segfault but the regular error
+    }
+
+    // connection has been kept alive
+    try testing.expect(client.connection_pool.free_len == 1);
+
     client.deinit();
 
     killServer(server.socket.listen_address);