Commit 0a4130f364
Changed files (4)
lib
std
lib/std/crypto/tls/Client.zig
@@ -89,7 +89,7 @@ pub const StreamInterface = struct {
};
pub fn InitError(comptime Stream: type) type {
- return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || error {
+ return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || error{
InsufficientEntropy,
DiskQuota,
LockViolation,
lib/std/http/Client.zig
@@ -29,9 +29,10 @@ const ConnectionPool = std.TailQueue(Connection);
const ConnectionNode = ConnectionPool.Node;
/// Acquires an existing connection from the connection pool. This function is threadsafe.
-pub fn acquire(client: *Client, node: *ConnectionNode) void {
- client.connection_mutex.lock();
- defer client.connection_mutex.unlock();
+/// If the caller already holds the connection mutex, it should pass `true` for `held`.
+pub fn acquire(client: *Client, node: *ConnectionNode, held: bool) void {
+ if (!held) client.connection_mutex.lock();
+ defer if (!held) client.connection_mutex.unlock();
client.connection_pool.remove(node);
client.connection_used.append(node);
@@ -40,16 +41,17 @@ pub fn acquire(client: *Client, node: *ConnectionNode) void {
/// Tries to release a connection back to the connection pool. This function is threadsafe.
/// If the connection is marked as closing, it will be closed instead.
pub fn release(client: *Client, node: *ConnectionNode) void {
+ client.connection_mutex.lock();
+ defer client.connection_mutex.unlock();
+
+ client.connection_used.remove(node);
+
if (node.data.closing) {
node.data.close(client);
return client.allocator.destroy(node);
}
- client.connection_mutex.lock();
- defer client.connection_mutex.unlock();
-
- client.connection_used.remove(node);
client.connection_pool.append(node);
}
@@ -83,7 +85,7 @@ pub const Connection = struct {
}
}
- pub const ReadError = std.net.Stream.ReadError || error{
+ pub const ReadError = net.Stream.ReadError || error{
TlsConnectionTruncated,
TlsRecordOverflow,
TlsDecodeError,
@@ -115,7 +117,7 @@ pub const Connection = struct {
}
}
- pub const WriteError = std.net.Stream.WriteError || error{};
+ pub const WriteError = net.Stream.WriteError || error{};
pub const Writer = std.io.Writer(*Connection, WriteError, write);
pub fn writer(conn: *Connection) Writer {
@@ -139,14 +141,21 @@ pub const Request = struct {
const read_buffer_size = 8192;
const ReadBufferIndex = std.math.IntFittingRange(0, read_buffer_size);
+ uri: Uri,
client: *Client,
connection: *ConnectionNode,
- redirects_left: u32,
response: Response,
/// These are stored in Request so that they are available when following
/// redirects.
headers: Headers,
+ redirects_left: u32,
+ handle_redirects: bool,
+ compression_init: bool,
+
+ /// Used as a allocator for resolving redirects locations.
+ arena: std.heap.ArenaAllocator,
+
/// Read buffer for the connection. This is used to pull in large amounts of data from the connection even if the user asks for a small amount. This can probably be removed with careful planning.
read_buffer: [read_buffer_size]u8 = undefined,
read_buffer_start: ReadBufferIndex = 0,
@@ -661,6 +670,7 @@ pub const Request = struct {
pub const Headers = struct {
version: http.Version = .@"HTTP/1.1",
method: http.Method = .GET,
+ user_agent: []const u8 = "Zig (std.http)",
connection: http.Connection = .keep_alive,
transfer_encoding: RequestTransfer = .none,
@@ -668,6 +678,7 @@ pub const Request = struct {
};
pub const Options = struct {
+ handle_redirects: bool = true,
max_redirects: u32 = 3,
header_strategy: HeaderStrategy = .{ .dynamic = 16 * 1024 },
@@ -703,10 +714,11 @@ pub const Request = struct {
req.client.release(req.connection);
}
+ req.arena.deinit();
req.* = undefined;
}
- const ReadRawError = Connection.ReadError || std.Uri.ParseError || RequestError || error{
+ const ReadRawError = Connection.ReadError || Uri.ParseError || RequestError || error{
UnexpectedEndOfStream,
TooManyHttpRedirects,
HttpRedirectMissingLocation,
@@ -723,9 +735,7 @@ pub const Request = struct {
var index: usize = 0;
while (index == 0) {
const amt = try req.readRawAdvanced(buffer[index..]);
- const zero_means_end = req.response.done and req.response.headers.status.class() != .redirect;
-
- if (amt == 0 and zero_means_end) break;
+ if (amt == 0 and req.response.done) break;
index += amt;
}
@@ -769,6 +779,8 @@ pub const Request = struct {
}
} else if (req.response.headers.content_length) |content_length| {
req.response.next_chunk_length = content_length;
+
+ if (content_length == 0) req.response.done = true;
} else {
req.response.done = true;
}
@@ -779,7 +791,7 @@ pub const Request = struct {
return 0;
}
- pub const WaitForCompleteHeadError = ReadRawError || error {
+ pub const WaitForCompleteHeadError = ReadRawError || error{
UnexpectedEndOfStream,
HttpHeadersExceededSizeLimit,
@@ -810,27 +822,8 @@ pub const Request = struct {
/// This one can return 0 without meaning EOF.
fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
- if (req.response.done) {
- if (req.response.headers.status.class() == .redirect) {
- if (req.redirects_left == 0) return error.TooManyHttpRedirects;
-
- const location = req.response.headers.location orelse
- return error.HttpRedirectMissingLocation;
- const new_url = try std.Uri.parse(location);
- const new_req = try req.client.request(new_url, req.headers, .{
- .max_redirects = req.redirects_left - 1,
- .header_strategy = if (req.response.header_bytes_owned) .{
- .dynamic = req.response.max_header_bytes,
- } else .{
- .static = req.response.header_bytes.unusedCapacitySlice(),
- },
- });
- req.deinit();
- req.* = new_req;
- } else {
- return 0;
- }
- }
+ assert(req.response.state.isContent());
+ if (req.response.done) return 0;
// var in: []const u8 = undefined;
if (req.read_buffer_start == req.read_buffer_len) {
@@ -851,7 +844,7 @@ pub const Request = struct {
const data_avail = req.response.next_chunk_length;
const out_avail = buffer.len;
- if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) {
+ if (req.handle_redirects and req.response.headers.status.class() == .redirect) {
const can_read = @intCast(usize, @min(buf_avail, data_avail));
req.response.next_chunk_length -= can_read;
@@ -859,7 +852,6 @@ pub const Request = struct {
req.client.release(req.connection);
req.connection = undefined;
req.response.done = true;
- continue;
}
return 0; // skip over as much data as possible
@@ -943,7 +935,7 @@ pub const Request = struct {
const data_avail = req.response.next_chunk_length;
const out_avail = buffer.len - out_index;
- if (req.response.state.isContent() and req.response.headers.status.class() == .redirect) {
+ if (req.handle_redirects and req.response.headers.status.class() == .redirect) {
const can_read = @intCast(usize, @min(buf_avail, data_avail));
req.response.next_chunk_length -= can_read;
@@ -990,9 +982,41 @@ pub const Request = struct {
}
pub fn read(req: *Request, buffer: []u8) ReadError!usize {
- if (!req.response.state.isContent()) try req.waitForCompleteHead();
+ while (true) {
+ if (!req.response.state.isContent()) try req.waitForCompleteHead();
+
+ if (req.handle_redirects and req.response.headers.status.class() == .redirect) {
+ assert(try req.readRaw(buffer) == 0);
+
+ if (req.redirects_left == 0) return error.TooManyHttpRedirects;
+
+ const location = req.response.headers.location orelse
+ return error.HttpRedirectMissingLocation;
+ const new_url = Uri.parse(location) catch try Uri.parseWithoutScheme(location);
+
+ var new_arena = std.heap.ArenaAllocator.init(req.client.allocator);
+ const resolved_url = try req.uri.resolve(new_url, false, new_arena.allocator());
+ errdefer new_arena.deinit();
+
+ req.arena.deinit();
+ req.arena = new_arena;
+
+ const new_req = try req.client.request(resolved_url, req.headers, .{
+ .max_redirects = req.redirects_left - 1,
+ .header_strategy = if (req.response.header_bytes_owned) .{
+ .dynamic = req.response.max_header_bytes,
+ } else .{
+ .static = req.response.header_bytes.unusedCapacitySlice(),
+ },
+ });
+ req.deinit();
+ req.* = new_req;
+ } else {
+ break;
+ }
+ }
- if (req.response.compression == .none and req.response.state.isContent()) {
+ if (req.response.compression == .none) {
if (req.response.headers.transfer_compression) |compression| {
switch (compression) {
.compress => unreachable,
@@ -1084,6 +1108,8 @@ pub const Request = struct {
};
pub fn deinit(client: *Client) void {
+ client.connection_mutex.lock();
+
var next = client.connection_pool.first;
while (next) |node| {
next = node.next;
@@ -1106,7 +1132,7 @@ pub fn deinit(client: *Client) void {
client.* = undefined;
}
-pub const ConnectError = std.mem.Allocator.Error || std.net.TcpConnectToHostError || std.crypto.tls.Client.InitError(std.net.Stream);
+pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream);
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionNode {
{ // Search through the connection pool for a potential connection.
@@ -1120,7 +1146,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
const same_protocol = node.data.protocol == protocol;
if (same_host and same_port and same_protocol) {
- client.acquire(node);
+ client.acquire(node, true);
return node;
}
@@ -1168,6 +1194,7 @@ pub const RequestError = ConnectError || Connection.WriteError || error{
InvalidPadding,
MissingEndCertificateMarker,
Unseekable,
+ EndOfStream,
};
pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Request.Options) RequestError!Request {
@@ -1196,27 +1223,52 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req
}
var req: Request = .{
+ .uri = uri,
.client = client,
.headers = headers,
.connection = try client.connect(host, port, protocol),
.redirects_left = options.max_redirects,
+ .handle_redirects = options.handle_redirects,
+ .compression_init = false,
.response = switch (options.header_strategy) {
.dynamic => |max| Request.Response.initDynamic(max),
.static => |buf| Request.Response.initStatic(buf),
},
+ .arena = undefined,
};
+ req.arena = std.heap.ArenaAllocator.init(client.allocator);
+
{
var buffered = std.io.bufferedWriter(req.connection.data.writer());
const writer = buffered.writer();
+ const escaped_path = try Uri.escapePath(client.allocator, uri.path);
+ defer client.allocator.free(escaped_path);
+
+ const escaped_query = if (uri.query) |q| try Uri.escapeQuery(client.allocator, q) else null;
+ defer if (escaped_query) |q| client.allocator.free(q);
+
+ const escaped_fragment = if (uri.fragment) |f| try Uri.escapeQuery(client.allocator, f) else null;
+ defer if (escaped_fragment) |f| client.allocator.free(f);
+
try writer.writeAll(@tagName(headers.method));
try writer.writeByte(' ');
- try writer.writeAll(uri.path);
+ try writer.writeAll(escaped_path);
+ if (escaped_query) |q| {
+ try writer.writeByte('?');
+ try writer.writeAll(q);
+ }
+ if (escaped_fragment) |f| {
+ try writer.writeByte('#');
+ try writer.writeAll(f);
+ }
try writer.writeByte(' ');
try writer.writeAll(@tagName(headers.version));
try writer.writeAll("\r\nHost: ");
try writer.writeAll(host);
+ try writer.writeAll("\r\nUser-Agent: ");
+ try writer.writeAll(headers.user_agent);
if (headers.connection == .close) {
try writer.writeAll("\r\nConnection: close");
} else {
lib/std/net.zig
@@ -741,9 +741,9 @@ pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream {
return Stream{ .handle = sockfd };
}
-const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || std.os.SocketError || std.os.BindError || error {
+const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError || std.fs.File.ReadError || std.os.SocketError || std.os.BindError || error{
// TODO: break this up into error sets from the various underlying functions
-
+
TemporaryNameServerFailure,
NameServerFailure,
AddressFamilyNotSupported,
@@ -760,7 +760,7 @@ const GetAddressListError = std.mem.Allocator.Error || std.fs.File.OpenError ||
Incomplete,
InvalidIpv4Mapping,
InvalidIPAddressFormat,
-
+
InterfaceNotFound,
FileSystem,
};
lib/std/Uri.zig
@@ -16,15 +16,27 @@ fragment: ?[]const u8,
/// Applies URI encoding and replaces all reserved characters with their respective %XX code.
pub fn escapeString(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 {
+ return escapeStringWithFn(allocator, input, isUnreserved);
+}
+
+pub fn escapePath(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 {
+ return escapeStringWithFn(allocator, input, isPathChar);
+}
+
+pub fn escapeQuery(allocator: std.mem.Allocator, input: []const u8) error{OutOfMemory}![]const u8 {
+ return escapeStringWithFn(allocator, input, isQueryChar);
+}
+
+pub fn escapeStringWithFn(allocator: std.mem.Allocator, input: []const u8, comptime keepUnescaped: fn (c: u8) bool) std.mem.Allocator.Error![]const u8 {
var outsize: usize = 0;
for (input) |c| {
- outsize += if (isUnreserved(c)) @as(usize, 1) else 3;
+ outsize += if (keepUnescaped(c)) @as(usize, 1) else 3;
}
var output = try allocator.alloc(u8, outsize);
var outptr: usize = 0;
for (input) |c| {
- if (isUnreserved(c)) {
+ if (keepUnescaped(c)) {
output[outptr] = c;
outptr += 1;
} else {
@@ -94,13 +106,14 @@ pub fn unescapeString(allocator: std.mem.Allocator, input: []const u8) error{Out
pub const ParseError = error{ UnexpectedCharacter, InvalidFormat, InvalidPort };
-/// Parses the URI or returns an error.
+/// Parses the URI or returns an error. This function is not compliant, but is required to parse
+/// some forms of URIs in the wild. Such as HTTP Location headers.
/// The return value will contain unescaped strings pointing into the
/// original `text`. Each component that is provided, will be non-`null`.
-pub fn parse(text: []const u8) ParseError!Uri {
+pub fn parseWithoutScheme(text: []const u8) ParseError!Uri {
var reader = SliceReader{ .slice = text };
var uri = Uri{
- .scheme = reader.readWhile(isSchemeChar),
+ .scheme = "",
.user = null,
.password = null,
.host = null,
@@ -110,14 +123,6 @@ pub fn parse(text: []const u8) ParseError!Uri {
.fragment = null,
};
- // after the scheme, a ':' must appear
- if (reader.get()) |c| {
- if (c != ':')
- return error.UnexpectedCharacter;
- } else {
- return error.InvalidFormat;
- }
-
if (reader.peekPrefix("//")) { // authority part
std.debug.assert(reader.get().? == '/');
std.debug.assert(reader.get().? == '/');
@@ -179,6 +184,76 @@ pub fn parse(text: []const u8) ParseError!Uri {
return uri;
}
+/// Parses the URI or returns an error.
+/// The return value will contain unescaped strings pointing into the
+/// original `text`. Each component that is provided, will be non-`null`.
+pub fn parse(text: []const u8) ParseError!Uri {
+ var reader = SliceReader{ .slice = text };
+ const scheme = reader.readWhile(isSchemeChar);
+
+ // after the scheme, a ':' must appear
+ if (reader.get()) |c| {
+ if (c != ':')
+ return error.UnexpectedCharacter;
+ } else {
+ return error.InvalidFormat;
+ }
+
+ var uri = try parseWithoutScheme(reader.readUntilEof());
+ uri.scheme = scheme;
+
+ return uri;
+}
+
+/// Resolves a URI against a base URI, conforming to RFC 3986, Section 5.
+/// arena owns any memory allocated by this function.
+pub fn resolve(Base: Uri, R: Uri, strict: bool, arena: std.mem.Allocator) !Uri {
+ var T: Uri = undefined;
+
+ if (R.scheme.len > 0 and !((!strict) and (std.mem.eql(u8, R.scheme, Base.scheme)))) {
+ T.scheme = R.scheme;
+ T.user = R.user;
+ T.host = R.host;
+ T.port = R.port;
+ T.path = try std.fs.path.resolvePosix(arena, &.{ "/", R.path });
+ T.query = R.query;
+ } else {
+ if (R.host) |host| {
+ T.user = R.user;
+ T.host = host;
+ T.port = R.port;
+ T.path = R.path;
+ T.path = try std.fs.path.resolvePosix(arena, &.{ "/", R.path });
+ T.query = R.query;
+ } else {
+ if (R.path.len == 0) {
+ T.path = Base.path;
+ if (R.query) |query| {
+ T.query = query;
+ } else {
+ T.query = Base.query;
+ }
+ } else {
+ if (R.path[0] == '/') {
+ T.path = try std.fs.path.resolvePosix(arena, &.{ "/", R.path });
+ } else {
+ T.path = try std.fs.path.resolvePosix(arena, &.{ "/", Base.path, R.path });
+ }
+ T.query = R.query;
+ }
+
+ T.user = Base.user;
+ T.host = Base.host;
+ T.port = Base.port;
+ }
+ T.scheme = Base.scheme;
+ }
+
+ T.fragment = R.fragment;
+
+ return T;
+}
+
const SliceReader = struct {
const Self = @This();
@@ -284,6 +359,14 @@ fn isPathSeparator(c: u8) bool {
};
}
+fn isPathChar(c: u8) bool {
+ return isUnreserved(c) or isSubLimit(c) or c == '/' or c == ':' or c == '@';
+}
+
+fn isQueryChar(c: u8) bool {
+ return isPathChar(c) or c == '?';
+}
+
fn isQuerySeparator(c: u8) bool {
return switch (c) {
'#' => true,