Commit 942b5b468f

Andrew Kelley <andrew@ziglang.org>
2022-12-17 02:06:00
std.crypto.tls: implement the rest of the cipher suites
Also: * Use KeyPair.create() function * Don't bother with CCM
1 parent 93ab8be
Changed files (3)
lib
lib/std/crypto/tls/Client.zig
@@ -23,27 +23,27 @@ partially_read_buffer: [tls.max_ciphertext_record_len]u8,
 partially_read_len: u15,
 eof: bool,
 
-const cipher_suites = blk: {
-    const fields = @typeInfo(CipherSuite).Enum.fields;
-    var result: [(fields.len + 1) * 2]u8 = undefined;
-    mem.writeIntBig(u16, result[0..2], result.len - 2);
-    for (fields) |field, i| {
-        const int = @enumToInt(@field(CipherSuite, field.name));
-        result[(i + 1) * 2] = @truncate(u8, int >> 8);
-        result[(i + 1) * 2 + 1] = @truncate(u8, int);
-    }
-    break :blk result;
-};
+// Measurement taken with 0.11.0-dev.810+c2f5848fe
+// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz:
+// zig run .lib/std/crypto/benchmark.zig -OReleaseFast
+//       aegis-128l:      15382 MiB/s
+//        aegis-256:       9553 MiB/s
+//       aes128-gcm:       3721 MiB/s
+//       aes256-gcm:       3010 MiB/s
+// chacha20Poly1305:        597 MiB/s
+
+const cipher_suites =
+    int2(@enumToInt(tls.CipherSuite.AEGIS_128L_SHA256)) ++
+    int2(@enumToInt(tls.CipherSuite.AEGIS_256_SHA384)) ++
+    int2(@enumToInt(tls.CipherSuite.AES_128_GCM_SHA256)) ++
+    int2(@enumToInt(tls.CipherSuite.AES_256_GCM_SHA384)) ++
+    int2(@enumToInt(tls.CipherSuite.CHACHA20_POLY1305_SHA256));
 
 /// `host` is only borrowed during this function call.
 pub fn init(stream: net.Stream, host: []const u8) !Client {
-    var x25519_priv_key: [32]u8 = undefined;
-    crypto.random.bytes(&x25519_priv_key);
-    const x25519_pub_key = crypto.dh.X25519.recoverPublicKey(x25519_priv_key) catch |err| {
-        switch (err) {
-            // Only possible to happen if the private key is all zeroes.
-            error.IdentityElement => return error.InsufficientEntropy,
-        }
+    const kp = crypto.dh.X25519.KeyPair.create(null) catch |err| switch (err) {
+        // Only possible to happen if the private key is all zeroes.
+        error.IdentityElement => return error.InsufficientEntropy,
     };
 
     // random (u32)
@@ -98,7 +98,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
         0, 36, // byte length of client_shares
         0x00, 0x1D, // NamedGroup.x25519
         0, 32, // byte length of key_exchange
-    } ++ x25519_pub_key ++ [_]u8{
+    } ++ kp.public_key ++ [_]u8{
 
         // Extension: server_name
         0, 0, // ExtensionType.server_name
@@ -120,7 +120,9 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
 
         // ClientHello
         0x03, 0x03, // legacy_version
-    } ++ rand_buf ++ [1]u8{0} ++ cipher_suites ++ [_]u8{
+    } ++ rand_buf ++ [1]u8{0} ++
+        int2(cipher_suites.len) ++ cipher_suites ++
+        [_]u8{
         0x01, 0x00, // legacy_compression_methods
     } ++ extensions_header;
 
@@ -191,9 +193,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
                 const legacy_session_id_echo_len = hello[34];
                 if (legacy_session_id_echo_len != 0) return error.TlsIllegalParameter;
                 const cipher_suite_int = mem.readIntBig(u16, hello[35..37]);
-                const cipher_suite_tag = std.meta.intToEnum(CipherSuite, cipher_suite_int) catch
-                    return error.TlsIllegalParameter;
-                std.debug.print("server wants cipher suite {s}\n", .{@tagName(cipher_suite_tag)});
+                const cipher_suite_tag = @intToEnum(CipherSuite, cipher_suite_int);
+                std.debug.print("server wants cipher suite {any}\n", .{cipher_suite_tag});
                 const legacy_compression_method = hello[37];
                 _ = legacy_compression_method;
                 const extensions_size = mem.readIntBig(u16, hello[38..40]);
@@ -250,13 +251,18 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
                 }
 
                 const shared_key = crypto.dh.X25519.scalarmult(
-                    x25519_priv_key,
+                    kp.secret_key,
                     x25519_server_pub_key.*,
                 ) catch return error.TlsDecryptFailure;
 
                 switch (cipher_suite_tag) {
-                    inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |tag| {
-                        const P = std.meta.TagPayload(CipherParams, tag);
+                    inline .AES_128_GCM_SHA256,
+                    .AES_256_GCM_SHA384,
+                    .CHACHA20_POLY1305_SHA256,
+                    .AEGIS_256_SHA384,
+                    .AEGIS_128L_SHA256,
+                    => |tag| {
+                        const P = std.meta.TagPayloadByName(CipherParams, @tagName(tag));
                         cipher_params = @unionInit(CipherParams, @tagName(tag), .{
                             .handshake_secret = undefined,
                             .master_secret = undefined,
@@ -301,14 +307,8 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
                         //    std.fmt.fmtSliceHexLower(&p.server_handshake_iv),
                         //});
                     },
-                    .TLS_CHACHA20_POLY1305_SHA256 => {
-                        @panic("TODO");
-                    },
-                    .TLS_AES_128_CCM_SHA256 => {
-                        @panic("TODO");
-                    },
-                    .TLS_AES_128_CCM_8_SHA256 => {
-                        @panic("TODO");
+                    else => {
+                        return error.TlsIllegalParameter;
                     },
                 }
             },
@@ -347,7 +347,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
             .application_data => {
                 var cleartext_buf: [8000]u8 = undefined;
                 const cleartext = switch (cipher_params) {
-                    inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: {
+                    inline else => |*p| c: {
                         const P = @TypeOf(p.*);
                         const ciphertext_len = record_size - P.AEAD.tag_length;
                         const ciphertext = handshake_buf[i..][0..ciphertext_len];
@@ -366,15 +366,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
                         p.transcript_hash.update(cleartext[0 .. cleartext.len - 1]);
                         break :c cleartext;
                     },
-                    .TLS_CHACHA20_POLY1305_SHA256 => {
-                        @panic("TODO");
-                    },
-                    .TLS_AES_128_CCM_SHA256 => {
-                        @panic("TODO");
-                    },
-                    .TLS_AES_128_CCM_8_SHA256 => {
-                        @panic("TODO");
-                    },
                 };
 
                 const inner_ct = @intToEnum(ContentType, cleartext[cleartext.len - 1]);
@@ -426,7 +417,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
                                         0x01,
                                     };
                                     const app_cipher = switch (cipher_params) {
-                                        inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p, tag| c: {
+                                        inline else => |*p, tag| c: {
                                             const P = @TypeOf(p.*);
                                             // TODO verify the server's data
                                             const handshake_hash = p.transcript_hash.finalResult();
@@ -467,15 +458,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
                                                 .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length),
                                             });
                                         },
-                                        .TLS_CHACHA20_POLY1305_SHA256 => {
-                                            @panic("TODO");
-                                        },
-                                        .TLS_AES_128_CCM_SHA256 => {
-                                            @panic("TODO");
-                                        },
-                                        .TLS_AES_128_CCM_8_SHA256 => {
-                                            @panic("TODO");
-                                        },
                                     };
                                     std.debug.print("remaining bytes: {d}\n", .{len - end});
                                     return .{
@@ -524,7 +506,7 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
     var bytes_i: usize = 0;
     // How many bytes are taken up by overhead per record.
     const overhead_len: usize = switch (c.application_cipher) {
-        inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| l: {
+        inline else => |*p| l: {
             const P = @TypeOf(p.*);
             const V = @Vector(P.AEAD.nonce_length, u8);
             const overhead_len = tls.ciphertext_record_header_len + P.AEAD.tag_length + 1;
@@ -577,15 +559,6 @@ pub fn write(c: *Client, stream: net.Stream, bytes: []const u8) !usize {
                 iovec_end += 1;
             }
         },
-        .TLS_CHACHA20_POLY1305_SHA256 => {
-            @panic("TODO");
-        },
-        .TLS_AES_128_CCM_SHA256 => {
-            @panic("TODO");
-        },
-        .TLS_AES_128_CCM_8_SHA256 => {
-            @panic("TODO");
-        },
     };
 
     // Ideally we would call writev exactly once here, however, we must ensure
@@ -659,7 +632,7 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize {
             },
             .application_data => {
                 const cleartext_len = switch (c.application_cipher) {
-                    inline .TLS_AES_128_GCM_SHA256, .TLS_AES_256_GCM_SHA384 => |*p| c: {
+                    inline else => |*p| c: {
                         const P = @TypeOf(p.*);
                         const V = @Vector(P.AEAD.nonce_length, u8);
                         const ad = frag[in - 5 ..][0..5];
@@ -682,15 +655,6 @@ pub fn read(c: *Client, stream: net.Stream, buffer: []u8) !usize {
                             return error.TlsBadRecordMac;
                         break :c cleartext.len;
                     },
-                    .TLS_CHACHA20_POLY1305_SHA256 => {
-                        @panic("TODO");
-                    },
-                    .TLS_AES_128_CCM_SHA256 => {
-                        @panic("TODO");
-                    },
-                    .TLS_AES_128_CCM_8_SHA256 => {
-                        @panic("TODO");
-                    },
                 };
 
                 const inner_ct = @intToEnum(ContentType, buffer[out + cleartext_len - 1]);
lib/std/crypto/tls.zig
@@ -211,17 +211,20 @@ pub const NamedGroup = enum(u16) {
 };
 
 pub const CipherSuite = enum(u16) {
-    TLS_AES_128_GCM_SHA256 = 0x1301,
-    TLS_AES_256_GCM_SHA384 = 0x1302,
-    TLS_CHACHA20_POLY1305_SHA256 = 0x1303,
-    TLS_AES_128_CCM_SHA256 = 0x1304,
-    TLS_AES_128_CCM_8_SHA256 = 0x1305,
+    AES_128_GCM_SHA256 = 0x1301,
+    AES_256_GCM_SHA384 = 0x1302,
+    CHACHA20_POLY1305_SHA256 = 0x1303,
+    AES_128_CCM_SHA256 = 0x1304,
+    AES_128_CCM_8_SHA256 = 0x1305,
+    AEGIS_256_SHA384 = 0x1306,
+    AEGIS_128L_SHA256 = 0x1307,
+    _,
 };
 
-pub const CipherParams = union(CipherSuite) {
-    TLS_AES_128_GCM_SHA256: struct {
-        pub const AEAD = crypto.aead.aes_gcm.Aes128Gcm;
-        pub const Hash = crypto.hash.sha2.Sha256;
+pub fn CipherParamsT(comptime AeadType: type, comptime HashType: type) type {
+    return struct {
+        pub const AEAD = AeadType;
+        pub const Hash = HashType;
         pub const Hmac = crypto.auth.hmac.Hmac(Hash);
         pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
 
@@ -234,33 +237,21 @@ pub const CipherParams = union(CipherSuite) {
         client_handshake_iv: [AEAD.nonce_length]u8,
         server_handshake_iv: [AEAD.nonce_length]u8,
         transcript_hash: Hash,
-    },
-    TLS_AES_256_GCM_SHA384: struct {
-        pub const AEAD = crypto.aead.aes_gcm.Aes256Gcm;
-        pub const Hash = crypto.hash.sha2.Sha384;
-        pub const Hmac = crypto.auth.hmac.Hmac(Hash);
-        pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
+    };
+}
 
-        handshake_secret: [Hkdf.prk_length]u8,
-        master_secret: [Hkdf.prk_length]u8,
-        client_handshake_key: [AEAD.key_length]u8,
-        server_handshake_key: [AEAD.key_length]u8,
-        client_finished_key: [Hmac.key_length]u8,
-        server_finished_key: [Hmac.key_length]u8,
-        client_handshake_iv: [AEAD.nonce_length]u8,
-        server_handshake_iv: [AEAD.nonce_length]u8,
-        transcript_hash: Hash,
-    },
-    TLS_CHACHA20_POLY1305_SHA256: void,
-    TLS_AES_128_CCM_SHA256: void,
-    TLS_AES_128_CCM_8_SHA256: void,
+pub const CipherParams = union(enum) {
+    AES_128_GCM_SHA256: CipherParamsT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256),
+    AES_256_GCM_SHA384: CipherParamsT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384),
+    CHACHA20_POLY1305_SHA256: CipherParamsT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256),
+    AEGIS_256_SHA384: CipherParamsT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha384),
+    AEGIS_128L_SHA256: CipherParamsT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256),
 };
 
-/// Encryption parameters for application traffic.
-pub const ApplicationCipher = union(CipherSuite) {
-    TLS_AES_128_GCM_SHA256: struct {
-        pub const AEAD = crypto.aead.aes_gcm.Aes128Gcm;
-        pub const Hash = crypto.hash.sha2.Sha256;
+pub fn ApplicationCipherT(comptime AeadType: type, comptime HashType: type) type {
+    return struct {
+        pub const AEAD = AeadType;
+        pub const Hash = HashType;
         pub const Hmac = crypto.auth.hmac.Hmac(Hash);
         pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
 
@@ -268,21 +259,16 @@ pub const ApplicationCipher = union(CipherSuite) {
         server_key: [AEAD.key_length]u8,
         client_iv: [AEAD.nonce_length]u8,
         server_iv: [AEAD.nonce_length]u8,
-    },
-    TLS_AES_256_GCM_SHA384: struct {
-        pub const AEAD = crypto.aead.aes_gcm.Aes256Gcm;
-        pub const Hash = crypto.hash.sha2.Sha384;
-        pub const Hmac = crypto.auth.hmac.Hmac(Hash);
-        pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac);
+    };
+}
 
-        client_key: [AEAD.key_length]u8,
-        server_key: [AEAD.key_length]u8,
-        client_iv: [AEAD.nonce_length]u8,
-        server_iv: [AEAD.nonce_length]u8,
-    },
-    TLS_CHACHA20_POLY1305_SHA256: void,
-    TLS_AES_128_CCM_SHA256: void,
-    TLS_AES_128_CCM_8_SHA256: void,
+/// Encryption parameters for application traffic.
+pub const ApplicationCipher = union(enum) {
+    AES_128_GCM_SHA256: ApplicationCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256),
+    AES_256_GCM_SHA384: ApplicationCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384),
+    CHACHA20_POLY1305_SHA256: ApplicationCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256),
+    AEGIS_256_SHA384: ApplicationCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha384),
+    AEGIS_128L_SHA256: ApplicationCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256),
 };
 
 pub fn hkdfExpandLabel(
lib/std/meta.zig
@@ -810,21 +810,25 @@ test "std.meta.activeTag" {
 
 const TagPayloadType = TagPayload;
 
-///Given a tagged union type, and an enum, return the type of the union
-/// field corresponding to the enum tag.
-pub fn TagPayload(comptime U: type, comptime tag: Tag(U)) type {
+pub fn TagPayloadByName(comptime U: type, comptime tag_name: []const u8) type {
     comptime debug.assert(trait.is(.Union)(U));
 
     const info = @typeInfo(U).Union;
 
     inline for (info.fields) |field_info| {
-        if (comptime mem.eql(u8, field_info.name, @tagName(tag)))
+        if (comptime mem.eql(u8, field_info.name, tag_name))
             return field_info.type;
     }
 
     unreachable;
 }
 
+/// Given a tagged union type, and an enum, return the type of the union field
+/// corresponding to the enum tag.
+pub fn TagPayload(comptime U: type, comptime tag: Tag(U)) type {
+    return TagPayloadByName(U, @tagName(tag));
+}
+
 test "std.meta.TagPayload" {
     const Event = union(enum) {
         Moved: struct {