Commit d2f5d0b199

Andrew Kelley <andrew@ziglang.org>
2022-12-14 04:15:41
std.crypto.Tls: parse the ServerHello handshake
1 parent ba44513
Changed files (3)
lib
std
lib/std/crypto/Tls.zig
@@ -8,6 +8,13 @@ const assert = std.debug.assert;
 state: State = .start,
 x25519_priv_key: [32]u8 = undefined,
 x25519_pub_key: [32]u8 = undefined,
+x25519_server_pub_key: [32]u8 = undefined,
+
+const ProtocolVersion = enum(u16) {
+    tls_1_2 = 0x0303,
+    tls_1_3 = 0x0304,
+    _,
+};
 
 const State = enum {
     /// In this state, all fields are undefined except state.
@@ -186,6 +193,18 @@ const NamedGroup = enum(u16) {
 // * length: u24
 // * data: opaque
 
+// ServerHello:
+// * ProtocolVersion legacy_version = 0x0303;
+// * Random random;
+// * opaque legacy_session_id_echo<0..32>;
+// * CipherSuite cipher_suite;
+// * uint8 legacy_compression_method = 0;
+// * Extension extensions<6..2^16-1>;
+
+// Extension:
+// * ExtensionType extension_type;
+// * opaque extension_data<0..2^16-1>;
+
 const CipherSuite = enum(u16) {
     TLS_AES_128_GCM_SHA256 = 0x1301,
     TLS_AES_256_GCM_SHA384 = 0x1302,
@@ -259,10 +278,10 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
 
         // Extension: key_share
         0, 51, // ExtensionType.key_share
-        0x00, 38, // byte length of this extension payload
-        0x00, 36, // byte length of client_shares
+        0, 38, // byte length of this extension payload
+        0, 36, // byte length of client_shares
         0x00, 0x1D, // NamedGroup.x25519
-        0x00, 32, // byte length of key_exchange
+        0, 32, // byte length of key_exchange
     } ++ tls.x25519_pub_key ++ [_]u8{
 
         // Extension: server_name
@@ -313,21 +332,103 @@ pub fn init(tls: *Tls, stream: net.Stream, host: []const u8) !void {
     try stream.writevAll(&iovecs);
 
     {
-        var buf: [1000]u8 = undefined;
-        const amt = try stream.read(&buf);
-        const resp = buf[0..amt];
-        const ct = @intToEnum(ContentType, resp[0]);
+        var handshake_buf: [4000]u8 = undefined;
+        const plaintext = handshake_buf[0..5];
+        const amt = try stream.readAtLeast(&handshake_buf, plaintext.len);
+        if (amt < plaintext.len) return error.EndOfStream;
+        const ct = @intToEnum(ContentType, plaintext[0]);
+        const frag_len = mem.readIntBig(u16, plaintext[3..][0..2]);
+        const end = plaintext.len + frag_len;
+        if (end > handshake_buf.len) return error.TlsServerHelloTooBig;
+        if (amt < end) {
+            const amt2 = try stream.readAll(handshake_buf[amt..end]);
+            if (amt2 < plaintext.len) return error.EndOfStream;
+        }
+        const frag = handshake_buf[plaintext.len..end];
+
         if (ct == .alert) {
-            //const prot_ver = @bitCast(u16, resp[1..][0..2].*);
-            const len = std.mem.readIntBig(u16, resp[3..][0..2]);
-            const alert = resp[5..][0..len];
-            const level = @intToEnum(AlertLevel, alert[0]);
-            const desc = @intToEnum(AlertDescription, alert[1]);
+            const level = @intToEnum(AlertLevel, frag[0]);
+            const desc = @intToEnum(AlertDescription, frag[1]);
             std.debug.print("alert: {s} {s}\n", .{ @tagName(level), @tagName(desc) });
             std.process.exit(1);
+        } else if (ct == .handshake) {
+            if (frag[0] != @enumToInt(HandshakeType.server_hello)) {
+                return error.TlsUnexpectedMessage;
+            }
+            const length = mem.readIntBig(u24, frag[1..4]);
+            if (4 + length != frag.len) return error.TlsBadLength;
+            const hello = frag[4..];
+            const legacy_version = mem.readIntBig(u16, hello[0..2]);
+            const random = hello[2..34].*;
+            _ = random;
+            const legacy_session_id_echo_len = hello[34];
+            if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter;
+            const cipher_suite_int = mem.readIntBig(u16, hello[35..37]);
+            const cipher_suite = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
+                return error.TlsIllegalParameter;
+            std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite)});
+            const legacy_compression_method = hello[37];
+            _ = legacy_compression_method;
+            const extensions_size = mem.readIntBig(u16, hello[38..40]);
+            if (40 + extensions_size != hello.len) return error.TlsBadLength;
+            var i: usize = 40;
+            var supported_version: u16 = 0;
+            var have_server_pub_key = false;
+            while (i < hello.len) {
+                const et = mem.readIntBig(u16, hello[i..][0..2]);
+                i += 2;
+                const ext_size = mem.readIntBig(u16, hello[i..][0..2]);
+                i += 2;
+                const next_i = i + ext_size;
+                if (next_i > hello.len) return error.TlsBadLength;
+                switch (et) {
+                    @enumToInt(ExtensionType.supported_versions) => {
+                        if (supported_version != 0) return error.TlsIllegalParameter;
+                        supported_version = mem.readIntBig(u16, hello[i..][0..2]);
+                    },
+                    @enumToInt(ExtensionType.key_share) => {
+                        if (have_server_pub_key) return error.TlsIllegalParameter;
+                        const named_group = mem.readIntBig(u16, hello[i..][0..2]);
+                        i += 2;
+                        switch (named_group) {
+                            @enumToInt(NamedGroup.x25519) => {
+                                const key_size = mem.readIntBig(u16, hello[i..][0..2]);
+                                i += 2;
+                                if (key_size != 32) return error.TlsBadLength;
+                                const encrypted_key = hello[i..][0..32].*;
+                                const server_pub_key = try crypto.dh.X25519.scalarmult(
+                                    tls.x25519_priv_key,
+                                    encrypted_key,
+                                );
+                                tls.x25519_server_pub_key = server_pub_key;
+                                have_server_pub_key = true;
+                            },
+                            else => {
+                                std.debug.print("named group: {x}\n", .{named_group});
+                                return error.TlsIllegalParameter;
+                            },
+                        }
+                    },
+                    else => {
+                        std.debug.print("unexpected extension: {x}\n", .{et});
+                    },
+                }
+                i = next_i;
+            }
+            if (!have_server_pub_key) return error.TlsIllegalParameter;
+            const tls_version = if (supported_version == 0) legacy_version else supported_version;
+            switch (tls_version) {
+                @enumToInt(ProtocolVersion.tls_1_2) => {
+                    std.debug.print("server wants TLS v1.2\n", .{});
+                },
+                @enumToInt(ProtocolVersion.tls_1_3) => {
+                    std.debug.print("server wants TLS v1.3\n", .{});
+                },
+                else => return error.TlsIllegalParameter,
+            }
         } else {
             std.debug.print("content_type: {s}\n", .{@tagName(ct)});
-            std.debug.print("got {d} bytes: {s}\n", .{ amt, std.fmt.fmtSliceHexLower(resp) });
+            std.debug.print("got {d} bytes: {s}\n", .{ amt, std.fmt.fmtSliceHexLower(frag) });
         }
     }
 
lib/std/http/Client.zig
@@ -59,7 +59,7 @@ pub const Request = struct {
 
 pub fn deinit(client: *Client) void {
     assert(client.active_requests == 0);
-    client.headers.denit(client.allocator);
+    client.headers.deinit(client.allocator);
     client.* = undefined;
 }
 
@@ -69,6 +69,7 @@ pub fn request(client: *Client, options: Request.Options) !Request {
         .stream = try net.tcpConnectToHost(client.allocator, options.host, options.port),
         .protocol = options.protocol,
     };
+    client.active_requests += 1;
     errdefer req.deinit();
 
     switch (options.protocol) {
@@ -100,7 +101,6 @@ pub fn request(client: *Client, options: Request.Options) !Request {
     }
     req.headers.appendSliceAssumeCapacity(client.headers.items);
 
-    client.active_requests += 1;
     return req;
 }
 
lib/std/net.zig
@@ -1672,6 +1672,28 @@ pub const Stream = struct {
         }
     }
 
+    /// Returns the number of bytes read. If the number read is smaller than
+    /// `buffer.len`, it means the stream reached the end. Reaching the end of
+    /// a stream is not an error condition.
+    pub fn readAll(s: Stream, buffer: []u8) ReadError!usize {
+        return readAtLeast(s, buffer, buffer.len);
+    }
+
+    /// Returns the number of bytes read, calling the underlying read function
+    /// multiple times until at least the buffer has at least `len` bytes
+    /// filled. If the number read is less than `len` it means the stream
+    /// reached the end. Reaching the end of the stream is not an error
+    /// condition.
+    pub fn readAtLeast(s: Stream, buffer: []u8, len: usize) ReadError!usize {
+        var index: usize = 0;
+        while (index < len) {
+            const amt = try s.read(buffer[index..]);
+            if (amt == 0) break;
+            index += amt;
+        }
+        return index;
+    }
+
     /// TODO in evented I/O mode, this implementation incorrectly uses the event loop's
     /// file system thread instead of non-blocking. It needs to be reworked to properly
     /// use non-blocking I/O.