Commit 729a051e9e
Changed files (2)
lib
std
http
test
standalone
lib/std/http/Client.zig
@@ -451,7 +451,8 @@ pub const Response = struct {
pub const Request = struct {
uri: Uri,
client: *Client,
- connection: *ConnectionPool.Node,
+ /// is null when this connection is released
+ connection: ?*ConnectionPool.Node,
method: http.Method,
version: http.Version = .@"HTTP/1.1",
@@ -481,13 +482,14 @@ pub const Request = struct {
req.response.parser.header_bytes.deinit(req.client.allocator);
}
- if (!req.response.parser.done) {
- // If the response wasn't fully read, then we need to close the connection.
- req.connection.data.closing = true;
+ if (req.connection) |connection| {
+ if (!req.response.parser.done) {
+ // If the response wasn't fully read, then we need to close the connection.
+ connection.data.closing = true;
+ }
+ req.client.connection_pool.release(req.client, connection);
}
- req.client.connection_pool.release(req.client, req.connection);
-
req.arena.deinit();
req.* = undefined;
}
@@ -504,7 +506,8 @@ pub const Request = struct {
.zstd => |*zstd| zstd.deinit(),
}
- req.client.connection_pool.release(req.client, req.connection);
+ req.client.connection_pool.release(req.client, req.connection.?);
+ req.connection = null;
const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme;
@@ -534,7 +537,7 @@ pub const Request = struct {
/// Send the request to the server.
pub fn start(req: *Request) StartError!void {
- var buffered = std.io.bufferedWriter(req.connection.data.writer());
+ var buffered = std.io.bufferedWriter(req.connection.?.data.writer());
const w = buffered.writer();
try w.writeAll(@tagName(req.method));
@@ -544,7 +547,7 @@ pub const Request = struct {
try w.writeAll(req.uri.host.?);
try w.writeByte(':');
try w.print("{}", .{req.uri.port.?});
- } else if (req.connection.data.proxied) {
+ } else if (req.connection.?.data.proxied) {
// proxied connections require the full uri
try w.print("{+/}", .{req.uri});
} else {
@@ -625,7 +628,7 @@ pub const Request = struct {
var index: usize = 0;
while (index == 0) {
- const amt = try req.response.parser.read(&req.connection.data, buf[index..], req.response.skip);
+ const amt = try req.response.parser.read(&req.connection.?.data, buf[index..], req.response.skip);
if (amt == 0 and req.response.parser.done) break;
index += amt;
}
@@ -643,10 +646,10 @@ pub const Request = struct {
pub fn wait(req: *Request) WaitError!void {
while (true) { // handle redirects
while (true) { // read headers
- try req.connection.data.fill();
+ try req.connection.?.data.fill();
- const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
- req.connection.data.drop(@intCast(u16, nchecked));
+ const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek());
+ req.connection.?.data.drop(@intCast(u16, nchecked));
if (req.response.parser.state.isContent()) break;
}
@@ -654,12 +657,12 @@ pub const Request = struct {
try req.response.parse(req.response.parser.header_bytes.items, false);
if (req.response.status == .switching_protocols) {
- req.connection.data.closing = false;
+ req.connection.?.data.closing = false;
req.response.parser.done = true;
}
if (req.method == .CONNECT and req.response.status == .ok) {
- req.connection.data.closing = false;
+ req.connection.?.data.closing = false;
req.response.parser.done = true;
}
@@ -670,9 +673,9 @@ pub const Request = struct {
const res_connection = req.response.headers.getFirstValue("connection");
const res_keepalive = res_connection != null and !std.ascii.eqlIgnoreCase("close", res_connection.?);
if (res_keepalive and (req_keepalive or req_connection == null)) {
- req.connection.data.closing = false;
+ req.connection.?.data.closing = false;
} else {
- req.connection.data.closing = true;
+ req.connection.?.data.closing = true;
}
if (req.response.transfer_encoding) |te| {
@@ -762,10 +765,10 @@ pub const Request = struct {
const has_trail = !req.response.parser.state.isContent();
while (!req.response.parser.state.isContent()) { // read trailing headers
- try req.connection.data.fill();
+ try req.connection.?.data.fill();
- const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.data.peek());
- req.connection.data.drop(@intCast(u16, nchecked));
+ const nchecked = try req.response.parser.checkCompleteHead(req.client.allocator, req.connection.?.data.peek());
+ req.connection.?.data.drop(@intCast(u16, nchecked));
}
if (has_trail) {
@@ -803,16 +806,16 @@ pub const Request = struct {
pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
switch (req.transfer_encoding) {
.chunked => {
- try req.connection.data.writer().print("{x}\r\n", .{bytes.len});
- try req.connection.data.writeAll(bytes);
- try req.connection.data.writeAll("\r\n");
+ try req.connection.?.data.writer().print("{x}\r\n", .{bytes.len});
+ try req.connection.?.data.writeAll(bytes);
+ try req.connection.?.data.writeAll("\r\n");
return bytes.len;
},
.content_length => |*len| {
if (len.* < bytes.len) return error.MessageTooLong;
- const amt = try req.connection.data.write(bytes);
+ const amt = try req.connection.?.data.write(bytes);
len.* -= amt;
return amt;
},
@@ -832,7 +835,7 @@ pub const Request = struct {
/// Finish the body of a request. This notifies the server that you have no more data to send.
pub fn finish(req: *Request) FinishError!void {
switch (req.transfer_encoding) {
- .chunked => try req.connection.data.writeAll("0\r\n\r\n"),
+ .chunked => try req.connection.?.data.writeAll("0\r\n\r\n"),
.content_length => |len| if (len != 0) return error.MessageNotCompleted,
.none => {},
}
test/standalone/http.zig
@@ -129,6 +129,15 @@ fn handleRequest(res: *Server.Response) !void {
try res.writeAll("Hello, ");
try res.writeAll("Redirected!\n");
try res.finish();
+ } else if (mem.eql(u8, res.request.target, "/redirect/invalid")) {
+ const invalid_port = try getUnusedTcpPort();
+ const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}", .{invalid_port});
+ defer salloc.free(location);
+
+ res.status = .found;
+ try res.headers.append("location", location);
+ try res.do();
+ try res.finish();
} else {
res.status = .not_found;
try res.do();
@@ -180,6 +189,14 @@ fn killServer(addr: std.net.Address) void {
conn.close();
}
+fn getUnusedTcpPort() !u16 {
+ const addr = try std.net.Address.parseIp("127.0.0.1", 0);
+ var s = std.net.StreamServer.init(.{});
+ defer s.deinit();
+ try s.listen(addr);
+ return s.listen_address.in.getPort();
+}
+
pub fn main() !void {
const log = std.log.scoped(.client);
@@ -533,6 +550,27 @@ pub fn main() !void {
// connection has been kept alive
try testing.expect(client.connection_pool.free_len == 1);
+ { // check client without segfault by connection error after redirection
+ var h = http.Headers{ .allocator = calloc };
+ defer h.deinit();
+
+ const location = try std.fmt.allocPrint(calloc, "http://127.0.0.1:{d}/redirect/invalid", .{port});
+ defer calloc.free(location);
+ const uri = try std.Uri.parse(location);
+
+ log.info("{s}", .{location});
+ var req = try client.request(.GET, uri, h, .{});
+ defer req.deinit();
+
+ try req.start();
+ const result = req.wait();
+
+ try testing.expectError(error.ConnectionRefused, result); // expects not segfault but the regular error
+ }
+
+ // connection has been kept alive
+ try testing.expect(client.connection_pool.free_len == 1);
+
client.deinit();
killServer(server.socket.listen_address);