Commit 920e5bc4ff

Andrew Kelley <andrew@ziglang.org>
2022-12-14 05:59:01
std.crypto.Tls: discard ChangeCipherSpec messages
The next step here is to decrypt encrypted records
1 parent d2f5d0b
Changed files (2)
lib
std
lib/std/crypto/Tls.zig
@@ -188,6 +188,12 @@ const NamedGroup = enum(u16) {
 // * fragment: opaque
 //   - the data being transmitted
 
+// Ciphertext
+// * ContentType opaque_type = application_data; /* 23 */
+// * ProtocolVersion legacy_record_version = 0x0303; /* TLS v1.2 */
+// * uint16 length;
+// * opaque encrypted_record[TLSCiphertext.length];
+
 // Handshake:
 // * type: HandshakeType
 // * length: u24
@@ -331,105 +337,144 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
     };
     try stream.writevAll(&iovecs);
 
-    {
-        var handshake_buf: [4000]u8 = undefined;
+    var handshake_buf: [4000]u8 = undefined;
+    var len: usize = 0;
+    var i: usize = i: {
         const plaintext = handshake_buf[0..5];
-        const amt = try stream.readAtLeast(&handshake_buf, plaintext.len);
-        if (amt < plaintext.len) return error.EndOfStream;
+        len = try stream.readAtLeast(&handshake_buf, plaintext.len);
+        if (len < plaintext.len) return error.EndOfStream;
         const ct = @intToEnum(ContentType, plaintext[0]);
         const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]);
         const end = plaintext.len + frag_len;
-        if (end > handshake_buf.len) return error.TlsServerHelloTooBig;
-        if (amt < end) {
-            const amt2 = try stream.readAll(handshake_buf[amt..end]);
-            if (amt2 < plaintext.len) return error.EndOfStream;
+        if (end > handshake_buf.len) return error.TlsRecordOverflow;
+        if (end > len) {
+            len += try stream.readAtLeast(handshake_buf[len..], end - len);
+            if (end > len) return error.EndOfStream;
         }
         const frag = handshake_buf[plaintext.len..end];
 
-        if (ct == .alert) {
-            const level = @intToEnum(AlertLevel, frag[0]);
-            const desc = @intToEnum(AlertDescription, frag[1]);
-            std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
-            std.process.exit(1);
-        } else if (ct == .handshake) {
-            if (frag[0] != @enumToInt(HandshakeType.server_hello)) {
-                return error.TlsUnexpectedMessage;
-            }
-            const length = mem.readIntBig(u24, frag[1..4]);
-            if (4 + length != frag.len) return error.TlsBadLength;
-            const hello = frag[4..];
-            const legacy_version = mem.readIntBig(u16, hello[0..2]);
-            const random = hello[2..34].*;
-            _ = random;
-            const legacy_session_id_echo_len = hello[34];
-            if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter;
-            const cipher_suite_int = mem.readIntBig(u16, hello[35..37]);
-            const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
-                return error.TlsIllegalParameter;
-            std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)});
-            const legacy_compression_method = hello[37];
-            _ = legacy_compression_method;
-            const extensions_size = mem.readIntBig(u16, hello[38..40]);
-            if (40 + extensions_size != hello.len) return error.TlsBadLength;
-            var i: usize = 40;
-            var supported_version: u16 = 0;
-            var have_server_pub_key = false;
-            while (i < hello.len) {
-                const et = mem.readIntBig(u16, hello[i..][0..2]);
-                i += 2;
-                const ext_size = mem.readIntBig(u16, hello[i..][0..2]);
-                i += 2;
-                const next_i = i + ext_size;
-                if (next_i > hello.len) return error.TlsBadLength;
-                switch (et) {
-                    @enumToInt(ExtensionType.supported_versions) => {
-                        if (supported_version != 0) return error.TlsIllegalParameter;
-                        supported_version = mem.readIntBig(u16, hello[i..][0..2]);
-                    },
-                    @enumToInt(ExtensionType.key_share) => {
-                        if (have_server_pub_key) return error.TlsIllegalParameter;
-                        const named_group = mem.readIntBig(u16, hello[i..][0..2]);
-                        i += 2;
-                        switch (named_group) {
-                            @enumToInt(NamedGroup.x25519) => {
-                                const key_size = mem.readIntBig(u16, hello[i..][0..2]);
-                                i += 2;
-                                if (key_size != 32) return error.TlsBadLength;
-                                const encrypted_key = hello[i..][0..32].*;
-                                const server_pub_key = try crypto.dh.X25519.scalarmult(
-                                    tls.x25519_priv_key,
-                                    encrypted_key,
-                                );
-                                tls.x25519_server_pub_key = server_pub_key;
-                                have_server_pub_key = true;
-                            },
-                            else => {
-                                std.debug.print("named group: {x}\n", .{named_group});
-                                return error.TlsIllegalParameter;
-                            },
-                        }
+        switch (ct) {
+            .alert => {
+                const level = @intToEnum(AlertLevel, frag[0]);
+                const desc = @intToEnum(AlertDescription, frag[1]);
+                std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
+                return error.TlsAlert;
+            },
+            .handshake => {
+                if (frag[0] != @enumToInt(HandshakeType.server_hello)) {
+                    return error.TlsUnexpectedMessage;
+                }
+                const length = mem.readIntBig(u24, frag[1..4]);
+                if (4 + length != frag.len) return error.TlsBadLength;
+                const hello = frag[4..];
+                const legacy_version = mem.readIntBig(u16, hello[0..2]);
+                const random = hello[2..34].*;
+                _ = random;
+                const legacy_session_id_echo_len = hello[34];
+                if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter;
+                const cipher_suite_int = mem.readIntBig(u16, hello[35..37]);
+                const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
+                    return error.TlsIllegalParameter;
+                std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)});
+                const legacy_compression_method = hello[37];
+                _ = legacy_compression_method;
+                const extensions_size = mem.readIntBig(u16, hello[38..40]);
+                if (40 + extensions_size != hello.len) return error.TlsBadLength;
+                var i: usize = 40;
+                var supported_version: u16 = 0;
+                var have_server_pub_key = false;
+                while (i < hello.len) {
+                    const et = mem.readIntBig(u16, hello[i..][0..2]);
+                    i += 2;
+                    const ext_size = mem.readIntBig(u16, hello[i..][0..2]);
+                    i += 2;
+                    const next_i = i + ext_size;
+                    if (next_i > hello.len) return error.TlsBadLength;
+                    switch (et) {
+                        @enumToInt(ExtensionType.supported_versions) => {
+                            if (supported_version != 0) return error.TlsIllegalParameter;
+                            supported_version = mem.readIntBig(u16, hello[i..][0..2]);
+                        },
+                        @enumToInt(ExtensionType.key_share) => {
+                            if (have_server_pub_key) return error.TlsIllegalParameter;
+                            const named_group = mem.readIntBig(u16, hello[i..][0..2]);
+                            i += 2;
+                            switch (named_group) {
+                                @enumToInt(NamedGroup.x25519) => {
+                                    const key_size = mem.readIntBig(u16, hello[i..][0..2]);
+                                    i += 2;
+                                    if (key_size != 32) return error.TlsBadLength;
+                                    const encrypted_key = hello[i..][0..32].*;
+                                    const server_pub_key = try crypto.dh.X25519.scalarmult(
+                                        tls.x25519_priv_key,
+                                        encrypted_key,
+                                    );
+                                    tls.x25519_server_pub_key = server_pub_key;
+                                    have_server_pub_key = true;
+                                },
+                                else => {
+                                    std.debug.print("named group: {x}\n", .{named_group});
+                                    return error.TlsIllegalParameter;
+                                },
+                            }
+                        },
+                        else => {
+                            std.debug.print("unexpected extension: {x}\n", .{et});
+                        },
+                    }
+                    i = next_i;
+                }
+                if (!have_server_pub_key) return error.TlsIllegalParameter;
+                const tls_version = if (supported_version == 0) legacy_version else supported_version;
+                switch (tls_version) {
+                    @enumToInt(ProtocolVersion.tls_1_2) => {
+                        std.debug.print("server wants TLS v1.2\n", .{});
                     },
-                    else => {
-                        std.debug.print("unexpected extension: {x}\n", .{et});
+                    @enumToInt(ProtocolVersion.tls_1_3) => {
+                        std.debug.print("server wants TLS v1.3\n", .{});
                     },
+                    else => return error.TlsIllegalParameter,
                 }
-                i = next_i;
-            }
-            if (!have_server_pub_key) return error.TlsIllegalParameter;
-            const tls_version = if (supported_version == 0) legacy_version else supported_version;
-            switch (tls_version) {
-                @enumToInt(ProtocolVersion.tls_1_2) => {
-                    std.debug.print("server wants TLS v1.2\n", .{});
-                },
-                @enumToInt(ProtocolVersion.tls_1_3) => {
-                    std.debug.print("server wants TLS v1.3\n", .{});
-                },
-                else => return error.TlsIllegalParameter,
-            }
-        } else {
-            std.debug.print("content_type: {s}\n", .{@tagName(ct)});
-            std.debug.print("got {d} bytes: {s}\n", .{ amt, std.fmt.fmtSliceHexLower(frag) });
+            },
+            else => return error.TlsUnexpectedMessage,
+        }
+        break :i end;
+    };
+
+    while (true) {
+        const end_hdr = i + 5;
+        if (end_hdr > handshake_buf.len) return error.TlsRecordOverflow;
+        if (end_hdr > len) {
+            len += try stream.readAtLeast(handshake_buf[len..], end_hdr - len);
+            if (end_hdr > len) return error.EndOfStream;
+        }
+        const ct = @intToEnum(ContentType, handshake_buf[i]);
+        i += 1;
+        const legacy_version = mem.readIntBig(u16, handshake_buf[i..][0..2]);
+        i += 2;
+        _ = legacy_version;
+        const record_size = mem.readIntBig(u16, handshake_buf[i..][0..2]);
+        i += 2;
+        const end = i + record_size;
+        if (end > handshake_buf.len) return error.TlsRecordOverflow;
+        if (end > len) {
+            len += try stream.readAtLeast(handshake_buf[len..], end - len);
+            if (end > len) return error.EndOfStream;
+        }
+        switch (ct) {
+            .change_cipher_spec => {
+                if (record_size != 1) return error.TlsUnexpectedMessage;
+                if (handshake_buf[i] != 0x01) return error.TlsUnexpectedMessage;
+            },
+            .application_data => {
+                std.debug.print("TODO: decrypt these {d} bytes\n", .{record_size});
+            },
+            else => {
+                std.debug.print("content type: {s}\n", .{@tagName(ct)});
+                return error.TlsUnexpectedMessage;
+            },
         }
+        i = end;
     }
 
     tls.state = .sent_hello;
lib/std/net.zig
@@ -1680,9 +1680,9 @@ pub const Stream = struct {
     }
 
     /// Returns the number of bytes read, calling the underlying read function
-    /// multiple times until at least the buffer has at least `len` bytes
-    /// filled. If the number read is less than `len` it means the stream
-    /// reached the end. Reaching the end of the stream is not an error
+    /// the minimal number of times until at least the buffer has at least
+    /// `len` bytes filled. If the number read is less than `len` it means the
+    /// stream reached the end. Reaching the end of the stream is not an error
     /// condition.
     pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize {
         var index: usize = 0;