Commit 920e5bc4ff
lib/std/crypto/Tls.zig
@@ -188,6 +188,12 @@ const NamedGroup = enum(u16) {
// * fragment: opaque
// - the data being transmitted
+// Ciphertext
+// * ContentType opaque_type = application_data; /* 23 */
+// * ProtocolVersion legacy_record_version = 0x0303; /* TLS v1.2 */
+// * uint16 length;
+// * opaque encrypted_record[TLSCiphertext.length];
+
// Handshake:
// * type: HandshakeType
// * length: u24
@@ -331,105 +337,144 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
};
try stream.writevAll(&iovecs);
- {
- var handshake_buf: [4000]u8 = undefined;
+ var handshake_buf: [4000]u8 = undefined;
+ var len: usize = 0;
+ var i: usize = i: {
const plaintext = handshake_buf[0..5];
- const amt = try stream.readAtLeast(&handshake_buf, plaintext.len);
- if (amt < plaintext.len) return error.EndOfStream;
+ len = try stream.readAtLeast(&handshake_buf, plaintext.len);
+ if (len < plaintext.len) return error.EndOfStream;
const ct = @intToEnum(ContentType, plaintext[0]);
const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]);
const end = plaintext.len + frag_len;
- if (end > handshake_buf.len) return error.TlsServerHelloTooBig;
- if (amt < end) {
- const amt2 = try stream.readAll(handshake_buf[amt..end]);
- if (amt2 < plaintext.len) return error.EndOfStream;
+ if (end > handshake_buf.len) return error.TlsRecordOverflow;
+ if (end > len) {
+ len += try stream.readAtLeast(handshake_buf[len..], end - len);
+ if (end > len) return error.EndOfStream;
}
const frag = handshake_buf[plaintext.len..end];
- if (ct == .alert) {
- const level = @intToEnum(AlertLevel, frag[0]);
- const desc = @intToEnum(AlertDescription, frag[1]);
- std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
- std.process.exit(1);
- } else if (ct == .handshake) {
- if (frag[0] != @enumToInt(HandshakeType.server_hello)) {
- return error.TlsUnexpectedMessage;
- }
- const length = mem.readIntBig(u24, frag[1..4]);
- if (4 + length != frag.len) return error.TlsBadLength;
- const hello = frag[4..];
- const legacy_version = mem.readIntBig(u16, hello[0..2]);
- const random = hello[2..34].*;
- _ = random;
- const legacy_session_id_echo_len = hello[34];
- if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter;
- const cipher_suite_int = mem.readIntBig(u16, hello[35..37]);
- const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
- return error.TlsIllegalParameter;
- std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)});
- const legacy_compression_method = hello[37];
- _ = legacy_compression_method;
- const extensions_size = mem.readIntBig(u16, hello[38..40]);
- if (40 + extensions_size != hello.len) return error.TlsBadLength;
- var i: usize = 40;
- var supported_version: u16 = 0;
- var have_server_pub_key = false;
- while (i < hello.len) {
- const et = mem.readIntBig(u16, hello[i..][0..2]);
- i += 2;
- const ext_size = mem.readIntBig(u16, hello[i..][0..2]);
- i += 2;
- const next_i = i + ext_size;
- if (next_i > hello.len) return error.TlsBadLength;
- switch (et) {
- @enumToInt(ExtensionType.supported_versions) => {
- if (supported_version != 0) return error.TlsIllegalParameter;
- supported_version = mem.readIntBig(u16, hello[i..][0..2]);
- },
- @enumToInt(ExtensionType.key_share) => {
- if (have_server_pub_key) return error.TlsIllegalParameter;
- const named_group = mem.readIntBig(u16, hello[i..][0..2]);
- i += 2;
- switch (named_group) {
- @enumToInt(NamedGroup.x25519) => {
- const key_size = mem.readIntBig(u16, hello[i..][0..2]);
- i += 2;
- if (key_size != 32) return error.TlsBadLength;
- const encrypted_key = hello[i..][0..32].*;
- const server_pub_key = try crypto.dh.X25519.scalarmult(
- tls.x25519_priv_key,
- encrypted_key,
- );
- tls.x25519_server_pub_key = server_pub_key;
- have_server_pub_key = true;
- },
- else => {
- std.debug.print("named group: {x}\n", .{named_group});
- return error.TlsIllegalParameter;
- },
- }
+ switch (ct) {
+ .alert => {
+ const level = @intToEnum(AlertLevel, frag[0]);
+ const desc = @intToEnum(AlertDescription, frag[1]);
+ std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
+ return error.TlsAlert;
+ },
+ .handshake => {
+ if (frag[0] != @enumToInt(HandshakeType.server_hello)) {
+ return error.TlsUnexpectedMessage;
+ }
+ const length = mem.readIntBig(u24, frag[1..4]);
+ if (4 + length != frag.len) return error.TlsBadLength;
+ const hello = frag[4..];
+ const legacy_version = mem.readIntBig(u16, hello[0..2]);
+ const random = hello[2..34].*;
+ _ = random;
+ const legacy_session_id_echo_len = hello[34];
+ if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter;
+ const cipher_suite_int = mem.readIntBig(u16, hello[35..37]);
+ const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
+ return error.TlsIllegalParameter;
+ std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)});
+ const legacy_compression_method = hello[37];
+ _ = legacy_compression_method;
+ const extensions_size = mem.readIntBig(u16, hello[38..40]);
+ if (40 + extensions_size != hello.len) return error.TlsBadLength;
+ var i: usize = 40;
+ var supported_version: u16 = 0;
+ var have_server_pub_key = false;
+ while (i < hello.len) {
+ const et = mem.readIntBig(u16, hello[i..][0..2]);
+ i += 2;
+ const ext_size = mem.readIntBig(u16, hello[i..][0..2]);
+ i += 2;
+ const next_i = i + ext_size;
+ if (next_i > hello.len) return error.TlsBadLength;
+ switch (et) {
+ @enumToInt(ExtensionType.supported_versions) => {
+ if (supported_version != 0) return error.TlsIllegalParameter;
+ supported_version = mem.readIntBig(u16, hello[i..][0..2]);
+ },
+ @enumToInt(ExtensionType.key_share) => {
+ if (have_server_pub_key) return error.TlsIllegalParameter;
+ const named_group = mem.readIntBig(u16, hello[i..][0..2]);
+ i += 2;
+ switch (named_group) {
+ @enumToInt(NamedGroup.x25519) => {
+ const key_size = mem.readIntBig(u16, hello[i..][0..2]);
+ i += 2;
+ if (key_size != 32) return error.TlsBadLength;
+ const encrypted_key = hello[i..][0..32].*;
+ const server_pub_key = try crypto.dh.X25519.scalarmult(
+ tls.x25519_priv_key,
+ encrypted_key,
+ );
+ tls.x25519_server_pub_key = server_pub_key;
+ have_server_pub_key = true;
+ },
+ else => {
+ std.debug.print("named group: {x}\n", .{named_group});
+ return error.TlsIllegalParameter;
+ },
+ }
+ },
+ else => {
+ std.debug.print("unexpected extension: {x}\n", .{et});
+ },
+ }
+ i = next_i;
+ }
+ if (!have_server_pub_key) return error.TlsIllegalParameter;
+ const tls_version = if (supported_version == 0) legacy_version else supported_version;
+ switch (tls_version) {
+ @enumToInt(ProtocolVersion.tls_1_2) => {
+ std.debug.print("server wants TLS v1.2\n", .{});
},
- else => {
- std.debug.print("unexpected extension: {x}\n", .{et});
+ @enumToInt(ProtocolVersion.tls_1_3) => {
+ std.debug.print("server wants TLS v1.3\n", .{});
},
+ else => return error.TlsIllegalParameter,
}
- i = next_i;
- }
- if (!have_server_pub_key) return error.TlsIllegalParameter;
- const tls_version = if (supported_version == 0) legacy_version else supported_version;
- switch (tls_version) {
- @enumToInt(ProtocolVersion.tls_1_2) => {
- std.debug.print("server wants TLS v1.2\n", .{});
- },
- @enumToInt(ProtocolVersion.tls_1_3) => {
- std.debug.print("server wants TLS v1.3\n", .{});
- },
- else => return error.TlsIllegalParameter,
- }
- } else {
- std.debug.print("content_type: {s}\n", .{@tagName(ct)});
- std.debug.print("got {d} bytes: {s}\n", .{ amt, std.fmt.fmtSliceHexLower(frag) });
+ },
+ else => return error.TlsUnexpectedMessage,
+ }
+ break :i end;
+ };
+
+ while (true) {
+ const end_hdr = i + 5;
+ if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow;
+ if (end_hdr > len) {
+ len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len);
+ if (end_hdr > len) return error.EndOfStream;
+ }
+ const ct = @intToEnum(ContentType, handshake_buf[i]);
+ i += 1;
+ const legacy_version = mem.readIntBig(u16, handshake_buf[i..][0..2]);
+ i += 2;
+ _ = legacy_version;
+ const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]);
+ i += 2;
+ const end = i + record_size;
+ if (end > handshake_buf.len) return error.TlsRecordOverflow;
+ if (end > len) {
+ len += try stream.readAtLeast(handshake_buf[len..], end - len);
+ if (end > len) return error.EndOfStream;
+ }
+ switch (ct) {
+ .change_cipher_spec => {
+ if (record_size != 1) return error.TlsUnexpectedMessage;
+ if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage;
+ },
+ .application_data => {
+ std.debug.print("TODO: decrypt these {d} bytes\n", .{record_size});
+ },
+ else => {
+ std.debug.print("content type: {s}\n", .{@tagName(ct)});
+ return error.TlsUnexpectedMessage;
+ },
}
+ i = end;
}
tls.state = .sent_hello;
lib/std/net.zig
@@ -1680,9 +1680,9 @@ pub const Stream = struct {
}
/// Returns the number of bytes read, calling the underlying read function
- /// multiple times until at least the buffer has at least `len` bytes
- /// filled. If the number read is less than `len` it means the stream
- /// reached the end. Reaching the end of the stream is not an error
+ /// the minimal number of times until at least the buffer has at least
+ /// `len` bytes filled. If the number read is less than `len` it means the
+ /// stream reached the end. Reaching the end of the stream is not an error
/// condition.
pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize {
var index: usize = 0;