Commit 940d368e7e
Changed files (3)
lib
std
lib/std/crypto/tls/Client.zig
@@ -18,14 +18,20 @@ const array = tls.array;
const enum_array = tls.enum_array;
const Certificate = crypto.Certificate;
-application_cipher: ApplicationCipher,
read_seq: u64,
write_seq: u64,
-/// The size is enough to contain exactly one TLSCiphertext record.
-partially_read_buffer: [tls.max_ciphertext_record_len]u8,
/// The number of partially read bytes inside `partially_read_buffer`.
partially_read_len: u15,
+/// The number of cleartext bytes from decoding `partially_read_buffer` which
+/// have already been transferred via read() calls. This implementation will
+/// re-decrypt bytes from `partially_read_buffer` when the buffer supplied by
+/// the read() API user is not large enough.
+partial_cleartext_index: u15,
+application_cipher: ApplicationCipher,
eof: bool,
+/// The size is enough to contain exactly one TLSCiphertext record.
+/// Contains encrypted bytes.
+partially_read_buffer: [tls.max_ciphertext_record_len]u8,
/// `host` is only borrowed during this function call.
pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) !Client {
@@ -596,6 +602,7 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
.application_cipher = app_cipher,
.read_seq = 0,
.write_seq = 0,
+ .partial_cleartext_index = 0,
.partially_read_buffer = undefined,
.partially_read_len = @intCast(u15, len - end),
.eof = false,
@@ -722,27 +729,85 @@ pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void {
}
}
-/// Returns number of bytes that have been read, which are now populated inside
-/// `buffer`. A return value of zero bytes does not necessarily mean end of
-/// stream. Instead, the `eof` flag is set upon end of stream. The `eof` flag
-/// may be set after any call to `read`, including when greater than zero bytes
-/// are returned, and this function asserts that `eof` is `false`.
-pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize {
+/// Returns the number of bytes read, calling the underlying read function the
+/// minimal number of times until the buffer has at least `len` bytes filled.
+/// If the number read is less than `len` it means the stream reached the end.
+/// Reaching the end of the stream is not an error condition.
+pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize {
+ assert(len <= buffer.len);
+ if (c.eof) return 0;
+ var index: usize = 0;
+ while (index < len) {
+ index += try c.readAdvanced(stream, buffer[index..]);
+ if (c.eof) break;
+ }
+ return index;
+}
+
+pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize {
+ return readAtLeast(c, stream, buffer, 1);
+}
+
+/// Returns the number of bytes read. If the number read is smaller than
+/// `buffer.len`, it means the stream reached the end. Reaching the end of the
+/// stream is not an error condition.
+pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize {
+ return readAtLeast(c, stream, buffer, buffer.len);
+}
+
+/// Returns number of bytes that have been read, populated inside `buffer`. A
+/// return value of zero bytes does not mean end of stream. Instead, the `eof`
+/// flag is set upon end of stream. The `eof` flag may be set after any call to
+/// `read`, including when greater than zero bytes are returned, and this
+/// function asserts that `eof` is `false`.
+/// See `read` for a higher level function that has the same, familiar API
+/// as other read functions, such as `std.fs.File.read`.
+/// It is recommended to use a buffer size with length at least
+/// `tls.max_ciphertext_len` bytes to avoid redundantly decrypting the same
+/// encoded data.
+pub fn readAdvanced(c: *Client, stream: net.Stream, buffer: []u8) !usize {
assert(!c.eof);
const prev_len = c.partially_read_len;
- var in_buf: [max_ciphertext_len * 4]u8 = undefined;
- mem.copy(u8, &in_buf, c.partially_read_buffer[0..prev_len]);
+ // Ideally, this buffer would never be used. It is needed when `buffer` is too small
+ // to fit the cleartext, which may be as large as `max_ciphertext_len`.
+ var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined;
+ // This buffer is typically used, except, as an optimization when a very large
+ // `buffer` is provided, we use half of it for buffering ciphertext and the
+ // other half for outputting cleartext.
+ var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined;
+ const half_buffer_len = buffer.len / 2;
+ const out_in: struct { []u8, []u8 } = if (half_buffer_len >= in_stack_buffer.len) .{
+ buffer[0..half_buffer_len],
+ buffer[half_buffer_len..],
+ } else .{
+ buffer,
+ &in_stack_buffer,
+ };
+ const out_buf = out_in[0];
+ const in_buf = out_in[1];
+ mem.copy(u8, in_buf, c.partially_read_buffer[0..prev_len]);
// Capacity of output buffer, in records, rounded up.
- const buf_cap = (buffer.len +| (max_ciphertext_len - 1)) / max_ciphertext_len;
+ const buf_cap = (out_buf.len +| (max_ciphertext_len - 1)) / max_ciphertext_len;
const wanted_read_len = buf_cap * (max_ciphertext_len + tls.ciphertext_record_header_len);
- const ask_slice = in_buf[prev_len..@min(wanted_read_len, in_buf.len)];
- const actual_read_len = try stream.read(ask_slice);
- const frag = in_buf[0 .. prev_len + actual_read_len];
- if (frag.len == 0) {
- // This is either a truncation attack, or a bug in the server.
- return error.TlsConnectionTruncated;
- }
+ const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len);
+ const ask_slice = in_buf[prev_len..][0..@min(ask_len, in_buf.len - prev_len)];
+ assert(ask_slice.len > 0);
+ const frag = frag: {
+ if (prev_len >= 5) {
+ const record_size = mem.readIntBig(u16, in_buf[3..][0..2]);
+ if (prev_len >= 5 + record_size) {
+ // We can use our buffered data without calling read().
+ break :frag in_buf[0..prev_len];
+ }
+ }
+ const actual_read_len = try stream.read(ask_slice);
+ if (actual_read_len == 0) {
+ // This is either a truncation attack, or a bug in the server.
+ return error.TlsConnectionTruncated;
+ }
+ break :frag in_buf[0 .. prev_len + actual_read_len];
+ };
var in: usize = 0;
var out: usize = 0;
@@ -750,6 +815,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize {
if (in + tls.ciphertext_record_header_len > frag.len) {
return finishRead(c, frag, in, out);
}
+ const record_start = in;
const ct = @intToEnum(ContentType, frag[in]);
in += 1;
const legacy_version = mem.readIntBig(u16, frag[in..][0..2]);
@@ -767,7 +833,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize {
@panic("TODO handle an alert here");
},
.application_data => {
- const cleartext_len = switch (c.application_cipher) {
+ const cleartext = switch (c.application_cipher) {
inline else => |*p| c: {
const P = @TypeOf(p.*);
const V = @Vector(P.AEAD.nonce_length, u8);
@@ -776,29 +842,29 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize {
const ciphertext = frag[in..][0..ciphertext_len];
in += ciphertext_len;
const auth_tag = frag[in..][0..P.AEAD.tag_length].*;
- const cleartext = buffer[out..][0..ciphertext_len];
const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
+ // Here we use read_seq and then intentionally don't
+ // increment it until later when it is certain the same
+ // ciphertext does not need to be decrypted again.
const operand: V = pad ++ @bitCast([8]u8, big(c.read_seq));
- c.read_seq += 1;
const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand;
- //std.debug.print("seq: {d} nonce: {} server_key: {} server_iv: {}\n", .{
- // c.read_seq - 1,
- // std.fmt.fmtSliceHexLower(&nonce),
- // std.fmt.fmtSliceHexLower(&p.server_key),
- // std.fmt.fmtSliceHexLower(&p.server_iv),
- //});
+ const cleartext_buf = if (c.partial_cleartext_index == 0 and out + ciphertext.len <= out_buf.len)
+ out_buf[out..]
+ else
+ &cleartext_stack_buffer;
+ const cleartext = cleartext_buf[0..ciphertext.len];
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch
return error.TlsBadRecordMac;
- break :c cleartext.len;
+ break :c cleartext;
},
};
- const cleartext = buffer[out..][0..cleartext_len];
const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]);
switch (inner_ct) {
.alert => {
- const level = @intToEnum(tls.AlertLevel, buffer[out]);
- const desc = @intToEnum(tls.AlertDescription, buffer[out + 1]);
+ c.read_seq += 1;
+ const level = @intToEnum(tls.AlertLevel, out_buf[out]);
+ const desc = @intToEnum(tls.AlertDescription, out_buf[out + 1]);
if (desc == .close_notify) {
c.eof = true;
return out;
@@ -807,6 +873,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize {
return error.TlsAlert;
},
.handshake => {
+ c.read_seq += 1;
var ct_i: usize = 0;
while (true) {
const handshake_type = @intToEnum(tls.HandshakeType, cleartext[ct_i]);
@@ -819,7 +886,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize {
const handshake = cleartext[ct_i..next_handshake_i];
switch (handshake_type) {
.new_session_ticket => {
- std.debug.print("server sent a new session ticket\n", .{});
+ // This client implementation ignores new session tickets.
},
.key_update => {
switch (c.application_cipher) {
@@ -859,7 +926,35 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize {
}
},
.application_data => {
- out += cleartext_len - 1;
+ // Determine whether the output buffer or a stack
+ // buffer was used for storing the cleartext.
+ if (c.partial_cleartext_index == 0 and
+ out + cleartext.len <= out_buf.len)
+ {
+ // Output buffer was used directly which means no
+ // memory copying needs to occur, and we can move
+ // on to the next ciphertext record.
+ out += cleartext.len - 1;
+ c.read_seq += 1;
+ } else {
+ // Stack buffer was used, so we must copy to the output buffer.
+ const dest = out_buf[out..];
+ const rest = cleartext[c.partial_cleartext_index..];
+ const src = rest[0..@min(rest.len, dest.len)];
+ mem.copy(u8, dest, src);
+ out += src.len;
+ c.partial_cleartext_index = @intCast(
+ @TypeOf(c.partial_cleartext_index),
+ c.partial_cleartext_index + src.len,
+ );
+ if (c.partial_cleartext_index >= cleartext.len) {
+ c.partial_cleartext_index = 0;
+ c.read_seq += 1;
+ } else {
+ in = record_start;
+ return finishRead(c, frag, in, out);
+ }
+ }
},
else => {
std.debug.print("inner content type: {d}\n", .{inner_ct});
lib/std/http/Client.zig
@@ -63,16 +63,10 @@ pub const Request = struct {
}
pub fn readAtLeast(req: *Request, buffer: []u8, len: usize) !usize {
- var index: usize = 0;
- while (index < len) {
- const amt = try req.read(buffer[index..]);
- index += amt;
- switch (req.protocol) {
- .http => if (amt == 0) break,
- .https => if (req.tls_client.eof) break,
- }
+ switch (req.protocol) {
+ .http => return req.stream.readAtLeast(buffer, len),
+ .https => return req.tls_client.readAtLeast(req.stream, buffer, len),
}
- return index;
}
};
lib/std/net.zig
@@ -1680,11 +1680,12 @@ pub const Stream = struct {
}
/// Returns the number of bytes read, calling the underlying read function
- /// the minimal number of times until at least the buffer has at least
- /// `len` bytes filled. If the number read is less than `len` it means the
- /// stream reached the end. Reaching the end of the stream is not an error
+ /// the minimal number of times until the buffer has at least `len` bytes
+ /// filled. If the number read is less than `len` it means the stream
+ /// reached the end. Reaching the end of the stream is not an error
/// condition.
pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize {
+ assert(len <= buffer.len);
var index: usize = 0;
while (index < len) {
const amt = try s.read(buffer[index..]);