Commit e2c16d03ab

Andrew Kelley <andrew@ziglang.org>
2022-12-19 02:51:28
std.crypto.tls.Client: support secp256r1 for handshake
1 parent f460c21
Changed files (1)
lib
std
crypto
lib/std/crypto/tls/Client.zig
@@ -46,7 +46,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
         // Only possible to happen if the private key is all zeroes.
         error.IdentityElement => return error.InsufficientEntropy,
     };
-    _ = secp256r1_kp;
 
     const extensions_payload =
         tls.extension(.supported_versions, [_]u8{
@@ -70,11 +69,14 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
         .rsa_pkcs1_sha1,
         .ecdsa_sha1,
     })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{
-        //.secp256r1,
+        .secp256r1,
         .x25519,
     })) ++ tls.extension(
         .key_share,
-        array(1, int2(@enumToInt(tls.NamedGroup.x25519)) ++ array(1, x25519_kp.public_key)),
+        array(1, int2(@enumToInt(tls.NamedGroup.x25519)) ++
+            array(1, x25519_kp.public_key) ++
+            int2(@enumToInt(tls.NamedGroup.secp256r1)) ++
+            array(1, secp256r1_kp.public_key.toUncompressedSec1())),
     ) ++
         int2(@enumToInt(tls.ExtensionType.server_name)) ++
         int2(host_len + 5) ++ // byte length of this extension payload
@@ -182,7 +184,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
                 i += 2;
                 if (i + extensions_size != frag.len) return error.TlsBadLength;
                 var supported_version: u16 = 0;
-                var opt_x25519_server_pub_key: ?*[32]u8 = null;
+                var shared_key: [32]u8 = undefined;
+                var have_shared_key = false;
                 while (i < frag.len) {
                     const et = mem.readIntBig(u16, frag[i..][0..2]);
                     i += 2;
@@ -196,15 +199,34 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
                             supported_version = mem.readIntBig(u16, frag[i..][0..2]);
                         },
                         @enumToInt(tls.ExtensionType.key_share) => {
-                            if (opt_x25519_server_pub_key != null) return error.TlsIllegalParameter;
+                            if (have_shared_key) return error.TlsIllegalParameter;
+                            have_shared_key = true;
                             const named_group = mem.readIntBig(u16, frag[i..][0..2]);
                             i += 2;
+                            const key_size = mem.readIntBig(u16, frag[i..][0..2]);
+                            i += 2;
+
                             switch (named_group) {
                                 @enumToInt(tls.NamedGroup.x25519) => {
-                                    const key_size = mem.readIntBig(u16, frag[i..][0..2]);
-                                    i += 2;
                                     if (key_size != 32) return error.TlsBadLength;
-                                    opt_x25519_server_pub_key = frag[i..][0..32];
+                                    const server_pub_key = frag[i..][0..32];
+
+                                    shared_key = crypto.dh.X25519.scalarmult(
+                                        x25519_kp.secret_key,
+                                        server_pub_key.*,
+                                    ) catch return error.TlsDecryptFailure;
+                                },
+                                @enumToInt(tls.NamedGroup.secp256r1) => {
+                                    const server_pub_key = frag[i..][0..key_size];
+
+                                    const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey;
+                                    const pk = PublicKey.fromSec1(server_pub_key) catch {
+                                        return error.TlsDecryptFailure;
+                                    };
+                                    const mul = pk.p.mulPublic(secp256r1_kp.secret_key.bytes, .Big) catch {
+                                        return error.TlsDecryptFailure;
+                                    };
+                                    shared_key = mul.affineCoordinates().x.toBytes(.Big);
                                 },
                                 else => {
                                     std.debug.print("named group: {x}\n", .{named_group});
@@ -218,8 +240,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
                     }
                     i = next_i;
                 }
-                const x25519_server_pub_key = opt_x25519_server_pub_key orelse
-                    return error.TlsIllegalParameter;
+                if (!have_shared_key) return error.TlsIllegalParameter;
                 const tls_version = if (supported_version == 0) legacy_version else supported_version;
                 switch (tls_version) {
                     @enumToInt(tls.ProtocolVersion.tls_1_2) => {
@@ -231,11 +252,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
                     else => return error.TlsIllegalParameter,
                 }
 
-                const shared_key = crypto.dh.X25519.scalarmult(
-                    x25519_kp.secret_key,
-                    x25519_server_pub_key.*,
-                ) catch return error.TlsDecryptFailure;
-
                 switch (cipher_suite_tag) {
                     inline .AES_128_GCM_SHA256,
                     .AES_256_GCM_SHA384,