Commit a0f9a5e78d
Changed files (6)
lib
std
lib/std/crypto/tls/Client.zig
@@ -183,7 +183,6 @@ const InitError = error{
/// `input` is asserted to have buffer capacity at least `min_buffer_len`.
pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client {
assert(input.buffer.len >= min_buffer_len);
- assert(output.buffer.len >= min_buffer_len);
const host = switch (options.host) {
.no_verification => "",
.explicit => |host| host,
@@ -1124,12 +1123,6 @@ fn readIndirect(c: *Client) Reader.Error!usize {
if (record_end > input.buffered().len) return 0;
}
- if (r.seek == r.end) {
- r.seek = 0;
- r.end = 0;
- }
- const cleartext_buffer = r.buffer[r.end..];
-
const cleartext_len, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) {
inline else => |*p| switch (c.tls_version) {
.tls_1_3 => {
@@ -1145,7 +1138,8 @@ fn readIndirect(c: *Client) Reader.Error!usize {
const operand: V = pad ++ mem.toBytes(big(c.read_seq));
break :nonce @as(V, pv.server_iv) ^ operand;
};
- const cleartext = cleartext_buffer[0..ciphertext.len];
+ rebase(r, ciphertext.len);
+ const cleartext = r.buffer[r.end..][0..ciphertext.len];
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch
return failRead(c, error.TlsBadRecordMac);
// TODO use scalar, non-slice version
@@ -1171,7 +1165,8 @@ fn readIndirect(c: *Client) Reader.Error!usize {
};
const ciphertext = input.take(message_len) catch unreachable; // already peeked
const auth_tag = (input.takeArray(P.mac_length) catch unreachable).*; // already peeked
- const cleartext = cleartext_buffer[0..ciphertext.len];
+ rebase(r, ciphertext.len);
+ const cleartext = r.buffer[r.end..][0..ciphertext.len];
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch
return failRead(c, error.TlsBadRecordMac);
break :cleartext .{ cleartext.len, ct };
@@ -1179,7 +1174,7 @@ fn readIndirect(c: *Client) Reader.Error!usize {
else => unreachable,
},
};
- const cleartext = cleartext_buffer[0..cleartext_len];
+ const cleartext = r.buffer[r.end..][0..cleartext_len];
c.read_seq = std.math.add(u64, c.read_seq, 1) catch return failRead(c, error.TlsSequenceOverflow);
switch (inner_ct) {
.alert => {
@@ -1275,6 +1270,15 @@ fn readIndirect(c: *Client) Reader.Error!usize {
}
}
+fn rebase(r: *Reader, capacity: usize) void {
+ if (r.buffer.len - r.end >= capacity) return;
+ const data = r.buffer[r.seek..r.end];
+ @memmove(r.buffer[0..data.len], data);
+ r.seek = 0;
+ r.end = data.len;
+ assert(r.buffer.len - r.end >= capacity);
+}
+
fn failRead(c: *Client, err: ReadError) error{ReadFailed} {
c.read_err = err;
return error.ReadFailed;
lib/std/http/Client.zig
@@ -42,7 +42,7 @@ connection_pool: ConnectionPool = .{},
///
/// If the entire HTTP header cannot fit in this amount of bytes,
/// `error.HttpHeadersOversize` will be returned from `Request.wait`.
-read_buffer_size: usize = 4096 + if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len,
+read_buffer_size: usize = 8192,
/// Each `Connection` allocates this amount for the writer buffer.
write_buffer_size: usize = 1024,
@@ -302,18 +302,22 @@ pub const Connection = struct {
const base = try gpa.alignedAlloc(u8, .of(Tls), alloc_len);
errdefer gpa.free(base);
const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len];
- const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.tls_buffer_size];
+ // The TLS client wants enough buffer for the max encrypted frame
+ // size, and the HTTP body reader wants enough buffer for the
+ // entire HTTP header. This means we need a combined upper bound.
+ const tls_read_buffer_len = client.tls_buffer_size + client.read_buffer_size;
+ const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..tls_read_buffer_len];
const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size];
- const write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size];
- const read_buffer = write_buffer.ptr[write_buffer.len..][0..client.read_buffer_size];
- assert(base.ptr + alloc_len == read_buffer.ptr + read_buffer.len);
+ const socket_write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size];
+ const socket_read_buffer = socket_write_buffer.ptr[socket_write_buffer.len..][0..client.tls_buffer_size];
+ assert(base.ptr + alloc_len == socket_read_buffer.ptr + socket_read_buffer.len);
@memcpy(host_buffer, remote_host);
const tls: *Tls = @ptrCast(base);
tls.* = .{
.connection = .{
.client = client,
.stream_writer = stream.writer(tls_write_buffer),
- .stream_reader = stream.reader(tls_read_buffer),
+ .stream_reader = stream.reader(socket_read_buffer),
.pool_node = .{},
.port = port,
.host_len = @intCast(remote_host.len),
@@ -329,8 +333,8 @@ pub const Connection = struct {
.host = .{ .explicit = remote_host },
.ca = .{ .bundle = client.ca_bundle },
.ssl_key_log = client.ssl_key_log,
- .read_buffer = read_buffer,
- .write_buffer = write_buffer,
+ .read_buffer = tls_read_buffer,
+ .write_buffer = socket_write_buffer,
// This is appropriate for HTTPS because the HTTP headers contain
// the content length which is used to detect truncation attacks.
.allow_truncation_attacks = true,
@@ -348,8 +352,9 @@ pub const Connection = struct {
}
fn allocLen(client: *Client, host_len: usize) usize {
- return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size +
- client.write_buffer_size + client.read_buffer_size;
+ const tls_read_buffer_len = client.tls_buffer_size + client.read_buffer_size;
+ return @sizeOf(Tls) + host_len + tls_read_buffer_len + client.tls_buffer_size +
+ client.write_buffer_size + client.tls_buffer_size;
}
fn host(tls: *Tls) []u8 {
@@ -1214,6 +1219,7 @@ pub const Request = struct {
.state = .ready,
// Populated when `http.Reader.bodyReader` is called.
.interface = undefined,
+ .max_head_len = r.client.read_buffer_size,
};
r.redirect_behavior.subtractOne();
}
@@ -1679,6 +1685,7 @@ pub fn request(
.state = .ready,
// Populated when `http.Reader.bodyReader` is called.
.interface = undefined,
+ .max_head_len = client.read_buffer_size,
},
.keep_alive = options.keep_alive,
.method = method,
lib/std/http/Server.zig
@@ -29,6 +29,7 @@ pub fn init(in: *Reader, out: *Writer) Server {
.state = .ready,
// Populated when `http.Reader.bodyReader` is called.
.interface = undefined,
+ .max_head_len = in.buffer.len,
},
.out = out,
};
@@ -251,6 +252,7 @@ pub const Request = struct {
.in = undefined,
.state = .received_head,
.interface = undefined,
+ .max_head_len = 4096,
},
.out = undefined,
};
lib/std/Io/Reader.zig
@@ -86,12 +86,12 @@ pub const VTable = struct {
/// `Reader.buffer`, whichever is bigger.
readVec: *const fn (r: *Reader, data: [][]u8) Error!usize = defaultReadVec,
- /// Ensures `capacity` more data can be buffered without rebasing.
+ /// Ensures `capacity` data can be buffered without rebasing.
///
/// Asserts `capacity` is within buffer capacity, or that the stream ends
/// within `capacity` bytes.
///
- /// Only called when `capacity` cannot fit into the unused capacity of
+ /// Only called when `capacity` cannot be satisfied by unused capacity of
/// `buffer`.
///
/// The default implementation moves buffered data to the start of
@@ -1037,7 +1037,7 @@ fn fillUnbuffered(r: *Reader, n: usize) Error!void {
///
/// Asserts buffer capacity is at least 1.
pub fn fillMore(r: *Reader) Error!void {
- try rebase(r, 1);
+ try rebase(r, r.end - r.seek + 1);
var bufs: [1][]u8 = .{""};
_ = try r.vtable.readVec(r, &bufs);
}
@@ -1205,24 +1205,6 @@ pub fn takeLeb128(r: *Reader, comptime Result: type) TakeLeb128Error!Result {
} }))) orelse error.Overflow;
}
-pub fn expandTotalCapacity(r: *Reader, allocator: Allocator, n: usize) Allocator.Error!void {
- if (n <= r.buffer.len) return;
- if (r.seek > 0) rebase(r, r.buffer.len);
- var list: ArrayList(u8) = .{
- .items = r.buffer[0..r.end],
- .capacity = r.buffer.len,
- };
- defer r.buffer = list.allocatedSlice();
- try list.ensureTotalCapacity(allocator, n);
-}
-
-pub const FillAllocError = Error || Allocator.Error;
-
-pub fn fillAlloc(r: *Reader, allocator: Allocator, n: usize) FillAllocError!void {
- try expandTotalCapacity(r, allocator, n);
- return fill(r, n);
-}
-
fn takeMultipleOf7Leb128(r: *Reader, comptime Result: type) TakeLeb128Error!Result {
const result_info = @typeInfo(Result).int;
comptime assert(result_info.bits % 7 == 0);
@@ -1253,9 +1235,9 @@ fn takeMultipleOf7Leb128(r: *Reader, comptime Result: type) TakeLeb128Error!Resu
}
}
-/// Ensures `capacity` more data can be buffered without rebasing.
+/// Ensures `capacity` data can be buffered without rebasing.
pub fn rebase(r: *Reader, capacity: usize) RebaseError!void {
- if (r.end + capacity <= r.buffer.len) {
+ if (r.buffer.len - r.seek >= capacity) {
@branchHint(.likely);
return;
}
@@ -1263,11 +1245,12 @@ pub fn rebase(r: *Reader, capacity: usize) RebaseError!void {
}
pub fn defaultRebase(r: *Reader, capacity: usize) RebaseError!void {
- if (r.end <= r.buffer.len - capacity) return;
+ assert(r.buffer.len - r.seek < capacity);
const data = r.buffer[r.seek..r.end];
@memmove(r.buffer[0..data.len], data);
r.seek = 0;
r.end = data.len;
+ assert(r.buffer.len - r.seek >= capacity);
}
test fixed {
lib/std/http.zig
@@ -329,6 +329,7 @@ pub const Reader = struct {
/// read from `in`.
trailers: []const u8 = &.{},
body_err: ?BodyError = null,
+ max_head_len: usize,
pub const RemainingChunkLen = enum(u64) {
head = 0,
@@ -387,10 +388,11 @@ pub const Reader = struct {
pub fn receiveHead(reader: *Reader) HeadError![]const u8 {
reader.trailers = &.{};
const in = reader.in;
+ const max_head_len = reader.max_head_len;
var hp: HeadParser = .{};
var head_len: usize = 0;
while (true) {
- if (in.buffer.len - head_len == 0) return error.HttpHeadersOversize;
+ if (head_len >= max_head_len) return error.HttpHeadersOversize;
const remaining = in.buffered()[head_len..];
if (remaining.len == 0) {
in.fillMore() catch |err| switch (err) {
lib/std/Io.zig
@@ -717,7 +717,12 @@ pub fn Poller(comptime StreamEnum: type) type {
const unused = r.buffer[r.end..];
if (unused.len >= min_len) return unused;
}
- if (r.seek > 0) r.rebase(r.buffer.len) catch unreachable;
+ if (r.seek > 0) {
+ const data = r.buffer[r.seek..r.end];
+ @memmove(r.buffer[0..data.len], data);
+ r.seek = 0;
+ r.end = data.len;
+ }
{
var list: std.ArrayListUnmanaged(u8) = .{
.items = r.buffer[0..r.end],