Commit 28190cc404
Changed files (4)
lib
std
lib/std/crypto/tls/Client.zig
@@ -1,11 +1,15 @@
+const builtin = @import("builtin");
+const native_endian = builtin.cpu.arch.endian();
+
const std = @import("../../std.zig");
const tls = std.crypto.tls;
const Client = @This();
-const net = std.net;
const mem = std.mem;
const crypto = std.crypto;
const assert = std.debug.assert;
const Certificate = std.crypto.Certificate;
+const Reader = std.io.Reader;
+const Writer = std.io.Writer;
const max_ciphertext_len = tls.max_ciphertext_len;
const hmacExpandLabel = tls.hmacExpandLabel;
@@ -13,44 +17,58 @@ const hkdfExpandLabel = tls.hkdfExpandLabel;
const int = tls.int;
const array = tls.array;
+/// The encrypted stream from the server to the client. Bytes are pulled from
+/// here via `reader`.
+///
+/// The buffer is asserted to have capacity at least `min_buffer_len`.
+input: *Reader,
+/// Decrypted stream from the server to the client.
+reader: Reader,
+
+/// The encrypted stream from the client to the server. Bytes are pushed here
+/// via `writer`.
+output: *Writer,
+/// The plaintext stream from the client to the server.
+writer: Writer,
+
+/// Populated when `error.TlsAlert` is returned.
+alert: ?tls.Alert = null,
+read_err: ?ReadError = null,
tls_version: tls.ProtocolVersion,
read_seq: u64,
write_seq: u64,
-/// The starting index of cleartext bytes inside `partially_read_buffer`.
-partial_cleartext_idx: u15,
-/// The ending index of cleartext bytes inside `partially_read_buffer` as well
-/// as the starting index of ciphertext bytes.
-partial_ciphertext_idx: u15,
-/// The ending index of ciphertext bytes inside `partially_read_buffer`.
-partial_ciphertext_end: u15,
/// When this is true, the stream may still not be at the end because there
-/// may be data in `partially_read_buffer`.
+/// may be data in the input buffer.
received_close_notify: bool,
-/// By default, reaching the end-of-stream when reading from the server will
-/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify
-/// message has been received. By setting this flag to `true`, instead, the
-/// end-of-stream will be forwarded to the application layer above TLS.
-/// This makes the application vulnerable to truncation attacks unless the
-/// application layer itself verifies that the amount of data received equals
-/// the amount of data expected, such as HTTP with the Content-Length header.
allow_truncation_attacks: bool,
application_cipher: tls.ApplicationCipher,
-/// The size is enough to contain exactly one TLSCiphertext record.
-/// This buffer is segmented into four parts:
-/// 0. unused
-/// 1. cleartext
-/// 2. ciphertext
-/// 3. unused
-/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and
-/// `partial_ciphertext_end` describe the span of the segments.
-partially_read_buffer: [tls.max_ciphertext_record_len]u8,
-/// If non-null, ssl secrets are logged to a file. Creating such a log file allows other
-/// programs with access to that file to decrypt all traffic over this connection.
-ssl_key_log: ?struct {
+
+/// If non-null, ssl secrets are logged to a stream. Creating such a log file
+/// allows other programs with access to that file to decrypt all traffic over
+/// this connection.
+ssl_key_log: ?*SslKeyLog,
+
+pub const ReadError = error{
+ /// The alert description will be stored in `alert`.
+ TlsAlert,
+ TlsBadLength,
+ TlsBadRecordMac,
+ TlsConnectionTruncated,
+ TlsDecodeError,
+ TlsRecordOverflow,
+ TlsUnexpectedMessage,
+ TlsIllegalParameter,
+ TlsSequenceOverflow,
+ /// The buffer provided to the read function was not at least
+ /// `min_buffer_len`.
+ OutputBufferUndersize,
+};
+
+pub const SslKeyLog = struct {
client_key_seq: u64,
server_key_seq: u64,
client_random: [32]u8,
- file: std.fs.File,
+ writer: *Writer,
fn clientCounter(key_log: *@This()) u64 {
defer key_log.client_key_seq += 1;
@@ -61,51 +79,12 @@ ssl_key_log: ?struct {
defer key_log.server_key_seq += 1;
return key_log.server_key_seq;
}
-},
-
-/// This is an example of the type that is needed by the read and write
-/// functions. It can have any fields but it must at least have these
-/// functions.
-///
-/// Note that `std.net.Stream` conforms to this interface.
-///
-/// This declaration serves as documentation only.
-pub const StreamInterface = struct {
- /// Can be any error set.
- pub const ReadError = error{};
-
- /// Returns the number of bytes read. The number read may be less than the
- /// buffer space provided. End-of-stream is indicated by a return value of 0.
- ///
- /// The `iovecs` parameter is mutable because so that function may to
- /// mutate the fields in order to handle partial reads from the underlying
- /// stream layer.
- pub fn readv(this: @This(), iovecs: []std.posix.iovec) ReadError!usize {
- _ = .{ this, iovecs };
- @panic("unimplemented");
- }
-
- /// Can be any error set.
- pub const WriteError = error{};
-
- /// Returns the number of bytes read, which may be less than the buffer
- /// space provided. A short read does not indicate end-of-stream.
- pub fn writev(this: @This(), iovecs: []const std.posix.iovec_const) WriteError!usize {
- _ = .{ this, iovecs };
- @panic("unimplemented");
- }
-
- /// Returns the number of bytes read, which may be less than the buffer
- /// space provided, indicating end-of-stream.
- /// The `iovecs` parameter is mutable in case this function needs to mutate
- /// the fields in order to handle partial writes from the underlying layer.
- pub fn writevAll(this: @This(), iovecs: []std.posix.iovec_const) WriteError!usize {
- // This can be implemented in terms of writev, or specialized if desired.
- _ = .{ this, iovecs };
- @panic("unimplemented");
- }
};
+/// The `Reader` supplied to `init` requires a buffer capacity
+/// at least this amount.
+pub const min_buffer_len = tls.max_ciphertext_record_len;
+
pub const Options = struct {
/// How to perform host verification of server certificates.
host: union(enum) {
@@ -127,64 +106,85 @@ pub const Options = struct {
/// Verify that the server certificate is authorized by a given ca bundle.
bundle: Certificate.Bundle,
},
- /// If non-null, ssl secrets are logged to this file. Creating such a log file allows
+ /// If non-null, ssl secrets are logged to this stream. Creating such a log file allows
/// other programs with access to that file to decrypt all traffic over this connection.
- ssl_key_log_file: ?std.fs.File = null,
+ ///
+ /// Only the `writer` field is observed during the handshake (`init`).
+ /// After that, the other fields are populated.
+ ssl_key_log: ?*SslKeyLog = null,
+ /// By default, reaching the end-of-stream when reading from the server will
+ /// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify
+ /// message has been received. By setting this flag to `true`, instead, the
+ /// end-of-stream will be forwarded to the application layer above TLS.
+ ///
+ /// This makes the application vulnerable to truncation attacks unless the
+ /// application layer itself verifies that the amount of data received equals
+ /// the amount of data expected, such as HTTP with the Content-Length header.
+ allow_truncation_attacks: bool = false,
+ write_buffer: []u8,
+ /// Asserted to have capacity at least `min_buffer_len`.
+ read_buffer: []u8,
+ /// Populated when `error.TlsAlert` is returned from `init`.
+ alert: ?*tls.Alert = null,
};
-pub fn InitError(comptime Stream: type) type {
- return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{
- InsufficientEntropy,
- DiskQuota,
- LockViolation,
- NotOpenForWriting,
- TlsUnexpectedMessage,
- TlsIllegalParameter,
- TlsDecryptFailure,
- TlsRecordOverflow,
- TlsBadRecordMac,
- CertificateFieldHasInvalidLength,
- CertificateHostMismatch,
- CertificatePublicKeyInvalid,
- CertificateExpired,
- CertificateFieldHasWrongDataType,
- CertificateIssuerMismatch,
- CertificateNotYetValid,
- CertificateSignatureAlgorithmMismatch,
- CertificateSignatureAlgorithmUnsupported,
- CertificateSignatureInvalid,
- CertificateSignatureInvalidLength,
- CertificateSignatureNamedCurveUnsupported,
- CertificateSignatureUnsupportedBitCount,
- TlsCertificateNotVerified,
- TlsBadSignatureScheme,
- TlsBadRsaSignatureBitCount,
- InvalidEncoding,
- IdentityElement,
- SignatureVerificationFailed,
- TlsDecryptError,
- TlsConnectionTruncated,
- TlsDecodeError,
- UnsupportedCertificateVersion,
- CertificateTimeInvalid,
- CertificateHasUnrecognizedObjectId,
- CertificateHasInvalidBitString,
- MessageTooLong,
- NegativeIntoUnsigned,
- TargetTooSmall,
- BufferTooSmall,
- InvalidSignature,
- NotSquare,
- NonCanonical,
- WeakPublicKey,
- };
-}
+const InitError = error{
+ WriteFailed,
+ ReadFailed,
+ InsufficientEntropy,
+ DiskQuota,
+ LockViolation,
+ NotOpenForWriting,
+ /// The alert description will be stored in `alert`.
+ TlsAlert,
+ TlsUnexpectedMessage,
+ TlsIllegalParameter,
+ TlsDecryptFailure,
+ TlsRecordOverflow,
+ TlsBadRecordMac,
+ CertificateFieldHasInvalidLength,
+ CertificateHostMismatch,
+ CertificatePublicKeyInvalid,
+ CertificateExpired,
+ CertificateFieldHasWrongDataType,
+ CertificateIssuerMismatch,
+ CertificateNotYetValid,
+ CertificateSignatureAlgorithmMismatch,
+ CertificateSignatureAlgorithmUnsupported,
+ CertificateSignatureInvalid,
+ CertificateSignatureInvalidLength,
+ CertificateSignatureNamedCurveUnsupported,
+ CertificateSignatureUnsupportedBitCount,
+ TlsCertificateNotVerified,
+ TlsBadSignatureScheme,
+ TlsBadRsaSignatureBitCount,
+ InvalidEncoding,
+ IdentityElement,
+ SignatureVerificationFailed,
+ TlsDecryptError,
+ TlsConnectionTruncated,
+ TlsDecodeError,
+ UnsupportedCertificateVersion,
+ CertificateTimeInvalid,
+ CertificateHasUnrecognizedObjectId,
+ CertificateHasInvalidBitString,
+ MessageTooLong,
+ NegativeIntoUnsigned,
+ TargetTooSmall,
+ BufferTooSmall,
+ InvalidSignature,
+ NotSquare,
+ NonCanonical,
+ WeakPublicKey,
+};
-/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session with `stream`, which
-/// must conform to `StreamInterface`.
+/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session.
///
/// `host` is only borrowed during this function call.
-pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client {
+///
+/// `input` is asserted to have buffer capacity at least `min_buffer_len`.
+pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client {
+ assert(input.buffer.len >= min_buffer_len);
const host = switch (options.host) {
.no_verification => "",
.explicit => |host| host,
@@ -276,11 +276,8 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
};
{
- var iovecs = [_]std.posix.iovec_const{
- .{ .base = cleartext_header.ptr, .len = cleartext_header.len },
- .{ .base = host.ptr, .len = host.len },
- };
- try stream.writevAll(iovecs[0..if (host.len == 0) 1 else 2]);
+ var iovecs: [2][]const u8 = .{ cleartext_header, host };
+ try output.writeVecAll(iovecs[0..if (host.len == 0) 1 else 2]);
}
var tls_version: tls.ProtocolVersion = undefined;
@@ -329,20 +326,26 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
var cleartext_fragment_start: usize = 0;
var cleartext_fragment_end: usize = 0;
var cleartext_bufs: [2][tls.max_ciphertext_inner_record_len]u8 = undefined;
- var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined;
- var d: tls.Decoder = .{ .buf = &handshake_buffer };
fragment: while (true) {
- try d.readAtLeastOurAmt(stream, tls.record_header_len);
- const record_header = d.buf[d.idx..][0..tls.record_header_len];
- const record_ct = d.decode(tls.ContentType);
- d.skip(2); // legacy_version
- const record_len = d.decode(u16);
- try d.readAtLeast(stream, record_len);
- var record_decoder = try d.sub(record_len);
+ // Ensure the input buffer pointer is stable in this scope.
+ input.rebaseCapacity(tls.max_ciphertext_record_len);
+ const record_header = input.peek(tls.record_header_len) catch |err| switch (err) {
+ error.EndOfStream => return error.TlsConnectionTruncated,
+ error.ReadFailed => return error.ReadFailed,
+ };
+ const record_ct = input.takeEnumNonexhaustive(tls.ContentType, .big) catch unreachable; // already peeked
+ input.toss(2); // legacy_version
+ const record_len = input.takeInt(u16, .big) catch unreachable; // already peeked
+ if (record_len > tls.max_ciphertext_len) return error.TlsRecordOverflow;
+ const record_buffer = input.take(record_len) catch |err| switch (err) {
+ error.EndOfStream => return error.TlsConnectionTruncated,
+ error.ReadFailed => return error.ReadFailed,
+ };
+ var record_decoder: tls.Decoder = .fromTheirSlice(record_buffer);
var ctd, const ct = content: switch (cipher_state) {
.cleartext => .{ record_decoder, record_ct },
.handshake => {
- std.debug.assert(tls_version == .tls_1_3);
+ assert(tls_version == .tls_1_3);
if (record_ct != .application_data) return error.TlsUnexpectedMessage;
try record_decoder.ensure(record_len);
const cleartext_buf = &cleartext_bufs[cert_buf_index % 2];
@@ -374,7 +377,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end])), ct };
},
.application => {
- std.debug.assert(tls_version == .tls_1_2);
+ assert(tls_version == .tls_1_2);
if (record_ct != .handshake) return error.TlsUnexpectedMessage;
try record_decoder.ensure(record_len);
const cleartext_buf = &cleartext_bufs[cert_buf_index % 2];
@@ -412,14 +415,11 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
switch (ct) {
.alert => {
ctd.ensure(2) catch continue :fragment;
- const level = ctd.decode(tls.AlertLevel);
- const desc = ctd.decode(tls.AlertDescription);
- _ = level;
-
- // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake
- try desc.toError();
- // TODO: handle server-side closures
- return error.TlsUnexpectedMessage;
+ if (options.alert) |a| a.* = .{
+ .level = ctd.decode(tls.Alert.Level),
+ .description = ctd.decode(tls.Alert.Description),
+ };
+ return error.TlsAlert;
},
.change_cipher_spec => {
ctd.ensure(1) catch continue :fragment;
@@ -533,7 +533,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
pv.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes);
const client_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length);
const server_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length);
- if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{
+ if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{
.client_random = &client_hello_rand,
}, .{
.SERVER_HANDSHAKE_TRAFFIC_SECRET = &server_secret,
@@ -707,7 +707,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
&client_hello_rand,
&server_hello_rand,
}, 48);
- if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{
+ if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{
.client_random = &client_hello_rand,
}, .{
.CLIENT_RANDOM = &master_secret,
@@ -755,11 +755,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
nonce,
pv.app_cipher.client_write_key,
);
- const all_msgs = client_key_exchange_msg ++ client_change_cipher_spec_msg ++ client_verify_msg;
- var all_msgs_vec = [_]std.posix.iovec_const{
- .{ .base = &all_msgs, .len = all_msgs.len },
+ var all_msgs_vec: [3][]const u8 = .{
+ &client_key_exchange_msg,
+ &client_change_cipher_spec_msg,
+ &client_verify_msg,
};
- try stream.writevAll(&all_msgs_vec);
+ try output.writeVecAll(&all_msgs_vec);
},
}
write_seq += 1;
@@ -820,15 +821,15 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
const nonce = pv.client_handshake_iv;
P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, pv.client_handshake_key);
- const all_msgs = client_change_cipher_spec_msg ++ finished_msg;
- var all_msgs_vec = [_]std.posix.iovec_const{
- .{ .base = &all_msgs, .len = all_msgs.len },
+ var all_msgs_vec: [2][]const u8 = .{
+ &client_change_cipher_spec_msg,
+ &finished_msg,
};
- try stream.writevAll(&all_msgs_vec);
+ try output.writeVecAll(&all_msgs_vec);
const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
- if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{
+ if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{
.counter = key_seq,
.client_random = &client_hello_rand,
}, .{
@@ -855,8 +856,28 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
else => unreachable,
},
};
- const leftover = d.rest();
- var client: Client = .{
+ if (options.ssl_key_log) |ssl_key_log| ssl_key_log.* = .{
+ .client_key_seq = key_seq,
+ .server_key_seq = key_seq,
+ .client_random = client_hello_rand,
+ .writer = ssl_key_log.writer,
+ };
+ return .{
+ .input = input,
+ .reader = .{
+ .buffer = options.read_buffer,
+ .vtable = &.{ .stream = stream },
+ .seek = 0,
+ .end = 0,
+ },
+ .output = output,
+ .writer = .{
+ .buffer = options.write_buffer,
+ .vtable = &.{
+ .drain = drain,
+ .sendFile = Writer.unimplementedSendFile,
+ },
+ },
.tls_version = tls_version,
.read_seq = switch (tls_version) {
.tls_1_3 => 0,
@@ -868,22 +889,11 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
.tls_1_2 => write_seq,
else => unreachable,
},
- .partial_cleartext_idx = 0,
- .partial_ciphertext_idx = 0,
- .partial_ciphertext_end = @intCast(leftover.len),
.received_close_notify = false,
- .allow_truncation_attacks = false,
+ .allow_truncation_attacks = options.allow_truncation_attacks,
.application_cipher = app_cipher,
- .partially_read_buffer = undefined,
- .ssl_key_log = if (options.ssl_key_log_file) |key_log_file| .{
- .client_key_seq = key_seq,
- .server_key_seq = key_seq,
- .client_random = client_hello_rand,
- .file = key_log_file,
- } else null,
+ .ssl_key_log = options.ssl_key_log,
};
- @memcpy(client.partially_read_buffer[0..leftover.len], leftover);
- return client;
},
else => return error.TlsUnexpectedMessage,
}
@@ -897,94 +907,48 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client
}
}
-/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
-/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`.
-pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize {
- return writeEnd(c, stream, bytes, false);
-}
-
-/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
-pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void {
- var index: usize = 0;
- while (index < bytes.len) {
- index += try c.write(stream, bytes[index..]);
- }
-}
-
-/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
-/// If `end` is true, then this function additionally sends a `close_notify` alert,
-/// which is necessary for the server to distinguish between a properly finished
-/// TLS session, or a truncation attack.
-pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void {
- var index: usize = 0;
- while (index < bytes.len) {
- index += try c.writeEnd(stream, bytes[index..], end);
+fn drain(w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize {
+ const c: *Client = @fieldParentPtr("writer", w);
+ if (true) @panic("update to use the buffer and flush");
+ const sliced_data = if (splat == 0) data[0..data.len -| 1] else data;
+ const output = c.output;
+ const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
+ var total_clear: usize = 0;
+ var ciphertext_end: usize = 0;
+ for (sliced_data) |buf| {
+ const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
+ total_clear += prepared.cleartext_len;
+ ciphertext_end += prepared.ciphertext_end;
+ if (total_clear < buf.len) break;
}
+ output.advance(ciphertext_end);
+ return total_clear;
}
-/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
-/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`.
-/// If `end` is true, then this function additionally sends a `close_notify` alert,
-/// which is necessary for the server to distinguish between a properly finished
-/// TLS session, or a truncation attack.
-pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize {
- var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined;
- var iovecs_buf: [6]std.posix.iovec_const = undefined;
- var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data);
- if (end) {
- prepared.iovec_end += prepareCiphertextRecord(
- c,
- iovecs_buf[prepared.iovec_end..],
- ciphertext_buf[prepared.ciphertext_end..],
- &tls.close_notify_alert,
- .alert,
- ).iovec_end;
- }
-
- const iovec_end = prepared.iovec_end;
- const overhead_len = prepared.overhead_len;
-
- // Ideally we would call writev exactly once here, however, we must ensure
- // that we don't return with a record partially written.
- var i: usize = 0;
- var total_amt: usize = 0;
- while (true) {
- var amt = try stream.writev(iovecs_buf[i..iovec_end]);
- while (amt >= iovecs_buf[i].len) {
- const encrypted_amt = iovecs_buf[i].len;
- total_amt += encrypted_amt - overhead_len;
- amt -= encrypted_amt;
- i += 1;
- // Rely on the property that iovecs delineate records, meaning that
- // if amt equals zero here, we have fortunately found ourselves
- // with a short read that aligns at the record boundary.
- if (i >= iovec_end) return total_amt;
- // We also cannot return on a vector boundary if the final close_notify is
- // not sent; otherwise the caller would not know to retry the call.
- if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt;
- }
- iovecs_buf[i].base += amt;
- iovecs_buf[i].len -= amt;
- }
+/// Sends a `close_notify` alert, which is necessary for the server to
+/// distinguish between a properly finished TLS session, or a truncation
+/// attack.
+pub fn end(c: *Client) Writer.Error!void {
+ const output = c.output;
+ const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
+ const prepared = prepareCiphertextRecord(c, ciphertext_buf, &tls.close_notify_alert, .alert);
+ output.advance(prepared.cleartext_len);
+ return prepared.ciphertext_end;
}
fn prepareCiphertextRecord(
c: *Client,
- iovecs: []std.posix.iovec_const,
ciphertext_buf: []u8,
bytes: []const u8,
inner_content_type: tls.ContentType,
) struct {
- iovec_end: usize,
ciphertext_end: usize,
- /// How many bytes are taken up by overhead per record.
- overhead_len: usize,
+ cleartext_len: usize,
} {
// Due to the trailing inner content type byte in the ciphertext, we need
// an additional buffer for storing the cleartext into before encrypting.
var cleartext_buf: [max_ciphertext_len]u8 = undefined;
var ciphertext_end: usize = 0;
- var iovec_end: usize = 0;
var bytes_i: usize = 0;
switch (c.application_cipher) {
inline else => |*p| switch (c.tls_version) {
@@ -992,18 +956,15 @@ fn prepareCiphertextRecord(
const pv = &p.tls_1_3;
const P = @TypeOf(p.*);
const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1;
- const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len;
while (true) {
const encrypted_content_len: u16 = @min(
bytes.len - bytes_i,
tls.max_ciphertext_inner_record_len,
- ciphertext_buf.len -|
- (close_notify_alert_reserved + overhead_len + ciphertext_end),
+ ciphertext_buf.len -| (overhead_len + ciphertext_end),
);
if (encrypted_content_len == 0) return .{
- .iovec_end = iovec_end,
.ciphertext_end = ciphertext_end,
- .overhead_len = overhead_len,
+ .cleartext_len = bytes_i,
};
@memcpy(cleartext_buf[0..encrypted_content_len], bytes[bytes_i..][0..encrypted_content_len]);
@@ -1012,7 +973,6 @@ fn prepareCiphertextRecord(
const ciphertext_len = encrypted_content_len + 1;
const cleartext = cleartext_buf[0..ciphertext_len];
- const record_start = ciphertext_end;
const ad = ciphertext_buf[ciphertext_end..][0..tls.record_header_len];
ad.* = .{@intFromEnum(tls.ContentType.application_data)} ++
int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++
@@ -1030,38 +990,27 @@ fn prepareCiphertextRecord(
};
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_key);
c.write_seq += 1; // TODO send key_update on overflow
-
- const record = ciphertext_buf[record_start..ciphertext_end];
- iovecs[iovec_end] = .{
- .base = record.ptr,
- .len = record.len,
- };
- iovec_end += 1;
}
},
.tls_1_2 => {
const pv = &p.tls_1_2;
const P = @TypeOf(p.*);
const overhead_len = tls.record_header_len + P.record_iv_length + P.mac_length;
- const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len;
while (true) {
const message_len: u16 = @min(
bytes.len - bytes_i,
tls.max_ciphertext_inner_record_len,
- ciphertext_buf.len -|
- (close_notify_alert_reserved + overhead_len + ciphertext_end),
+ ciphertext_buf.len -| (overhead_len + ciphertext_end),
);
if (message_len == 0) return .{
- .iovec_end = iovec_end,
.ciphertext_end = ciphertext_end,
- .overhead_len = overhead_len,
+ .cleartext_len = bytes_i,
};
@memcpy(cleartext_buf[0..message_len], bytes[bytes_i..][0..message_len]);
bytes_i += message_len;
const cleartext = cleartext_buf[0..message_len];
- const record_start = ciphertext_end;
const record_header = ciphertext_buf[ciphertext_end..][0..tls.record_header_len];
ciphertext_end += tls.record_header_len;
record_header.* = .{@intFromEnum(inner_content_type)} ++
@@ -1083,13 +1032,6 @@ fn prepareCiphertextRecord(
ciphertext_end += P.mac_length;
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_write_key);
c.write_seq += 1; // TODO send key_update on overflow
-
- const record = ciphertext_buf[record_start..ciphertext_end];
- iovecs[iovec_end] = .{
- .base = record.ptr,
- .len = record.len,
- };
- iovec_end += 1;
}
},
else => unreachable,
@@ -1098,421 +1040,194 @@ fn prepareCiphertextRecord(
}
pub fn eof(c: Client) bool {
- return c.received_close_notify and
- c.partial_cleartext_idx >= c.partial_ciphertext_idx and
- c.partial_ciphertext_idx >= c.partial_ciphertext_end;
-}
-
-/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
-/// Returns the number of bytes read, calling the underlying read function the
-/// minimal number of times until the buffer has at least `len` bytes filled.
-/// If the number read is less than `len` it means the stream reached the end.
-/// Reaching the end of the stream is not an error condition.
-pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize {
- var iovecs = [1]std.posix.iovec{.{ .base = buffer.ptr, .len = buffer.len }};
- return readvAtLeast(c, stream, &iovecs, len);
-}
-
-/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
-pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize {
- return readAtLeast(c, stream, buffer, 1);
-}
-
-/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
-/// Returns the number of bytes read. If the number read is smaller than
-/// `buffer.len`, it means the stream reached the end. Reaching the end of the
-/// stream is not an error condition.
-pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize {
- return readAtLeast(c, stream, buffer, buffer.len);
-}
-
-/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
-/// Returns the number of bytes read. If the number read is less than the space
-/// provided it means the stream reached the end. Reaching the end of the
-/// stream is not an error condition.
-/// The `iovecs` parameter is mutable because this function needs to mutate the fields in
-/// order to handle partial reads from the underlying stream layer.
-pub fn readv(c: *Client, stream: anytype, iovecs: []std.posix.iovec) !usize {
- return readvAtLeast(c, stream, iovecs, 1);
-}
-
-/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
-/// Returns the number of bytes read, calling the underlying read function the
-/// minimal number of times until the iovecs have at least `len` bytes filled.
-/// If the number read is less than `len` it means the stream reached the end.
-/// Reaching the end of the stream is not an error condition.
-/// The `iovecs` parameter is mutable because this function needs to mutate the fields in
-/// order to handle partial reads from the underlying stream layer.
-pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.posix.iovec, len: usize) !usize {
- if (c.eof()) return 0;
-
- var off_i: usize = 0;
- var vec_i: usize = 0;
- while (true) {
- var amt = try c.readvAdvanced(stream, iovecs[vec_i..]);
- off_i += amt;
- if (c.eof() or off_i >= len) return off_i;
- while (amt >= iovecs[vec_i].len) {
- amt -= iovecs[vec_i].len;
- vec_i += 1;
- }
- iovecs[vec_i].base += amt;
- iovecs[vec_i].len -= amt;
- }
+ return c.received_close_notify;
}
-/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
-/// Returns number of bytes that have been read, populated inside `iovecs`. A
-/// return value of zero bytes does not mean end of stream. Instead, check the `eof()`
-/// for the end of stream. The `eof()` may be true after any call to
-/// `read`, including when greater than zero bytes are returned, and this
-/// function asserts that `eof()` is `false`.
-/// See `readv` for a higher level function that has the same, familiar API as
-/// other read functions, such as `std.fs.File.read`.
-pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iovec) !usize {
- var vp: VecPut = .{ .iovecs = iovecs };
-
- // Give away the buffered cleartext we have, if any.
- const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx];
- if (partial_cleartext.len > 0) {
- const amt: u15 = @intCast(vp.put(partial_cleartext));
- c.partial_cleartext_idx += amt;
-
- if (c.partial_cleartext_idx == c.partial_ciphertext_idx and
- c.partial_ciphertext_end == c.partial_ciphertext_idx)
- {
- // The buffer is now empty.
- c.partial_cleartext_idx = 0;
- c.partial_ciphertext_idx = 0;
- c.partial_ciphertext_end = 0;
- }
-
- if (c.received_close_notify) {
- c.partial_ciphertext_end = 0;
- assert(vp.total == amt);
- return amt;
- } else if (amt > 0) {
- // We don't need more data, so don't call read.
- assert(vp.total == amt);
- return amt;
- }
+fn stream(r: *Reader, w: *Writer, limit: std.io.Limit) Reader.StreamError!usize {
+ const c: *Client = @fieldParentPtr("reader", r);
+ if (c.eof()) return error.EndOfStream;
+ const input = c.input;
+ // If at least one full encrypted record is not buffered, read once.
+ const record_header = input.peek(tls.record_header_len) catch |err| switch (err) {
+ error.EndOfStream => {
+ // This is either a truncation attack, a bug in the server, or an
+ // intentional omission of the close_notify message due to truncation
+ // detection handled above the TLS layer.
+ if (c.allow_truncation_attacks) {
+ c.received_close_notify = true;
+ return error.EndOfStream;
+ } else {
+ return failRead(c, error.TlsConnectionTruncated);
+ }
+ },
+ error.ReadFailed => return error.ReadFailed,
+ };
+ const ct: tls.ContentType = @enumFromInt(record_header[0]);
+ const legacy_version = mem.readInt(u16, record_header[1..][0..2], .big);
+ _ = legacy_version;
+ const record_len = mem.readInt(u16, record_header[3..][0..2], .big);
+ if (record_len > max_ciphertext_len) return failRead(c, error.TlsRecordOverflow);
+ const record_end = 5 + record_len;
+ if (record_end > input.buffered().len) {
+ input.fillMore() catch |err| switch (err) {
+ error.EndOfStream => return failRead(c, error.TlsConnectionTruncated),
+ error.ReadFailed => return error.ReadFailed,
+ };
+ if (record_end > input.buffered().len) return 0;
}
- assert(!c.received_close_notify);
-
- // Ideally, this buffer would never be used. It is needed when `iovecs` are
- // too small to fit the cleartext, which may be as large as `max_ciphertext_len`.
var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined;
- // Temporarily stores ciphertext before decrypting it and giving it to `iovecs`.
- var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined;
- // How many bytes left in the user's buffer.
- const free_size = vp.freeSize();
- // The amount of the user's buffer that we need to repurpose for storing
- // ciphertext. The end of the buffer will be used for such purposes.
- const ciphertext_buf_len = (free_size / 2) -| in_stack_buffer.len;
- // The amount of the user's buffer that will be used to give cleartext. The
- // beginning of the buffer will be used for such purposes.
- const cleartext_buf_len = free_size - ciphertext_buf_len;
-
- // Recoup `partially_read_buffer` space. This is necessary because it is assumed
- // below that `frag0` is big enough to hold at least one record.
- limitedOverlapCopy(c.partially_read_buffer[0..c.partial_ciphertext_end], c.partial_ciphertext_idx);
- c.partial_ciphertext_end -= c.partial_ciphertext_idx;
- c.partial_ciphertext_idx = 0;
- c.partial_cleartext_idx = 0;
- const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..];
-
- var ask_iovecs_buf: [2]std.posix.iovec = .{
- .{
- .base = first_iov.ptr,
- .len = first_iov.len,
- },
- .{
- .base = &in_stack_buffer,
- .len = in_stack_buffer.len,
+ const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) {
+ inline else => |*p| switch (c.tls_version) {
+ .tls_1_3 => {
+ const pv = &p.tls_1_3;
+ const P = @TypeOf(p.*);
+ const ad = input.take(tls.record_header_len) catch unreachable; // already peeked
+ const ciphertext_len = record_len - P.AEAD.tag_length;
+ const ciphertext = input.take(ciphertext_len) catch unreachable; // already peeked
+ const auth_tag = (input.takeArray(P.AEAD.tag_length) catch unreachable).*; // already peeked
+ const nonce = nonce: {
+ const V = @Vector(P.AEAD.nonce_length, u8);
+ const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
+ const operand: V = pad ++ std.mem.toBytes(big(c.read_seq));
+ break :nonce @as(V, pv.server_iv) ^ operand;
+ };
+ const cleartext = cleartext_stack_buffer[0..ciphertext.len];
+ P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch
+ return failRead(c, error.TlsBadRecordMac);
+ const msg = mem.trimRight(u8, cleartext, "\x00");
+ break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) };
+ },
+ .tls_1_2 => {
+ const pv = &p.tls_1_2;
+ const P = @TypeOf(p.*);
+ const message_len: u16 = record_len - P.record_iv_length - P.mac_length;
+ const ad_header = input.take(tls.record_header_len) catch unreachable; // already peeked
+ const ad = std.mem.toBytes(big(c.read_seq)) ++
+ ad_header[0 .. 1 + 2] ++
+ std.mem.toBytes(big(message_len));
+ const record_iv = (input.takeArray(P.record_iv_length) catch unreachable).*; // already peeked
+ const masked_read_seq = c.read_seq &
+ comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length);
+ const nonce: [P.AEAD.nonce_length]u8 = nonce: {
+ const V = @Vector(P.AEAD.nonce_length, u8);
+ const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
+ const operand: V = pad ++ @as([8]u8, @bitCast(big(masked_read_seq)));
+ break :nonce @as(V, pv.server_write_IV ++ record_iv) ^ operand;
+ };
+ const ciphertext = input.take(message_len) catch unreachable; // already peeked
+ const auth_tag = (input.takeArray(P.mac_length) catch unreachable).*; // already peeked
+ const cleartext = cleartext_stack_buffer[0..ciphertext.len];
+ P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch
+ return failRead(c, error.TlsBadRecordMac);
+ break :cleartext .{ cleartext, ct };
+ },
+ else => unreachable,
},
};
-
- // Cleartext capacity of output buffer, in records. Minimum one full record.
- const buf_cap = @max(cleartext_buf_len / max_ciphertext_len, 1);
- const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len);
- const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len) - c.partial_ciphertext_end;
- const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len);
- const actual_read_len = try stream.readv(ask_iovecs);
- if (actual_read_len == 0) {
- // This is either a truncation attack, a bug in the server, or an
- // intentional omission of the close_notify message due to truncation
- // detection handled above the TLS layer.
- if (c.allow_truncation_attacks) {
- c.received_close_notify = true;
- } else {
- return error.TlsConnectionTruncated;
- }
- }
-
- // There might be more bytes inside `in_stack_buffer` that need to be processed,
- // but at least frag0 will have one complete ciphertext record.
- const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len);
- const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end];
- var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len];
- // We need to decipher frag0 and frag1 but there may be a ciphertext record
- // straddling the boundary. We can handle this with two memcpy() calls to
- // assemble the straddling record in between handling the two sides.
- var frag = frag0;
- var in: usize = 0;
- while (true) {
- if (in == frag.len) {
- // Perfect split.
- if (frag.ptr == frag1.ptr) {
- c.partial_ciphertext_end = c.partial_ciphertext_idx;
- return vp.total;
- }
- frag = frag1;
- in = 0;
- continue;
- }
-
- if (in + tls.record_header_len > frag.len) {
- if (frag.ptr == frag1.ptr)
- return finishRead(c, frag, in, vp.total);
-
- const first = frag[in..];
-
- if (frag1.len < tls.record_header_len)
- return finishRead2(c, first, frag1, vp.total);
-
- // A record straddles the two fragments. Copy into the now-empty first fragment.
- const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3);
- const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4);
- const record_len = (record_len_byte_0 << 8) | record_len_byte_1;
- if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
-
- const full_record_len = record_len + tls.record_header_len;
- const second_len = full_record_len - first.len;
- if (frag1.len < second_len)
- return finishRead2(c, first, frag1, vp.total);
-
- limitedOverlapCopy(frag, in);
- @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]);
- frag = frag[0..full_record_len];
- frag1 = frag1[second_len..];
- in = 0;
- continue;
- }
- const ct: tls.ContentType = @enumFromInt(frag[in]);
- in += 1;
- const legacy_version = mem.readInt(u16, frag[in..][0..2], .big);
- in += 2;
- _ = legacy_version;
- const record_len = mem.readInt(u16, frag[in..][0..2], .big);
- if (record_len > max_ciphertext_len) return error.TlsRecordOverflow;
- in += 2;
- const end = in + record_len;
- if (end > frag.len) {
- // We need the record header on the next iteration of the loop.
- in -= tls.record_header_len;
-
- if (frag.ptr == frag1.ptr)
- return finishRead(c, frag, in, vp.total);
-
- // A record straddles the two fragments. Copy into the now-empty first fragment.
- const first = frag[in..];
- const full_record_len = record_len + tls.record_header_len;
- const second_len = full_record_len - first.len;
- if (frag1.len < second_len)
- return finishRead2(c, first, frag1, vp.total);
-
- limitedOverlapCopy(frag, in);
- @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]);
- frag = frag[0..full_record_len];
- frag1 = frag1[second_len..];
- in = 0;
- continue;
- }
- const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) {
- inline else => |*p| switch (c.tls_version) {
- .tls_1_3 => {
- const pv = &p.tls_1_3;
- const P = @TypeOf(p.*);
- const ad = frag[in - tls.record_header_len ..][0..tls.record_header_len];
- const ciphertext_len = record_len - P.AEAD.tag_length;
- const ciphertext = frag[in..][0..ciphertext_len];
- in += ciphertext_len;
- const auth_tag = frag[in..][0..P.AEAD.tag_length].*;
- const nonce = nonce: {
- const V = @Vector(P.AEAD.nonce_length, u8);
- const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
- const operand: V = pad ++ std.mem.toBytes(big(c.read_seq));
- break :nonce @as(V, pv.server_iv) ^ operand;
- };
- const out_buf = vp.peek();
- const cleartext_buf = if (ciphertext.len <= out_buf.len)
- out_buf
- else
- &cleartext_stack_buffer;
- const cleartext = cleartext_buf[0..ciphertext.len];
- P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch
- return error.TlsBadRecordMac;
- const msg = mem.trimEnd(u8, cleartext, "\x00");
- break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) };
+ c.read_seq = std.math.add(u64, c.read_seq, 1) catch return failRead(c, error.TlsSequenceOverflow);
+ switch (inner_ct) {
+ .alert => {
+ if (cleartext.len != 2) return failRead(c, error.TlsDecodeError);
+ const alert: tls.Alert = .{
+ .level = @enumFromInt(cleartext[0]),
+ .description = @enumFromInt(cleartext[1]),
+ };
+ switch (alert.description) {
+ .close_notify => {
+ c.received_close_notify = true;
+ return 0;
},
- .tls_1_2 => {
- const pv = &p.tls_1_2;
- const P = @TypeOf(p.*);
- const message_len: u16 = record_len - P.record_iv_length - P.mac_length;
- const ad = std.mem.toBytes(big(c.read_seq)) ++
- frag[in - tls.record_header_len ..][0 .. 1 + 2] ++
- std.mem.toBytes(big(message_len));
- const record_iv = frag[in..][0..P.record_iv_length].*;
- in += P.record_iv_length;
- const masked_read_seq = c.read_seq &
- comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length);
- const nonce: [P.AEAD.nonce_length]u8 = nonce: {
- const V = @Vector(P.AEAD.nonce_length, u8);
- const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8);
- const operand: V = pad ++ @as([8]u8, @bitCast(big(masked_read_seq)));
- break :nonce @as(V, pv.server_write_IV ++ record_iv) ^ operand;
- };
- const ciphertext = frag[in..][0..message_len];
- in += message_len;
- const auth_tag = frag[in..][0..P.mac_length].*;
- in += P.mac_length;
- const out_buf = vp.peek();
- const cleartext_buf = if (message_len <= out_buf.len)
- out_buf
- else
- &cleartext_stack_buffer;
- const cleartext = cleartext_buf[0..ciphertext.len];
- P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch
- return error.TlsBadRecordMac;
- break :cleartext .{ cleartext, ct };
+ .user_canceled => {
+ // TODO: handle server-side closures
+ return failRead(c, error.TlsUnexpectedMessage);
},
- else => unreachable,
- },
- };
- c.read_seq = try std.math.add(u64, c.read_seq, 1);
- switch (inner_ct) {
- .alert => {
- if (cleartext.len != 2) return error.TlsDecodeError;
- const level: tls.AlertLevel = @enumFromInt(cleartext[0]);
- const desc: tls.AlertDescription = @enumFromInt(cleartext[1]);
- if (desc == .close_notify) {
- c.received_close_notify = true;
- c.partial_ciphertext_end = c.partial_ciphertext_idx;
- return vp.total;
- }
- _ = level;
-
- try desc.toError();
- // TODO: handle server-side closures
- return error.TlsUnexpectedMessage;
- },
- .handshake => {
- var ct_i: usize = 0;
- while (true) {
- const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]);
- ct_i += 1;
- const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big);
- ct_i += 3;
- const next_handshake_i = ct_i + handshake_len;
- if (next_handshake_i > cleartext.len)
- return error.TlsBadLength;
- const handshake = cleartext[ct_i..next_handshake_i];
- switch (handshake_type) {
- .new_session_ticket => {
- // This client implementation ignores new session tickets.
- },
- .key_update => {
- switch (c.application_cipher) {
- inline else => |*p| {
- const pv = &p.tls_1_3;
- const P = @TypeOf(p.*);
- const server_secret = hkdfExpandLabel(P.Hkdf, pv.server_secret, "traffic upd", "", P.Hash.digest_length);
- if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{
- .counter = key_log.serverCounter(),
- .client_random = &key_log.client_random,
- }, .{
- .SERVER_TRAFFIC_SECRET = &server_secret,
- });
- pv.server_secret = server_secret;
- pv.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length);
- pv.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length);
- },
- }
- c.read_seq = 0;
-
- switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) {
- .update_requested => {
- switch (c.application_cipher) {
- inline else => |*p| {
- const pv = &p.tls_1_3;
- const P = @TypeOf(p.*);
- const client_secret = hkdfExpandLabel(P.Hkdf, pv.client_secret, "traffic upd", "", P.Hash.digest_length);
- if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{
- .counter = key_log.clientCounter(),
- .client_random = &key_log.client_random,
- }, .{
- .CLIENT_TRAFFIC_SECRET = &client_secret,
- });
- pv.client_secret = client_secret;
- pv.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length);
- pv.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length);
- },
- }
- c.write_seq = 0;
- },
- .update_not_requested => {},
- _ => return error.TlsIllegalParameter,
- }
- },
- else => {
- return error.TlsUnexpectedMessage;
- },
- }
- ct_i = next_handshake_i;
- if (ct_i >= cleartext.len) break;
- }
- },
- .application_data => {
- // Determine whether the output buffer or a stack
- // buffer was used for storing the cleartext.
- if (cleartext.ptr == &cleartext_stack_buffer) {
- // Stack buffer was used, so we must copy to the output buffer.
- if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
- // We have already run out of room in iovecs. Continue
- // appending to `partially_read_buffer`.
- @memcpy(
- c.partially_read_buffer[c.partial_ciphertext_idx..][0..cleartext.len],
- cleartext,
- );
- c.partial_ciphertext_idx = @intCast(c.partial_ciphertext_idx + cleartext.len);
- } else {
- const amt = vp.put(cleartext);
- if (amt < cleartext.len) {
- const rest = cleartext[amt..];
- c.partial_cleartext_idx = 0;
- c.partial_ciphertext_idx = @intCast(rest.len);
- @memcpy(c.partially_read_buffer[0..rest.len], rest);
+ else => {
+ c.alert = alert;
+ return failRead(c, error.TlsAlert);
+ },
+ }
+ },
+ .handshake => {
+ var ct_i: usize = 0;
+ while (true) {
+ const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]);
+ ct_i += 1;
+ const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big);
+ ct_i += 3;
+ const next_handshake_i = ct_i + handshake_len;
+ if (next_handshake_i > cleartext.len) return failRead(c, error.TlsBadLength);
+ const handshake = cleartext[ct_i..next_handshake_i];
+ switch (handshake_type) {
+ .new_session_ticket => {
+ // This client implementation ignores new session tickets.
+ },
+ .key_update => {
+ switch (c.application_cipher) {
+ inline else => |*p| {
+ const pv = &p.tls_1_3;
+ const P = @TypeOf(p.*);
+ const server_secret = hkdfExpandLabel(P.Hkdf, pv.server_secret, "traffic upd", "", P.Hash.digest_length);
+ if (c.ssl_key_log) |key_log| logSecrets(key_log.writer, .{
+ .counter = key_log.serverCounter(),
+ .client_random = &key_log.client_random,
+ }, .{
+ .SERVER_TRAFFIC_SECRET = &server_secret,
+ });
+ pv.server_secret = server_secret;
+ pv.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length);
+ pv.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length);
+ },
+ }
+ c.read_seq = 0;
+
+ switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) {
+ .update_requested => {
+ switch (c.application_cipher) {
+ inline else => |*p| {
+ const pv = &p.tls_1_3;
+ const P = @TypeOf(p.*);
+ const client_secret = hkdfExpandLabel(P.Hkdf, pv.client_secret, "traffic upd", "", P.Hash.digest_length);
+ if (c.ssl_key_log) |key_log| logSecrets(key_log.writer, .{
+ .counter = key_log.clientCounter(),
+ .client_random = &key_log.client_random,
+ }, .{
+ .CLIENT_TRAFFIC_SECRET = &client_secret,
+ });
+ pv.client_secret = client_secret;
+ pv.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length);
+ pv.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length);
+ },
+ }
+ c.write_seq = 0;
+ },
+ .update_not_requested => {},
+ _ => return failRead(c, error.TlsIllegalParameter),
}
- }
- } else {
- // Output buffer was used directly which means no
- // memory copying needs to occur, and we can move
- // on to the next ciphertext record.
- vp.next(cleartext.len);
+ },
+ else => return failRead(c, error.TlsUnexpectedMessage),
}
- },
- else => return error.TlsUnexpectedMessage,
- }
- in = end;
+ ct_i = next_handshake_i;
+ if (ct_i >= cleartext.len) break;
+ }
+ return 0;
+ },
+ .application_data => {
+ if (@intFromEnum(limit) < cleartext.len) return failRead(c, error.OutputBufferUndersize);
+ try w.writeAll(cleartext);
+ return cleartext.len;
+ },
+ else => return failRead(c, error.TlsUnexpectedMessage),
}
}
-fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) void {
- const locked = if (key_log_file.lock(.exclusive)) |_| true else |_| false;
- defer if (locked) key_log_file.unlock();
- key_log_file.seekFromEnd(0) catch {};
- inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| key_log_file.deprecatedWriter().print("{s}" ++
+fn failRead(c: *Client, err: ReadError) error{ReadFailed} {
+ c.read_err = err;
+ return error.ReadFailed;
+}
+
+fn logSecrets(w: *Writer, context: anytype, secrets: anytype) void {
+ inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| w.print("{s}" ++
(if (@hasField(@TypeOf(context), "counter")) "_{d}" else "") ++ " {x} {x}\n", .{field.name} ++
(if (@hasField(@TypeOf(context), "counter")) .{context.counter} else .{}) ++ .{
context.client_random,
@@ -1520,62 +1235,6 @@ fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) voi
}) catch {};
}
-fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize {
- const saved_buf = frag[in..];
- if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
- // There is cleartext at the beginning already which we need to preserve.
- c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + saved_buf.len);
- @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx..][0..saved_buf.len], saved_buf);
- } else {
- c.partial_cleartext_idx = 0;
- c.partial_ciphertext_idx = 0;
- c.partial_ciphertext_end = @intCast(saved_buf.len);
- @memcpy(c.partially_read_buffer[0..saved_buf.len], saved_buf);
- }
- return out;
-}
-
-/// Note that `first` usually overlaps with `c.partially_read_buffer`.
-fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize {
- if (c.partial_ciphertext_idx > c.partial_cleartext_idx) {
- // There is cleartext at the beginning already which we need to preserve.
- c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + first.len + frag1.len);
- // TODO: eliminate this call to copyForwards
- std.mem.copyForwards(u8, c.partially_read_buffer[c.partial_ciphertext_idx..][0..first.len], first);
- @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..][0..frag1.len], frag1);
- } else {
- c.partial_cleartext_idx = 0;
- c.partial_ciphertext_idx = 0;
- c.partial_ciphertext_end = @intCast(first.len + frag1.len);
- // TODO: eliminate this call to copyForwards
- std.mem.copyForwards(u8, c.partially_read_buffer[0..first.len], first);
- @memcpy(c.partially_read_buffer[first.len..][0..frag1.len], frag1);
- }
- return out;
-}
-
-fn limitedOverlapCopy(frag: []u8, in: usize) void {
- const first = frag[in..];
- if (first.len <= in) {
- // A single, non-overlapping memcpy suffices.
- @memcpy(frag[0..first.len], first);
- } else {
- // One memcpy call would overlap, so just do this instead.
- std.mem.copyForwards(u8, frag, first);
- }
-}
-
-fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 {
- if (index < s1.len) {
- return s1[index];
- } else {
- return s2[index - s1.len];
- }
-}
-
-const builtin = @import("builtin");
-const native_endian = builtin.cpu.arch.endian();
-
fn big(x: anytype) @TypeOf(x) {
return switch (native_endian) {
.big => x,
@@ -1836,81 +1495,6 @@ const CertificatePublicKey = struct {
}
};
-/// Abstraction for sending multiple byte buffers to a slice of iovecs.
-const VecPut = struct {
- iovecs: []const std.posix.iovec,
- idx: usize = 0,
- off: usize = 0,
- total: usize = 0,
-
- /// Returns the amount actually put which is always equal to bytes.len
- /// unless the vectors ran out of space.
- fn put(vp: *VecPut, bytes: []const u8) usize {
- if (vp.idx >= vp.iovecs.len) return 0;
- var bytes_i: usize = 0;
- while (true) {
- const v = vp.iovecs[vp.idx];
- const dest = v.base[vp.off..v.len];
- const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)];
- @memcpy(dest[0..src.len], src);
- bytes_i += src.len;
- vp.off += src.len;
- if (vp.off >= v.len) {
- vp.off = 0;
- vp.idx += 1;
- if (vp.idx >= vp.iovecs.len) {
- vp.total += bytes_i;
- return bytes_i;
- }
- }
- if (bytes_i >= bytes.len) {
- vp.total += bytes_i;
- return bytes_i;
- }
- }
- }
-
- /// Returns the next buffer that consecutive bytes can go into.
- fn peek(vp: VecPut) []u8 {
- if (vp.idx >= vp.iovecs.len) return &.{};
- const v = vp.iovecs[vp.idx];
- return v.base[vp.off..v.len];
- }
-
- // After writing to the result of peek(), one can call next() to
- // advance the cursor.
- fn next(vp: *VecPut, len: usize) void {
- vp.total += len;
- vp.off += len;
- if (vp.off >= vp.iovecs[vp.idx].len) {
- vp.off = 0;
- vp.idx += 1;
- }
- }
-
- fn freeSize(vp: VecPut) usize {
- if (vp.idx >= vp.iovecs.len) return 0;
- var total: usize = 0;
- total += vp.iovecs[vp.idx].len - vp.off;
- if (vp.idx + 1 >= vp.iovecs.len) return total;
- for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.len;
- return total;
- }
-};
-
-/// Limit iovecs to a specific byte size.
-fn limitVecs(iovecs: []std.posix.iovec, len: usize) []std.posix.iovec {
- var bytes_left: usize = len;
- for (iovecs, 0..) |*iovec, vec_i| {
- if (bytes_left <= iovec.len) {
- iovec.len = bytes_left;
- return iovecs[0 .. vec_i + 1];
- }
- bytes_left -= iovec.len;
- }
- return iovecs;
-}
-
/// The priority order here is chosen based on what crypto algorithms Zig has
/// available in the standard library as well as what is faster. Following are
/// a few data points on the relative performance of these algorithms.
@@ -1954,7 +1538,3 @@ else
.AES_256_GCM_SHA384,
.ECDHE_RSA_WITH_AES_256_GCM_SHA384,
});
-
-test {
- _ = StreamInterface;
-}
lib/std/crypto/tls.zig
@@ -49,8 +49,8 @@ pub const hello_retry_request_sequence = [32]u8{
};
pub const close_notify_alert = [_]u8{
- @intFromEnum(AlertLevel.warning),
- @intFromEnum(AlertDescription.close_notify),
+ @intFromEnum(Alert.Level.warning),
+ @intFromEnum(Alert.Description.close_notify),
};
pub const ProtocolVersion = enum(u16) {
@@ -138,103 +138,108 @@ pub const ExtensionType = enum(u16) {
_,
};
-pub const AlertLevel = enum(u8) {
- warning = 1,
- fatal = 2,
- _,
-};
+pub const Alert = struct {
+ level: Level,
+ description: Description,
-pub const AlertDescription = enum(u8) {
- pub const Error = error{
- TlsAlertUnexpectedMessage,
- TlsAlertBadRecordMac,
- TlsAlertRecordOverflow,
- TlsAlertHandshakeFailure,
- TlsAlertBadCertificate,
- TlsAlertUnsupportedCertificate,
- TlsAlertCertificateRevoked,
- TlsAlertCertificateExpired,
- TlsAlertCertificateUnknown,
- TlsAlertIllegalParameter,
- TlsAlertUnknownCa,
- TlsAlertAccessDenied,
- TlsAlertDecodeError,
- TlsAlertDecryptError,
- TlsAlertProtocolVersion,
- TlsAlertInsufficientSecurity,
- TlsAlertInternalError,
- TlsAlertInappropriateFallback,
- TlsAlertMissingExtension,
- TlsAlertUnsupportedExtension,
- TlsAlertUnrecognizedName,
- TlsAlertBadCertificateStatusResponse,
- TlsAlertUnknownPskIdentity,
- TlsAlertCertificateRequired,
- TlsAlertNoApplicationProtocol,
- TlsAlertUnknown,
+ pub const Level = enum(u8) {
+ warning = 1,
+ fatal = 2,
+ _,
};
- close_notify = 0,
- unexpected_message = 10,
- bad_record_mac = 20,
- record_overflow = 22,
- handshake_failure = 40,
- bad_certificate = 42,
- unsupported_certificate = 43,
- certificate_revoked = 44,
- certificate_expired = 45,
- certificate_unknown = 46,
- illegal_parameter = 47,
- unknown_ca = 48,
- access_denied = 49,
- decode_error = 50,
- decrypt_error = 51,
- protocol_version = 70,
- insufficient_security = 71,
- internal_error = 80,
- inappropriate_fallback = 86,
- user_canceled = 90,
- missing_extension = 109,
- unsupported_extension = 110,
- unrecognized_name = 112,
- bad_certificate_status_response = 113,
- unknown_psk_identity = 115,
- certificate_required = 116,
- no_application_protocol = 120,
- _,
+ pub const Description = enum(u8) {
+ pub const Error = error{
+ TlsAlertUnexpectedMessage,
+ TlsAlertBadRecordMac,
+ TlsAlertRecordOverflow,
+ TlsAlertHandshakeFailure,
+ TlsAlertBadCertificate,
+ TlsAlertUnsupportedCertificate,
+ TlsAlertCertificateRevoked,
+ TlsAlertCertificateExpired,
+ TlsAlertCertificateUnknown,
+ TlsAlertIllegalParameter,
+ TlsAlertUnknownCa,
+ TlsAlertAccessDenied,
+ TlsAlertDecodeError,
+ TlsAlertDecryptError,
+ TlsAlertProtocolVersion,
+ TlsAlertInsufficientSecurity,
+ TlsAlertInternalError,
+ TlsAlertInappropriateFallback,
+ TlsAlertMissingExtension,
+ TlsAlertUnsupportedExtension,
+ TlsAlertUnrecognizedName,
+ TlsAlertBadCertificateStatusResponse,
+ TlsAlertUnknownPskIdentity,
+ TlsAlertCertificateRequired,
+ TlsAlertNoApplicationProtocol,
+ TlsAlertUnknown,
+ };
- pub fn toError(alert: AlertDescription) Error!void {
- switch (alert) {
- .close_notify => {}, // not an error
- .unexpected_message => return error.TlsAlertUnexpectedMessage,
- .bad_record_mac => return error.TlsAlertBadRecordMac,
- .record_overflow => return error.TlsAlertRecordOverflow,
- .handshake_failure => return error.TlsAlertHandshakeFailure,
- .bad_certificate => return error.TlsAlertBadCertificate,
- .unsupported_certificate => return error.TlsAlertUnsupportedCertificate,
- .certificate_revoked => return error.TlsAlertCertificateRevoked,
- .certificate_expired => return error.TlsAlertCertificateExpired,
- .certificate_unknown => return error.TlsAlertCertificateUnknown,
- .illegal_parameter => return error.TlsAlertIllegalParameter,
- .unknown_ca => return error.TlsAlertUnknownCa,
- .access_denied => return error.TlsAlertAccessDenied,
- .decode_error => return error.TlsAlertDecodeError,
- .decrypt_error => return error.TlsAlertDecryptError,
- .protocol_version => return error.TlsAlertProtocolVersion,
- .insufficient_security => return error.TlsAlertInsufficientSecurity,
- .internal_error => return error.TlsAlertInternalError,
- .inappropriate_fallback => return error.TlsAlertInappropriateFallback,
- .user_canceled => {}, // not an error
- .missing_extension => return error.TlsAlertMissingExtension,
- .unsupported_extension => return error.TlsAlertUnsupportedExtension,
- .unrecognized_name => return error.TlsAlertUnrecognizedName,
- .bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse,
- .unknown_psk_identity => return error.TlsAlertUnknownPskIdentity,
- .certificate_required => return error.TlsAlertCertificateRequired,
- .no_application_protocol => return error.TlsAlertNoApplicationProtocol,
- _ => return error.TlsAlertUnknown,
+ close_notify = 0,
+ unexpected_message = 10,
+ bad_record_mac = 20,
+ record_overflow = 22,
+ handshake_failure = 40,
+ bad_certificate = 42,
+ unsupported_certificate = 43,
+ certificate_revoked = 44,
+ certificate_expired = 45,
+ certificate_unknown = 46,
+ illegal_parameter = 47,
+ unknown_ca = 48,
+ access_denied = 49,
+ decode_error = 50,
+ decrypt_error = 51,
+ protocol_version = 70,
+ insufficient_security = 71,
+ internal_error = 80,
+ inappropriate_fallback = 86,
+ user_canceled = 90,
+ missing_extension = 109,
+ unsupported_extension = 110,
+ unrecognized_name = 112,
+ bad_certificate_status_response = 113,
+ unknown_psk_identity = 115,
+ certificate_required = 116,
+ no_application_protocol = 120,
+ _,
+
+ pub fn toError(description: Description) Error!void {
+ switch (description) {
+ .close_notify => {}, // not an error
+ .unexpected_message => return error.TlsAlertUnexpectedMessage,
+ .bad_record_mac => return error.TlsAlertBadRecordMac,
+ .record_overflow => return error.TlsAlertRecordOverflow,
+ .handshake_failure => return error.TlsAlertHandshakeFailure,
+ .bad_certificate => return error.TlsAlertBadCertificate,
+ .unsupported_certificate => return error.TlsAlertUnsupportedCertificate,
+ .certificate_revoked => return error.TlsAlertCertificateRevoked,
+ .certificate_expired => return error.TlsAlertCertificateExpired,
+ .certificate_unknown => return error.TlsAlertCertificateUnknown,
+ .illegal_parameter => return error.TlsAlertIllegalParameter,
+ .unknown_ca => return error.TlsAlertUnknownCa,
+ .access_denied => return error.TlsAlertAccessDenied,
+ .decode_error => return error.TlsAlertDecodeError,
+ .decrypt_error => return error.TlsAlertDecryptError,
+ .protocol_version => return error.TlsAlertProtocolVersion,
+ .insufficient_security => return error.TlsAlertInsufficientSecurity,
+ .internal_error => return error.TlsAlertInternalError,
+ .inappropriate_fallback => return error.TlsAlertInappropriateFallback,
+ .user_canceled => {}, // not an error
+ .missing_extension => return error.TlsAlertMissingExtension,
+ .unsupported_extension => return error.TlsAlertUnsupportedExtension,
+ .unrecognized_name => return error.TlsAlertUnrecognizedName,
+ .bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse,
+ .unknown_psk_identity => return error.TlsAlertUnknownPskIdentity,
+ .certificate_required => return error.TlsAlertCertificateRequired,
+ .no_application_protocol => return error.TlsAlertNoApplicationProtocol,
+ _ => return error.TlsAlertUnknown,
+ }
}
- }
+ };
};
pub const SignatureScheme = enum(u16) {
@@ -650,7 +655,7 @@ pub const Decoder = struct {
}
/// Use this function to increase `their_end`.
- pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void {
+ pub fn readAtLeast(d: *Decoder, stream: *std.io.Reader, their_amt: usize) !void {
assert(!d.disable_reads);
const existing_amt = d.cap - d.idx;
d.their_end = d.idx + their_amt;
@@ -658,14 +663,16 @@ pub const Decoder = struct {
const request_amt = their_amt - existing_amt;
const dest = d.buf[d.cap..];
if (request_amt > dest.len) return error.TlsRecordOverflow;
- const actual_amt = try stream.readAtLeast(dest, request_amt);
- if (actual_amt < request_amt) return error.TlsConnectionTruncated;
- d.cap += actual_amt;
+ stream.readSlice(dest[0..request_amt]) catch |err| switch (err) {
+ error.EndOfStream => return error.TlsConnectionTruncated,
+ error.ReadFailed => return error.ReadFailed,
+ };
+ d.cap += request_amt;
}
/// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`.
/// Use when `our_amt` is calculated by us, not by them.
- pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void {
+ pub fn readAtLeastOurAmt(d: *Decoder, stream: *std.io.Reader, our_amt: usize) !void {
assert(!d.disable_reads);
try readAtLeast(d, stream, our_amt);
d.our_end = d.idx + our_amt;
lib/std/Io/Reader.zig
@@ -1306,31 +1306,6 @@ pub fn defaultRebase(r: *Reader, capacity: usize) RebaseError!void {
r.end = data.len;
}
-/// Advances the stream and decreases the size of the storage buffer by `n`,
-/// returning the range of bytes no longer accessible by `r`.
-///
-/// This action can be undone by `restitute`.
-///
-/// Asserts there are at least `n` buffered bytes already.
-///
-/// Asserts that `r.seek` is zero, i.e. the buffer is in a rebased state.
-pub fn steal(r: *Reader, n: usize) []u8 {
- assert(r.seek == 0);
- assert(n <= r.end);
- const stolen = r.buffer[0..n];
- r.buffer = r.buffer[n..];
- r.end -= n;
- return stolen;
-}
-
-/// Expands the storage buffer, undoing the effects of `steal`
-/// Assumes that `n` does not exceed the total number of stolen bytes.
-pub fn restitute(r: *Reader, n: usize) void {
- r.buffer = (r.buffer.ptr - n)[0 .. r.buffer.len + n];
- r.end += n;
- r.seek += n;
-}
-
test fixed {
var r: Reader = .fixed("a\x02");
try testing.expect((try r.takeByte()) == 'a');
lib/std/http.zig
@@ -343,10 +343,9 @@ pub const Reader = struct {
/// read from `in`.
trailers: []const u8 = &.{},
body_err: ?BodyError = null,
- /// Stolen from `in`.
- head_buffer: []u8 = &.{},
-
- pub const max_chunk_header_len = 22;
+ /// Determines at which point `error.HttpHeadersOversize` occurs, as well
+ /// as the minimum buffer capacity of `in`.
+ max_head_len: usize,
pub const RemainingChunkLen = enum(u64) {
head = 0,
@@ -398,19 +397,11 @@ pub const Reader = struct {
ReadFailed,
};
- pub fn restituteHeadBuffer(reader: *Reader) void {
- reader.in.restitute(reader.head_buffer.len);
- reader.head_buffer.len = 0;
- }
-
- /// Buffers the entire head into `head_buffer`, invalidating the previous
- /// `head_buffer`, if any.
+ /// Buffers the entire head.
pub fn receiveHead(reader: *Reader) HeadError!void {
reader.trailers = &.{};
const in = reader.in;
- in.restitute(reader.head_buffer.len);
- reader.head_buffer.len = 0;
- in.rebase();
+ try in.rebase(reader.max_head_len);
var hp: HeadParser = .{};
var head_end: usize = 0;
while (true) {