Commit fd2f906d1e
Changed files (2)
lib
std
lib/std/http/Client.zig
@@ -21,27 +21,51 @@ ca_bundle: std.crypto.Certificate.Bundle = .{},
/// it will first rescan the system for root certificates.
next_https_rescan_certs: bool = true,
-connection_pool: std.TailQueue(Connection) = .{},
+connection_mutex: std.Thread.Mutex = .{},
+connection_pool: ConnectionPool = .{},
+connection_used: ConnectionPool = .{},
const ConnectionPool = std.TailQueue(Connection);
const ConnectionNode = ConnectionPool.Node;
+/// Acquires an existing connection from the connection pool. This function is threadsafe.
+pub fn acquire(client: *Client, node: *ConnectionNode) void {
+ client.connection_mutex.lock();
+ defer client.connection_mutex.unlock();
+
+ client.connection_pool.remove(node);
+ client.connection_used.append(node);
+}
+
+/// Tries to release a connection back to the connection pool. This function is threadsafe.
+/// If the connection is marked as closing, it will be closed instead.
pub fn release(client: *Client, node: *ConnectionNode) void {
- if (node.data.unusable) return node.data.close(client);
+ if (node.data.closing) {
+ node.data.close(client);
+
+ return client.allocator.destroy(node);
+ }
+
+ client.connection_mutex.lock();
+ defer client.connection_mutex.unlock();
+ client.connection_used.remove(node);
client.connection_pool.append(node);
}
+const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw);
+const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw);
+
pub const Connection = struct {
stream: net.Stream,
/// undefined unless protocol is tls.
- tls_client: std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB.
+ tls_client: *std.crypto.tls.Client, // TODO: allocate this, it's currently 16 KB.
protocol: Protocol,
host: []u8,
port: u16,
// This connection has been part of a non keepalive request and cannot be added to the pool.
- unusable: bool = false,
+ closing: bool = false,
pub const Protocol = enum { plain, tls };
@@ -59,6 +83,24 @@ pub const Connection = struct {
}
}
+ pub const ReadError = std.net.Stream.ReadError || error{
+ TlsConnectionTruncated,
+ TlsRecordOverflow,
+ TlsDecodeError,
+ TlsAlert,
+ TlsBadRecordMac,
+ Overflow,
+ TlsBadLength,
+ TlsIllegalParameter,
+ TlsUnexpectedMessage,
+ };
+
+ pub const Reader = std.io.Reader(*Connection, ReadError, read);
+
+ pub fn reader(conn: *Connection) Reader {
+ return Reader{ .context = conn };
+ }
+
pub fn writeAll(conn: *Connection, buffer: []const u8) !void {
switch (conn.protocol) {
.plain => return conn.stream.writeAll(buffer),
@@ -73,10 +115,18 @@ pub const Connection = struct {
}
}
+ pub const WriteError = std.net.Stream.WriteError || error{};
+ pub const Writer = std.io.Writer(*Connection, WriteError, write);
+
+ pub fn writer(conn: *Connection) Writer {
+ return Writer{ .context = conn };
+ }
+
pub fn close(conn: *Connection, client: *const Client) void {
if (conn.protocol == .tls) {
// try to cleanly close the TLS connection, for any server that cares.
_ = conn.tls_client.writeEnd(conn.stream, "", true) catch {};
+ client.allocator.destroy(conn.tls_client);
}
conn.stream.close();
@@ -85,10 +135,10 @@ pub const Connection = struct {
}
};
-/// TODO: emit error.UnexpectedEndOfStream or something like that when the read
-/// data does not match the content length. This is necessary since HTTPS disables
-/// close_notify protection on underlying TLS streams.
pub const Request = struct {
+ const read_buffer_size = 8192;
+ const ReadBufferIndex = std.math.IntFittingRange(0, read_buffer_size);
+
client: *Client,
connection: *ConnectionNode,
redirects_left: u32,
@@ -97,6 +147,11 @@ pub const Request = struct {
/// redirects.
headers: Headers,
+ /// Read buffer for the connection. This is used to pull in large amounts of data from the connection even if the user asks for a small amount. This can probably be removed with careful planning.
+ read_buffer: [read_buffer_size]u8 = undefined,
+ read_buffer_start: ReadBufferIndex = 0,
+ read_buffer_len: ReadBufferIndex = 0,
+
pub const Response = struct {
headers: Response.Headers,
state: State,
@@ -106,15 +161,24 @@ pub const Request = struct {
header_bytes: std.ArrayListUnmanaged(u8),
max_header_bytes: usize,
next_chunk_length: u64,
- done: bool,
+ done: bool = false,
+
+ compression: union(enum) {
+ deflate: DeflateDecompressor,
+ gzip: GzipDecompressor,
+ none: void,
+ } = .none,
pub const Headers = struct {
status: http.Status,
version: http.Version,
location: ?[]const u8 = null,
content_length: ?u64 = null,
- transfer_encoding: ?http.TransferEncoding = null,
- connection_close: bool = true,
+ transfer_encoding: ?http.TransferEncoding = null, // This should only ever be chunked, compression is handled separately.
+ transfer_compression: ?http.TransferEncoding = null,
+ connection: http.Connection = .close,
+
+ number_of_headers: usize = 0,
pub fn parse(bytes: []const u8) !Response.Headers {
var it = mem.split(u8, bytes[0 .. bytes.len - 4], "\r\n");
@@ -137,6 +201,8 @@ pub const Request = struct {
};
while (it.next()) |line| {
+ headers.number_of_headers += 1;
+
if (line.len == 0) return error.HttpHeadersInvalid;
switch (line[0]) {
' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
@@ -152,14 +218,65 @@ pub const Request = struct {
if (headers.content_length != null) return error.HttpHeadersInvalid;
headers.content_length = try std.fmt.parseInt(u64, header_value, 10);
} else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
- if (headers.transfer_encoding != null) return error.HttpHeadersInvalid;
- headers.transfer_encoding = std.meta.stringToEnum(http.TransferEncoding, header_value) orelse
+ if (headers.transfer_encoding != null or headers.transfer_compression != null) return error.HttpHeadersInvalid;
+
+ // Transfer-Encoding: second, first
+ // Transfer-Encoding: deflate, chunked
+ var iter = std.mem.splitBackwards(u8, header_value, ",");
+
+ if (iter.next()) |first| {
+ const kind = std.meta.stringToEnum(
+ http.TransferEncoding,
+ std.mem.trim(u8, first, " "),
+ ) orelse
+ return error.HttpTransferEncodingUnsupported;
+
+ switch (kind) {
+ .chunked => headers.transfer_encoding = .chunked,
+ .compress => headers.transfer_compression = .compress,
+ .deflate => headers.transfer_compression = .deflate,
+ .gzip => headers.transfer_compression = .gzip,
+ }
+ }
+
+ if (iter.next()) |second| {
+ if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported;
+
+ const kind = std.meta.stringToEnum(
+ http.TransferEncoding,
+ std.mem.trim(u8, second, " "),
+ ) orelse
+ return error.HttpTransferEncodingUnsupported;
+
+ switch (kind) {
+ .chunked => return error.HttpHeadersInvalid, // chunked must come last
+ .compress => return error.HttpTransferEncodingUnsupported, // compress not supported
+ .deflate => headers.transfer_compression = .deflate,
+ .gzip => headers.transfer_compression = .gzip,
+ }
+ }
+
+ if (iter.next()) |_| return error.HttpTransferEncodingUnsupported;
+ } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) {
+ if (headers.transfer_compression != null) return error.HttpHeadersInvalid;
+
+ const kind = std.meta.stringToEnum(
+ http.TransferEncoding,
+ std.mem.trim(u8, header_value, " "),
+ ) orelse
return error.HttpTransferEncodingUnsupported;
+
+ switch (kind) {
+ .chunked => return error.HttpHeadersInvalid, // not transfer encoding
+ .compress => return error.HttpTransferEncodingUnsupported, // compress not supported
+ .deflate => headers.transfer_compression = .deflate,
+ .gzip => headers.transfer_compression = .gzip,
+ }
} else if (std.ascii.eqlIgnoreCase(header_name, "connection")) {
if (std.ascii.eqlIgnoreCase(header_value, "keep-alive")) {
- headers.connection_close = false;
+ headers.connection = .keep_alive;
} else if (std.ascii.eqlIgnoreCase(header_value, "close")) {
- headers.connection_close = true;
+ headers.connection = .close;
} else {
return error.HttpConnectionHeaderUnsupported;
}
@@ -238,7 +355,6 @@ pub const Request = struct {
.max_header_bytes = max,
.header_bytes_owned = true,
.next_chunk_length = undefined,
- .done = false,
};
}
@@ -250,7 +366,6 @@ pub const Request = struct {
.max_header_bytes = buf.len,
.header_bytes_owned = false,
.next_chunk_length = undefined,
- .done = false,
};
}
@@ -537,10 +652,19 @@ pub const Request = struct {
}
};
+ pub const RequestTransfer = union(enum) {
+ content_length: u64,
+ chunked: void,
+ none: void,
+ };
+
pub const Headers = struct {
version: http.Version = .@"HTTP/1.1",
method: http.Method = .GET,
- connection_close: bool = false,
+ connection: http.Connection = .keep_alive,
+ transfer_encoding: RequestTransfer = .none,
+
+ custom: []const http.CustomHeader = &[_]http.CustomHeader{},
};
pub const Options = struct {
@@ -561,167 +685,131 @@ pub const Request = struct {
};
};
- /// May be skipped if header strategy is buffer.
+ /// Frees all resources associated with the request.
pub fn deinit(req: *Request) void {
+ switch (req.response.compression) {
+ .none => {},
+ .deflate => |*deflate| deflate.deinit(),
+ .gzip => |*gzip| gzip.deinit(),
+ }
+
if (req.response.header_bytes_owned) {
req.response.header_bytes.deinit(req.client.allocator);
}
+
+ if (!req.response.done) {
+ // If the response wasn't fully read, then we need to close the connection.
+ req.connection.data.closing = true;
+ req.client.release(req.connection);
+ }
+
req.* = undefined;
}
- pub const Reader = std.io.Reader(*Request, ReadError, read);
+ const ReadRawError = Connection.ReadError || std.Uri.ParseError || RequestError || error{
+ UnexpectedEndOfStream,
+ TooManyHttpRedirects,
+ HttpRedirectMissingLocation,
+ HttpHeadersInvalid,
+ };
- pub fn reader(req: *Request) Reader {
- return .{ .context = req };
+ const ReaderRaw = std.io.Reader(*Request, ReadRawError, readRaw);
+
+ /// Read from the underlying stream, without decompressing or parsing the headers. Must be called
+ /// after waitForCompleteHead() has returned successfully.
+ pub fn readRaw(req: *Request, buffer: []u8) ReadRawError!usize {
+ assert(req.response.state.isContent());
+
+ var index: usize = 0;
+ while (index == 0) {
+ const amt = try req.readRawAdvanced(buffer[index..]);
+ const zero_means_end = req.response.done and req.response.headers.status.class() != .redirect;
+
+ if (amt == 0 and zero_means_end) break;
+ index += amt;
+ }
+
+ return index;
}
- pub fn readAll(req: *Request, buffer: []u8) !usize {
- return readAtLeast(req, buffer, buffer.len);
+ fn checkForCompleteHead(req: *Request, buffer: []u8) !usize {
+ switch (req.response.state) {
+ .invalid => unreachable,
+ .start, .seen_r, .seen_rn, .seen_rnr => {},
+ else => return 0, // No more headers to read.
+ }
+
+ const i = req.response.findHeadersEnd(buffer[0..]);
+ if (req.response.state == .invalid) return error.HttpHeadersInvalid;
+
+ const headers_data = buffer[0..i];
+ if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) {
+ return error.HttpHeadersExceededSizeLimit;
+ }
+ try req.response.header_bytes.appendSlice(req.client.allocator, headers_data);
+
+ if (req.response.state == .finished) {
+ req.response.headers = try Response.Headers.parse(req.response.header_bytes.items);
+
+ if (req.response.headers.connection == .keep_alive) {
+ req.connection.data.closing = false;
+ } else {
+ req.connection.data.closing = true;
+ }
+
+ if (req.response.headers.transfer_encoding) |transfer_encoding| {
+ switch (transfer_encoding) {
+ .chunked => {
+ req.response.next_chunk_length = 0;
+ req.response.state = .chunk_size;
+ },
+ .compress => unreachable,
+ .deflate => unreachable,
+ .gzip => unreachable,
+ }
+ } else if (req.response.headers.content_length) |content_length| {
+ req.response.next_chunk_length = content_length;
+ } else {
+ req.response.done = true;
+ }
+
+ return i;
+ }
+
+ return 0;
}
- pub const ReadError = net.Stream.ReadError || error{
- // From HTTP protocol
- HttpHeadersInvalid,
+ pub const WaitForCompleteHeadError = ReadRawError || error {
+ UnexpectedEndOfStream,
+
HttpHeadersExceededSizeLimit,
- HttpRedirectMissingLocation,
- HttpTransferEncodingUnsupported,
- HttpConnectionHeaderUnsupported,
- HttpContentLengthUnknown,
- TooManyHttpRedirects,
ShortHttpStatusLine,
BadHttpVersion,
HttpHeaderContinuationsUnsupported,
- UnsupportedUrlScheme,
- UriMissingHost,
- UnknownHostName,
-
- // Network problems
- NetworkUnreachable,
- HostLacksNetworkAddresses,
- TemporaryNameServerFailure,
- NameServerFailure,
- ProtocolFamilyNotAvailable,
- ProtocolNotSupported,
-
- // System resource problems
- ProcessFdQuotaExceeded,
- SystemFdQuotaExceeded,
- OutOfMemory,
-
- // TLS problems
- InsufficientEntropy,
- TlsConnectionTruncated,
- TlsRecordOverflow,
- TlsDecodeError,
- TlsAlert,
- TlsBadRecordMac,
- TlsBadLength,
- TlsIllegalParameter,
- TlsUnexpectedMessage,
- TlsDecryptFailure,
- CertificateFieldHasInvalidLength,
- CertificateHostMismatch,
- CertificatePublicKeyInvalid,
- CertificateExpired,
- CertificateFieldHasWrongDataType,
- CertificateIssuerMismatch,
- CertificateNotYetValid,
- CertificateSignatureAlgorithmMismatch,
- CertificateSignatureAlgorithmUnsupported,
- CertificateSignatureInvalid,
- CertificateSignatureInvalidLength,
- CertificateSignatureNamedCurveUnsupported,
- CertificateSignatureUnsupportedBitCount,
- TlsCertificateNotVerified,
- TlsBadSignatureScheme,
- TlsBadRsaSignatureBitCount,
- TlsDecryptError,
- UnsupportedCertificateVersion,
- CertificateTimeInvalid,
- CertificateHasUnrecognizedObjectId,
- CertificateHasInvalidBitString,
- CertificateAuthorityBundleTooBig,
-
- // TODO: convert to higher level errors
- InvalidFormat,
- InvalidPort,
- UnexpectedCharacter,
- Overflow,
- InvalidCharacter,
- AddressFamilyNotSupported,
- AddressInUse,
- AddressNotAvailable,
- ConnectionPending,
- ConnectionRefused,
- FileNotFound,
- PermissionDenied,
- ServiceUnavailable,
- SocketTypeNotSupported,
- FileTooBig,
- LockViolation,
- NoSpaceLeft,
- NotOpenForWriting,
- InvalidEncoding,
- IdentityElement,
- NonCanonical,
- SignatureVerificationFailed,
- MessageTooLong,
- NegativeIntoUnsigned,
- TargetTooSmall,
- BufferTooSmall,
- InvalidSignature,
- NotSquare,
- DiskQuota,
- InvalidEnd,
- Incomplete,
- InvalidIpv4Mapping,
- InvalidIPAddressFormat,
- BadPathName,
- DeviceBusy,
- FileBusy,
- FileLocksNotSupported,
- InvalidHandle,
- InvalidUtf8,
- NameTooLong,
- NoDevice,
- PathAlreadyExists,
- PipeBusy,
- SharingViolation,
- SymLinkLoop,
- FileSystem,
- InterfaceNotFound,
- AlreadyBound,
- FileDescriptorNotASocket,
- NetworkSubsystemFailed,
- NotDir,
- ReadOnlyFileSystem,
- Unseekable,
- MissingEndCertificateMarker,
- InvalidPadding,
- EndOfStream,
- InvalidArgument,
+ HttpTransferEncodingUnsupported,
+ HttpConnectionHeaderUnsupported,
};
- pub fn read(req: *Request, buffer: []u8) ReadError!usize {
- return readAtLeast(req, buffer, 1);
- }
-
- pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize {
- assert(len <= buffer.len);
- var index: usize = 0;
- while (index < len) {
- const amt = try readAdvanced(req, buffer[index..]);
- const zero_means_end = req.response.done and req.response.headers.status.class() != .redirect;
+ /// Reads a complete response head. Any leftover data is stored in the request. This function is idempotent.
+ pub fn waitForCompleteHead(req: *Request) WaitForCompleteHeadError!void {
+ if (req.response.state.isContent()) return;
- if (amt == 0 and zero_means_end) break;
- index += amt;
+ while (true) {
+ const nread = try req.connection.data.read(req.read_buffer[0..]);
+ const amt = try checkForCompleteHead(req, req.read_buffer[0..nread]);
+
+ if (amt != 0) {
+ req.read_buffer_start = @intCast(ReadBufferIndex, amt);
+ req.read_buffer_len = @intCast(ReadBufferIndex, nread);
+ return;
+ } else if (nread == 0) {
+ return error.UnexpectedEndOfStream;
+ }
}
- return index;
}
/// This one can return 0 without meaning EOF.
- /// TODO change to readvAdvanced
- pub fn readAdvanced(req: *Request, buffer: []u8) !usize {
+ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
if (req.response.done) {
if (req.response.headers.status.class() == .redirect) {
if (req.redirects_left == 0) return error.TooManyHttpRedirects;
@@ -744,82 +832,56 @@ pub const Request = struct {
}
}
- var in = buffer[0..try req.connection.data.read(buffer)];
+ // var in: []const u8 = undefined;
+ if (req.read_buffer_start == req.read_buffer_len) {
+ const nread = try req.connection.data.read(req.read_buffer[0..]);
+ if (nread == 0) return error.UnexpectedEndOfStream;
+
+ req.read_buffer_start = 0;
+ req.read_buffer_len = @intCast(ReadBufferIndex, nread);
+ }
+
var out_index: usize = 0;
while (true) {
switch (req.response.state) {
- .invalid => unreachable,
- .start, .seen_r, .seen_rn, .seen_rnr => {
- const i = req.response.findHeadersEnd(in);
- if (req.response.state == .invalid) return error.HttpHeadersInvalid;
-
- const headers_data = in[0..i];
- if (req.response.header_bytes.items.len + headers_data.len > req.response.max_header_bytes) {
- return error.HttpHeadersExceededSizeLimit;
- }
- try req.response.header_bytes.appendSlice(req.client.allocator, headers_data);
-
- if (req.response.state == .finished) {
- req.response.headers = try Response.Headers.parse(req.response.header_bytes.items);
-
- if (req.response.headers.connection_close == true) {
- req.connection.data.unusable = true;
- } else {
- req.connection.data.unusable = false;
- }
-
- if (req.response.headers.transfer_encoding) |transfer_encoding| {
- switch (transfer_encoding) {
- .chunked => {
- req.response.next_chunk_length = 0;
- req.response.state = .chunk_size;
- },
- .compress => return error.HttpTransferEncodingUnsupported,
- .deflate => return error.HttpTransferEncodingUnsupported,
- .gzip => return error.HttpTransferEncodingUnsupported,
- }
- } else if (req.response.headers.content_length) |content_length| {
- req.response.next_chunk_length = content_length;
- } else {
- return error.HttpContentLengthUnknown;
+ .invalid, .start, .seen_r, .seen_rn, .seen_rnr => unreachable,
+ .finished => {
+ // TODO https://github.com/ziglang/zig/issues/14039
+ const buf_avail = req.read_buffer_len - req.read_buffer_start;
+ const data_avail = req.response.next_chunk_length;
+ const out_avail = buffer.len;
+
+ if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) {
+ const can_read = @intCast(usize, @min(buf_avail, data_avail));
+ req.response.next_chunk_length -= can_read;
+
+ if (req.response.next_chunk_length == 0) {
+ req.client.release(req.connection);
+ req.connection = undefined;
+ req.response.done = true;
+ continue;
}
- in = in[i..];
- continue;
+ return 0; // skip over as much data as possible
}
- assert(out_index == 0);
- return 0;
- },
- .finished => {
- const sub_amt = @intCast(usize, @min(req.response.next_chunk_length, in.len));
- req.response.next_chunk_length -= sub_amt;
+ const can_read = @intCast(usize, @min(@min(buf_avail, data_avail), out_avail));
+ req.response.next_chunk_length -= can_read;
+
+ mem.copy(u8, buffer[0..], req.read_buffer[req.read_buffer_start..][0..can_read]);
+ req.read_buffer_start += @intCast(ReadBufferIndex, can_read);
if (req.response.next_chunk_length == 0) {
req.client.release(req.connection);
req.connection = undefined;
-
req.response.done = true;
- assert(in.len == sub_amt); // TODO: figure out how to not read more than necessary.
-
- if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) return 0;
-
- mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
- return out_index + sub_amt;
}
- if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) return 0;
-
- if (in.ptr == buffer.ptr) {
- return sub_amt;
- } else {
- mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
- return out_index + sub_amt;
- }
+ return can_read;
},
- .chunk_size_prefix_r => switch (in.len) {
+ .chunk_size_prefix_r => switch (req.read_buffer_len - req.read_buffer_start) {
0 => return out_index,
- 1 => switch (in[0]) {
+ 1 => switch (req.read_buffer[req.read_buffer_start]) {
'\r' => {
req.response.state = .chunk_size_prefix_n;
return out_index;
@@ -829,9 +891,9 @@ pub const Request = struct {
return error.HttpHeadersInvalid;
},
},
- else => switch (int16(in[0..2])) {
+ else => switch (int16(req.read_buffer[req.read_buffer_start..][0..2])) {
int16("\r\n") => {
- in = in[2..];
+ req.read_buffer_start += 2;
req.response.state = .chunk_size;
continue;
},
@@ -841,11 +903,11 @@ pub const Request = struct {
},
},
},
- .chunk_size_prefix_n => switch (in.len) {
+ .chunk_size_prefix_n => switch (req.read_buffer_len - req.read_buffer_start) {
0 => return out_index,
- else => switch (in[0]) {
+ else => switch (req.read_buffer[req.read_buffer_start]) {
'\n' => {
- in = in[1..];
+ req.read_buffer_start += 1;
req.response.state = .chunk_size;
continue;
},
@@ -856,7 +918,7 @@ pub const Request = struct {
},
},
.chunk_size, .chunk_r => {
- const i = req.response.findChunkedLen(in);
+ const i = req.response.findChunkedLen(req.read_buffer[req.read_buffer_start..req.read_buffer_len]);
switch (req.response.state) {
.invalid => return error.HttpHeadersInvalid,
.chunk_data => {
@@ -867,7 +929,8 @@ pub const Request = struct {
return out_index;
}
- in = in[i..];
+
+ req.read_buffer_start += @intCast(ReadBufferIndex, i);
continue;
},
.chunk_size => return out_index,
@@ -876,34 +939,129 @@ pub const Request = struct {
},
.chunk_data => {
// TODO https://github.com/ziglang/zig/issues/14039
- const sub_amt = @intCast(usize, @min(req.response.next_chunk_length, in.len));
- req.response.next_chunk_length -= sub_amt;
+ const buf_avail = req.read_buffer_len - req.read_buffer_start;
+ const data_avail = req.response.next_chunk_length;
+ const out_avail = buffer.len - out_index;
+
+ if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) {
+ const can_read = @intCast(usize, @min(buf_avail, data_avail));
+ req.response.next_chunk_length -= can_read;
+
+ if (req.response.next_chunk_length == 0) {
+ req.client.release(req.connection);
+ req.connection = undefined;
+ req.response.done = true;
+ continue;
+ }
+
+ return 0; // skip over as much data as possible
+ }
+
+ const can_read = @intCast(usize, @min(@min(buf_avail, data_avail), out_avail));
+ req.response.next_chunk_length -= can_read;
+
+ mem.copy(u8, buffer[out_index..], req.read_buffer[req.read_buffer_start..][0..can_read]);
+ req.read_buffer_start += @intCast(ReadBufferIndex, can_read);
+ out_index += can_read;
if (req.response.next_chunk_length == 0) {
req.response.state = .chunk_size_prefix_r;
- in = in[sub_amt..];
-
- if (req.response.headers.status.class() == .redirect) continue;
- mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
- out_index += sub_amt;
continue;
}
- if (req.response.headers.status.class() == .redirect) return 0;
-
- if (in.ptr == buffer.ptr) {
- return sub_amt;
- } else {
- mem.copy(u8, buffer[out_index..], in[0..sub_amt]);
- out_index += sub_amt;
- return out_index;
- }
+ return out_index;
},
}
}
}
+ pub const ReadError = DeflateDecompressor.Error || GzipDecompressor.Error || WaitForCompleteHeadError || error{
+ BadHeader,
+ InvalidCompression,
+ StreamTooLong,
+ InvalidWindowSize,
+ };
+
+ pub const Reader = std.io.Reader(*Request, ReadError, read);
+
+ pub fn reader(req: *Request) Reader {
+ return .{ .context = req };
+ }
+
+ pub fn read(req: *Request, buffer: []u8) ReadError!usize {
+ if (!req.response.state.isContent()) try req.waitForCompleteHead();
+
+ if (req.response.compression == .none and req.response.state.isContent()) {
+ if (req.response.headers.transfer_compression) |compression| {
+ switch (compression) {
+ .compress => unreachable,
+ .deflate => req.response.compression = .{
+ .deflate = try std.compress.zlib.zlibStream(req.client.allocator, ReaderRaw{ .context = req }),
+ },
+ .gzip => req.response.compression = .{
+ .gzip = try std.compress.gzip.decompress(req.client.allocator, ReaderRaw{ .context = req }),
+ },
+ .chunked => unreachable,
+ }
+ }
+ }
+
+ return switch (req.response.compression) {
+ .deflate => |*deflate| try deflate.read(buffer),
+ .gzip => |*gzip| try gzip.read(buffer),
+ else => try req.readRaw(buffer),
+ };
+ }
+
+ pub fn readAll(req: *Request, buffer: []u8) !usize {
+ var index: usize = 0;
+ while (index < buffer.len) {
+ const amt = try read(req, buffer[index..]);
+ if (amt == 0) break;
+ index += amt;
+ }
+ return index;
+ }
+
+ pub const WriteError = Connection.WriteError || error{MessageTooLong};
+
+ pub const Writer = std.io.Writer(*Request, WriteError, write);
+
+ pub fn writer(req: *Request) Writer {
+ return .{ .context = req };
+ }
+
+ /// Write `bytes` to the server. The `transfer_encoding` request header determines how data will be sent.
+ pub fn write(req: *Request, bytes: []const u8) !usize {
+ switch (req.headers.transfer_encoding) {
+ .chunked => {
+ try req.connection.data.writer().print("{x}\r\n", .{bytes.len});
+ try req.connection.data.writeAll(bytes);
+ try req.connection.data.writeAll("\r\n");
+
+ return bytes.len;
+ },
+ .content_length => |*len| {
+ if (len.* < bytes.len) return error.MessageTooLong;
+
+ const amt = try req.connection.data.write(bytes);
+ len.* -= amt;
+ return amt;
+ },
+ .none => return error.NotWriteable,
+ }
+ }
+
+ /// Finish the body of a request. This notifies the server that you have no more data to send.
+ pub fn finish(req: *Request) !void {
+ switch (req.headers.transfer_encoding) {
+ .chunked => try req.connection.data.writeAll("0\r\n"),
+ .content_length => |len| if (len != 0) return error.MessageNotCompleted,
+ .none => {},
+ }
+ }
+
inline fn int16(array: *const [2]u8) u16 {
return @bitCast(u16, array.*);
}
@@ -917,6 +1075,10 @@ pub const Request = struct {
}
test {
+ const builtin = @import("builtin");
+
+ if (builtin.os.tag == .wasi) return error.SkipZigTest;
+
_ = Response;
}
};
@@ -931,23 +1093,39 @@ pub fn deinit(client: *Client) void {
client.allocator.destroy(node);
}
+ next = client.connection_used.first;
+ while (next) |node| {
+ next = node.next;
+
+ node.data.close(client);
+
+ client.allocator.destroy(node);
+ }
+
client.ca_bundle.deinit(client.allocator);
client.* = undefined;
}
-pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) !*ConnectionNode {
- var potential = client.connection_pool.last;
- while (potential) |node| {
- const same_host = mem.eql(u8, node.data.host, host);
- const same_port = node.data.port == port;
- const same_protocol = node.data.protocol == protocol;
+pub const ConnectError = std.mem.Allocator.Error || std.net.TcpConnectToHostError || std.crypto.tls.Client.InitError(std.net.Stream);
- if (same_host and same_port and same_protocol) {
- client.connection_pool.remove(node);
- return node;
- }
+pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionNode {
+ { // Search through the connection pool for a potential connection.
+ client.connection_mutex.lock();
+ defer client.connection_mutex.unlock();
- potential = node.prev;
+ var potential = client.connection_pool.last;
+ while (potential) |node| {
+ const same_host = mem.eql(u8, node.data.host, host);
+ const same_port = node.data.port == port;
+ const same_protocol = node.data.protocol == protocol;
+
+ if (same_host and same_port and same_protocol) {
+ client.acquire(node);
+ return node;
+ }
+
+ potential = node.prev;
+ }
}
const conn = try client.allocator.create(ConnectionNode);
@@ -964,17 +1142,35 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
switch (protocol) {
.plain => {},
.tls => {
- conn.data.tls_client = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host);
+ conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client);
+ conn.data.tls_client.* = try std.crypto.tls.Client.init(conn.data.stream, client.ca_bundle, host);
// This is appropriate for HTTPS because the HTTP headers contain
// the content length which is used to detect truncation attacks.
conn.data.tls_client.allow_truncation_attacks = true;
},
}
+ {
+ client.connection_mutex.lock();
+ defer client.connection_mutex.unlock();
+
+ client.connection_used.append(conn);
+ }
+
return conn;
}
-pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) !Request {
+pub const RequestError = ConnectError || Connection.WriteError || error{
+ UnsupportedUrlScheme,
+ UriMissingHost,
+
+ CertificateAuthorityBundleTooBig,
+ InvalidPadding,
+ MissingEndCertificateMarker,
+ Unseekable,
+};
+
+pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) RequestError!Request {
const protocol: Connection.Protocol = if (mem.eql(u8, uri.scheme, "http"))
.plain
else if (mem.eql(u8, uri.scheme, "https"))
@@ -990,8 +1186,13 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req
const host = uri.host orelse return error.UriMissingHost;
if (client.next_https_rescan_certs and protocol == .tls) {
- try client.ca_bundle.rescan(client.allocator);
- client.next_https_rescan_certs = false;
+ client.connection_mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex.
+ defer client.connection_mutex.unlock();
+
+ if (client.next_https_rescan_certs) {
+ try client.ca_bundle.rescan(client.allocator);
+ client.next_https_rescan_certs = false;
+ }
}
var req: Request = .{
@@ -1006,23 +1207,39 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req
};
{
- var h = try std.BoundedArray(u8, 1000).init(0);
- try h.appendSlice(@tagName(headers.method));
- try h.appendSlice(" ");
- try h.appendSlice(uri.path);
- try h.appendSlice(" ");
- try h.appendSlice(@tagName(headers.version));
- try h.appendSlice("\r\nHost: ");
- try h.appendSlice(host);
- if (headers.connection_close) {
- try h.appendSlice("\r\nConnection: close");
+ var buffered = std.io.bufferedWriter(req.connection.data.writer());
+ const writer = buffered.writer();
+
+ try writer.writeAll(@tagName(headers.method));
+ try writer.writeByte(' ');
+ try writer.writeAll(uri.path);
+ try writer.writeByte(' ');
+ try writer.writeAll(@tagName(headers.version));
+ try writer.writeAll("\r\nHost: ");
+ try writer.writeAll(host);
+ if (headers.connection == .close) {
+ try writer.writeAll("\r\nConnection: close");
} else {
- try h.appendSlice("\r\nConnection: keep-alive");
+ try writer.writeAll("\r\nConnection: keep-alive");
}
- try h.appendSlice("\r\n\r\n");
+ try writer.writeAll("\r\nAccept-Encoding: gzip, deflate");
- const header_bytes = h.slice();
- try req.connection.data.writeAll(header_bytes);
+ switch (headers.transfer_encoding) {
+ .chunked => try writer.writeAll("\r\nTransfer-Encoding: chunked"),
+ .content_length => |content_length| try writer.print("\r\nContent-Length: {d}", .{content_length}),
+ .none => {},
+ }
+
+ for (headers.custom) |header| {
+ try writer.writeAll("\r\n");
+ try writer.writeAll(header.name);
+ try writer.writeAll(": ");
+ try writer.writeAll(header.value);
+ }
+
+ try writer.writeAll("\r\n\r\n");
+
+ try buffered.flush();
}
return req;
@@ -1036,5 +1253,7 @@ test {
return error.SkipZigTest;
}
+ if (builtin.os.tag == .wasi) return error.SkipZigTest;
+
_ = Request;
}
lib/std/http.zig
@@ -253,6 +253,16 @@ pub const TransferEncoding = enum {
gzip,
};
+pub const Connection = enum {
+ keep_alive,
+ close,
+};
+
+pub const CustomHeader = struct {
+ name: []const u8,
+ value: []const u8,
+};
+
const std = @import("std.zig");
test {