Commit 611a1fdd6d
Changed files (3)
lib
std
crypto
http
lib/std/crypto/tls/Client.zig
@@ -37,8 +37,54 @@ application_cipher: tls.ApplicationCipher,
/// `partial_ciphertext_end` describe the span of the segments.
partially_read_buffer: [tls.max_ciphertext_record_len]u8,
+/// This is an example of the type that is needed by the read and write
+/// functions. It can have any fields but it must at least have these
+/// functions.
+///
+/// Note that `std.net.Stream` conforms to this interface.
+///
+/// This declaration serves as documentation only.
+pub const StreamInterface = struct {
+ /// Can be any error set.
+ pub const ReadError = error{};
+
+ /// Returns the number of bytes read. The number read may be less than the
+ /// buffer space provided. End-of-stream is indicated by a return value of 0.
+ ///
+ /// The `iovecs` parameter is mutable because so that function may to
+ /// mutate the fields in order to handle partial reads from the underlying
+ /// stream layer.
+ pub fn readv(this: @This(), iovecs: []std.os.iovec) ReadError!usize {
+ _ = .{ this, iovecs };
+ @panic("unimplemented");
+ }
+
+ /// Can be any error set.
+ pub const WriteError = error{};
+
+ /// Returns the number of bytes read, which may be less than the buffer
+ /// space provided. A short read does not indicate end-of-stream.
+ pub fn writev(this: @This(), iovecs: []const std.os.iovec_const) WriteError!usize {
+ _ = .{ this, iovecs };
+ @panic("unimplemented");
+ }
+
+ /// Returns the number of bytes read, which may be less than the buffer
+ /// space provided, indicating end-of-stream.
+ /// The `iovecs` parameter is mutable in case this function needs to mutate
+ /// the fields in order to handle partial writes from the underlying layer.
+ pub fn writevAll(this: @This(), iovecs: []std.os.iovec_const) WriteError!usize {
+ // This can be implemented in terms of writev, or specialized if desired.
+ _ = .{ this, iovecs };
+ @panic("unimplemented");
+ }
+};
+
+/// Initiates a TLS handshake and establishes a TLSv1.3 session with `stream`, which
+/// must conform to `StreamInterface`.
+///
/// `host` is only borrowed during this function call.
-pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8) !Client {
+pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) !Client {
const host_len = @intCast(u16, host.len);
var random_buffer: [128]u8 = undefined;
@@ -579,31 +625,115 @@ pub fn init(stream: net.Stream, ca_bundle: Certificate.Bundle, host: []const u8)
}
}
-pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
+/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
+/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`.
+pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize {
+ return writeEnd(c, stream, bytes, false);
+}
+
+/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
+pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void {
+ var index: usize = 0;
+ while (index < bytes.len) {
+ index += try c.write(stream, bytes[index..]);
+ }
+}
+
+/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
+/// If `end` is true, then this function additionally sends a `close_notify` alert,
+/// which is necessary for the server to distinguish between a properly finished
+/// TLS session, or a truncation attack.
+pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void {
+ var index: usize = 0;
+ while (index < bytes.len) {
+ index += try c.writeEnd(stream, bytes[index..], end);
+ }
+}
+
+/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`.
+/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`.
+/// If `end` is true, then this function additionally sends a `close_notify` alert,
+/// which is necessary for the server to distinguish between a properly finished
+/// TLS session, or a truncation attack.
+pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize {
var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined;
+ var iovecs_buf: [6]std.os.iovec_const = undefined;
+ var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data);
+ if (end) {
+ prepared.iovec_end += prepareCiphertextRecord(
+ c,
+ iovecs_buf[prepared.iovec_end..],
+ ciphertext_buf[prepared.ciphertext_end..],
+ &tls.close_notify_alert,
+ .alert,
+ ).iovec_end;
+ }
+
+ const iovec_end = prepared.iovec_end;
+ const overhead_len = prepared.overhead_len;
+
+ // Ideally we would call writev exactly once here, however, we must ensure
+ // that we don't return with a record partially written.
+ var i: usize = 0;
+ var total_amt: usize = 0;
+ while (true) {
+ var amt = try stream.writev(iovecs_buf[i..iovec_end]);
+ while (amt >= iovecs_buf[i].iov_len) {
+ const encrypted_amt = iovecs_buf[i].iov_len;
+ total_amt += encrypted_amt - overhead_len;
+ amt -= encrypted_amt;
+ i += 1;
+ // Rely on the property that iovecs delineate records, meaning that
+ // if amt equals zero here, we have fortunately found ourselves
+ // with a short read that aligns at the record boundary.
+ if (i >= iovec_end) return total_amt;
+ // We also cannot return on a vector boundary if the final close_notify is
+ // not sent; otherwise the caller would not know to retry the call.
+ if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt;
+ }
+ iovecs_buf[i].iov_base += amt;
+ iovecs_buf[i].iov_len -= amt;
+ }
+}
+
+fn prepareCiphertextRecord(
+ c: *Client,
+ iovecs: []std.os.iovec_const,
+ ciphertext_buf: []u8,
+ bytes: []const u8,
+ inner_content_type: tls.ContentType,
+) struct {
+ iovec_end: usize,
+ ciphertext_end: usize,
+ /// How many bytes are taken up by overhead per record.
+ overhead_len: usize,
+} {
// Due to the trailing inner content type byte in the ciphertext, we need
// an additional buffer for storing the cleartext into before encrypting.
var cleartext_buf: [max_ciphertext_len]u8 = undefined;
- var iovecs_buf: [5]std.os.iovec_const = undefined;
var ciphertext_end: usize = 0;
var iovec_end: usize = 0;
var bytes_i: usize = 0;
- // How many bytes are taken up by overhead per record.
- const overhead_len: usize = switch (c.application_cipher) {
- inline else => |*p| l: {
+ switch (c.application_cipher) {
+ inline else => |*p| {
const P = @TypeOf(p.*);
const V = @Vector(P.AEAD.nonce_length, u8);
const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1;
+ const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len;
while (true) {
const encrypted_content_len = @intCast(u16, @min(
@min(bytes.len - bytes_i, max_ciphertext_len - 1),
- ciphertext_buf.len -
- tls.record_header_len - P.AEAD.tag_length - ciphertext_end - 1,
+ ciphertext_buf.len - close_notify_alert_reserved -
+ overhead_len - ciphertext_end,
));
- if (encrypted_content_len == 0) break :l overhead_len;
+ if (encrypted_content_len == 0) return .{
+ .iovec_end = iovec_end,
+ .ciphertext_end = ciphertext_end,
+ .overhead_len = overhead_len,
+ };
mem.copy(u8, &cleartext_buf, bytes[bytes_i..][0..encrypted_content_len]);
- cleartext_buf[encrypted_content_len] = @enumToInt(tls.ContentType.application_data);
+ cleartext_buf[encrypted_content_len] = @enumToInt(inner_content_type);
bytes_i += encrypted_content_len;
const ciphertext_len = encrypted_content_len + 1;
const cleartext = cleartext_buf[0..ciphertext_len];
@@ -626,40 +756,13 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key);
const record = ciphertext_buf[record_start..ciphertext_end];
- iovecs_buf[iovec_end] = .{
+ iovecs[iovec_end] = .{
.iov_base = record.ptr,
.iov_len = record.len,
};
iovec_end += 1;
}
},
- };
-
- // Ideally we would call writev exactly once here, however, we must ensure
- // that we don't return with a record partially written.
- var i: usize = 0;
- var total_amt: usize = 0;
- while (true) {
- var amt = try stream.writev(iovecs_buf[i..iovec_end]);
- while (amt >= iovecs_buf[i].iov_len) {
- const encrypted_amt = iovecs_buf[i].iov_len;
- total_amt += encrypted_amt - overhead_len;
- amt -= encrypted_amt;
- i += 1;
- // Rely on the property that iovecs delineate records, meaning that
- // if amt equals zero here, we have fortunately found ourselves
- // with a short read that aligns at the record boundary.
- if (i >= iovec_end or amt == 0) return total_amt;
- }
- iovecs_buf[i].iov_base += amt;
- iovecs_buf[i].iov_len -= amt;
- }
-}
-
-pub fn writeAll(c: *Client, stream: net.Stream, bytes: []const u8) !void {
- var index: usize = 0;
- while (index < bytes.len) {
- index += try c.write(stream, bytes[index..]);
}
}
@@ -669,6 +772,7 @@ pub fn eof(c: Client) bool {
c.partial_ciphertext_idx >= c.partial_ciphertext_end;
}
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
/// Returns the number of bytes read, calling the underlying read function the
/// minimal number of times until the buffer has at least `len` bytes filled.
/// If the number read is less than `len` it means the stream reached the end.
@@ -678,10 +782,12 @@ pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize
return readvAtLeast(c, stream, &iovecs, len);
}
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize {
return readAtLeast(c, stream, buffer, 1);
}
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
/// Returns the number of bytes read. If the number read is smaller than
/// `buffer.len`, it means the stream reached the end. Reaching the end of the
/// stream is not an error condition.
@@ -689,6 +795,7 @@ pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize {
return readAtLeast(c, stream, buffer, buffer.len);
}
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
/// Returns the number of bytes read. If the number read is less than the space
/// provided it means the stream reached the end. Reaching the end of the
/// stream is not an error condition.
@@ -698,6 +805,7 @@ pub fn readv(c: *Client, stream: anytype, iovecs: []std.os.iovec) !usize {
return readvAtLeast(c, stream, iovecs);
}
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
/// Returns the number of bytes read, calling the underlying read function the
/// minimal number of times until the iovecs have at least `len` bytes filled.
/// If the number read is less than `len` it means the stream reached the end.
@@ -722,6 +830,7 @@ pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.os.iovec, len: us
}
}
+/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`.
/// Returns number of bytes that have been read, populated inside `iovecs`. A
/// return value of zero bytes does not mean end of stream. Instead, check the `eof()`
/// for the end of stream. The `eof()` may be true after any call to
@@ -729,7 +838,7 @@ pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.os.iovec, len: us
/// function asserts that `eof()` is `false`.
/// See `readv` for a higher level function that has the same, familiar API as
/// other read functions, such as `std.fs.File.read`.
-pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iovec) !usize {
+pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec) !usize {
var vp: VecPut = .{ .iovecs = iovecs };
// Give away the buffered cleartext we have, if any.
@@ -905,7 +1014,8 @@ pub fn readvAdvanced(c: *Client, stream: net.Stream, iovecs: []const std.os.iove
break :c cleartext;
},
};
- c.read_seq += 1;
+
+ c.read_seq = try std.math.add(u64, c.read_seq, 1);
const inner_ct = @intToEnum(tls.ContentType, cleartext[cleartext.len - 1]);
switch (inner_ct) {
@@ -1196,3 +1306,7 @@ const cipher_suites = enum_array(tls.CipherSuite, &.{
.AES_256_GCM_SHA384,
.CHACHA20_POLY1305_SHA256,
});
+
+test {
+ _ = StreamInterface;
+}
lib/std/crypto/tls.zig
@@ -47,6 +47,11 @@ pub const hello_retry_request_sequence = [32]u8{
0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C,
};
+pub const close_notify_alert = [_]u8{
+ @enumToInt(AlertLevel.warning),
+ @enumToInt(AlertDescription.close_notify),
+};
+
pub const ProtocolVersion = enum(u16) {
tls_1_2 = 0x0303,
tls_1_3 = 0x0304,
lib/std/http/Client.zig
@@ -47,7 +47,7 @@ pub const Request = struct {
try req.stream.writeAll(req.headers.items);
},
.https => {
- try req.tls_client.writeAll(req.stream, req.headers.items);
+ try req.tls_client.writeAllEnd(req.stream, req.headers.items, true);
},
}
}