Commit 1ca3a48b87

Frank Denis <github@pureftpd.org>
2024-03-10 15:30:13
std.crypto: add support for ML-KEM
ML-KEM is the Kyber post-quantum secure key encapsulation mechanism, as being standardized by NIST. Too bad, they decided to rename it; the "Kyber" name was so much better! This implements the current draft (NIST FIPS-203), which is already being deployed even though the specification is not finalized.
1 parent 4ba4f94
Changed files (2)
lib
lib/std/crypto/kyber_d00.zig → lib/std/crypto/ml_kem.zig
@@ -1,14 +1,15 @@
-//! Implementation of the IND-CCA2 post-quantum secure key encapsulation
-//! mechanism (KEM) CRYSTALS-Kyber, as submitted to the third round of the NIST
-//! Post-Quantum Cryptography (v3.02/"draft00"), and selected for standardisation.
+//! Implementation of the IND-CCA2 post-quantum secure key encapsulation mechanism (KEM)
+//! ML-KEM (NIST FIPS-203 publication) and CRYSTALS-Kyber (v3.02/"draft00" CFRG draft).
 //!
-//! Kyber will likely change before final standardisation.
+//! The schemes are not finalized yet, and are still subject to breaking changes.
 //!
-//! The namespace suffix (currently `_d00`) refers to the version currently
-//! implemented, in accordance with the draft. It may not be updated if new
-//! versions of the draft only include editorial changes.
+//! The Kyber namespace suffix (currently `_d00`) refers to the version currently
+//! implemented, in accordance with the draft.
+//! The ML-KEM namespace suffix (currently `_01`) refers to the NIST FIPS-203 draft
+//! published on August 24, 2023, with the unintentional transposition of  having been reverted.
 //!
-//! The suffix will eventually be removed once Kyber is finalized.
+//! Suffixes may not be updated if new versions of the documents only include editorial changes.
+//! The suffixes will be removed once the schemes are finalized.
 //!
 //! Quoting from the CFRG I-D:
 //!
@@ -108,6 +109,7 @@ const builtin = @import("builtin");
 const testing = std.testing;
 const assert = std.debug.assert;
 const crypto = std.crypto;
+const errors = std.crypto.errors;
 const math = std.math;
 const mem = std.mem;
 const RndGen = std.Random.DefaultPrng;
@@ -128,6 +130,9 @@ const eta2: u8 = 2;
 const Params = struct {
     name: []const u8,
 
+    // NIST ML-KEM variant instead of Kyber as originally submitted.
+    ml_kem: bool = false,
+
     // Width and height of the matrix A.
     k: u8,
 
@@ -143,31 +148,69 @@ const Params = struct {
     dv: u8,
 };
 
-pub const Kyber512 = Kyber(.{
-    .name = "Kyber512",
-    .k = 2,
-    .eta1 = 3,
-    .du = 10,
-    .dv = 4,
-});
-
-pub const Kyber768 = Kyber(.{
-    .name = "Kyber768",
-    .k = 3,
-    .eta1 = 2,
-    .du = 10,
-    .dv = 4,
-});
-
-pub const Kyber1024 = Kyber(.{
-    .name = "Kyber1024",
-    .k = 4,
-    .eta1 = 2,
-    .du = 11,
-    .dv = 5,
-});
-
-const modes = [_]type{ Kyber512, Kyber768, Kyber1024 };
+pub const kyber_d00 = struct {
+    pub const Kyber512 = Kyber(.{
+        .name = "Kyber512",
+        .k = 2,
+        .eta1 = 3,
+        .du = 10,
+        .dv = 4,
+    });
+
+    pub const Kyber768 = Kyber(.{
+        .name = "Kyber768",
+        .k = 3,
+        .eta1 = 2,
+        .du = 10,
+        .dv = 4,
+    });
+
+    pub const Kyber1024 = Kyber(.{
+        .name = "Kyber1024",
+        .k = 4,
+        .eta1 = 2,
+        .du = 11,
+        .dv = 5,
+    });
+};
+
+pub const ml_kem_01 = struct {
+    pub const MLKem512 = Kyber(.{
+        .name = "ML-KEM-512",
+        .ml_kem = true,
+        .k = 2,
+        .eta1 = 3,
+        .du = 10,
+        .dv = 4,
+    });
+
+    pub const MLKem768 = Kyber(.{
+        .name = "ML-KEM-768",
+        .ml_kem = true,
+        .k = 3,
+        .eta1 = 2,
+        .du = 10,
+        .dv = 4,
+    });
+
+    pub const MLKem1024 = Kyber(.{
+        .name = "ML-KEM-1024",
+        .ml_kem = true,
+        .k = 4,
+        .eta1 = 2,
+        .du = 11,
+        .dv = 5,
+    });
+};
+
+const modes = [_]type{
+    kyber_d00.Kyber512,
+    kyber_d00.Kyber768,
+    kyber_d00.Kyber1024,
+    ml_kem_01.MLKem512,
+    ml_kem_01.MLKem768,
+    ml_kem_01.MLKem1024,
+};
 const h_length: usize = 32;
 const inner_seed_length: usize = 32;
 const common_encaps_seed_length: usize = 32;
@@ -211,18 +254,18 @@ fn Kyber(comptime p: Params) type {
             /// If `seed` is `null`, a random seed is used. This is recommended.
             /// If `seed` is set, encapsulation is deterministic.
             pub fn encaps(pk: PublicKey, seed_: ?[encaps_seed_length]u8) EncapsulatedSecret {
-                const seed = seed_ orelse seed: {
-                    var random_seed: [encaps_seed_length]u8 = undefined;
-                    crypto.random.bytes(&random_seed);
-                    break :seed random_seed;
-                };
-
                 var m: [inner_plaintext_length]u8 = undefined;
 
-                // m = H(seed)
-                var h = sha3.Sha3_256.init(.{});
-                h.update(&seed);
-                h.final(&m);
+                if (seed_) |seed| {
+                    if (p.ml_kem) {
+                        @memcpy(&m, &seed);
+                    } else {
+                        // m = H(seed)
+                        sha3.Sha3_256.hash(&seed, &m, .{});
+                    }
+                } else {
+                    crypto.random.bytes(&m);
+                }
 
                 // (K', r) = G(m ‖ H(pk))
                 var kr: [inner_plaintext_length + h_length]u8 = undefined;
@@ -235,20 +278,21 @@ fn Kyber(comptime p: Params) type {
                 const ct = pk.pk.encrypt(&m, kr[32..64]);
 
                 // Compute H(c) and put in second slot of kr, which will be (K', H(c)).
-                h = sha3.Sha3_256.init(.{});
-                h.update(&ct);
-                h.final(kr[32..64]);
-
-                // K = KDF(K' ‖ H(c))
-                var kdf = sha3.Shake256.init(.{});
-                kdf.update(&kr);
-                var ss: [shared_length]u8 = undefined;
-                kdf.squeeze(&ss);
-
-                return EncapsulatedSecret{
-                    .shared_secret = ss,
-                    .ciphertext = ct,
-                };
+                sha3.Sha3_256.hash(&ct, kr[32..], .{});
+
+                if (p.ml_kem) {
+                    return EncapsulatedSecret{
+                        .shared_secret = kr[0..shared_length].*, // ML-KEM: K = K'
+                        .ciphertext = ct,
+                    };
+                } else {
+                    var ss: [shared_length]u8 = undefined;
+                    sha3.Shake256.hash(&kr, &ss, .{});
+                    return EncapsulatedSecret{
+                        .shared_secret = ss, // Kyber: K = KDF(K' ‖ H(c))
+                        .ciphertext = ct,
+                    };
+                }
             }
 
             /// Serializes the key into a byte array.
@@ -257,13 +301,10 @@ fn Kyber(comptime p: Params) type {
             }
 
             /// Deserializes the key from a byte array.
-            pub fn fromBytes(buf: *const [bytes_length]u8) !PublicKey {
+            pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!PublicKey {
                 var ret: PublicKey = undefined;
-                ret.pk = InnerPk.fromBytes(buf[0..InnerPk.bytes_length]);
-
-                var h = sha3.Sha3_256.init(.{});
-                h.update(buf);
-                h.final(&ret.hpk);
+                ret.pk = try InnerPk.fromBytes(buf[0..InnerPk.bytes_length]);
+                sha3.Sha3_256.hash(buf, &ret.hpk, .{});
                 return ret;
             }
         };
@@ -295,19 +336,20 @@ fn Kyber(comptime p: Params) type {
                 const ct2 = sk.pk.encrypt(&m2, kr2[32..64]);
 
                 // Compute H(ct) and put in the second slot of kr2 which will be (K'', H(ct)).
-                var h = sha3.Sha3_256.init(.{});
-                h.update(ct);
-                h.final(kr2[32..64]);
+                sha3.Sha3_256.hash(ct, kr2[32..], .{});
 
                 // Replace K'' by z in the first slot of kr2 if ct ≠ ct'.
                 cmov(32, kr2[0..32], sk.z, ctneq(ciphertext_length, ct.*, ct2));
 
-                // K = KDF(K''/z, H(c))
-                var kdf = sha3.Shake256.init(.{});
-                var ss: [shared_length]u8 = undefined;
-                kdf.update(&kr2);
-                kdf.squeeze(&ss);
-                return ss;
+                if (p.ml_kem) {
+                    // ML-KEM: K = K''/z
+                    return kr2[0..shared_length].*;
+                } else {
+                    // Kyber: K = KDF(K''/z ‖ H(c))
+                    var ss: [shared_length]u8 = undefined;
+                    sha3.Shake256.hash(&kr2, &ss, .{});
+                    return ss;
+                }
             }
 
             /// Serializes the key into a byte array.
@@ -316,12 +358,12 @@ fn Kyber(comptime p: Params) type {
             }
 
             /// Deserializes the key from a byte array.
-            pub fn fromBytes(buf: *const [bytes_length]u8) !SecretKey {
+            pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!SecretKey {
                 var ret: SecretKey = undefined;
                 comptime var s: usize = 0;
                 ret.sk = InnerSk.fromBytes(buf[s .. s + InnerSk.bytes_length]);
                 s += InnerSk.bytes_length;
-                ret.pk = InnerPk.fromBytes(buf[s .. s + InnerPk.bytes_length]);
+                ret.pk = try InnerPk.fromBytes(buf[s .. s + InnerPk.bytes_length]);
                 s += InnerPk.bytes_length;
                 ret.hpk = buf[s..][0..h_length].*;
                 s += h_length;
@@ -359,9 +401,7 @@ fn Kyber(comptime p: Params) type {
                 ret.secret_key.z = seed[inner_seed_length..seed_length].*;
 
                 // Compute H(pk)
-                var h = sha3.Sha3_256.init(.{});
-                h.update(&ret.public_key.pk.toBytes());
-                h.final(&ret.secret_key.hpk);
+                sha3.Sha3_256.hash(&ret.public_key.pk.toBytes(), &ret.secret_key.hpk, .{});
                 ret.public_key.hpk = ret.secret_key.hpk;
 
                 return ret;
@@ -415,9 +455,19 @@ fn Kyber(comptime p: Params) type {
                 return pk.th.toBytes() ++ pk.rho;
             }
 
-            fn fromBytes(buf: *const [bytes_length]u8) InnerPk {
+            fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!InnerPk {
                 var ret: InnerPk = undefined;
-                ret.th = V.fromBytes(buf[0..V.bytes_length]).normalize();
+
+                const th_bytes = buf[0..V.bytes_length];
+                ret.th = V.fromBytes(th_bytes).normalize();
+
+                if (p.ml_kem) {
+                    // Verify that the coefficients used a canonical representation.
+                    if (!mem.eql(u8, &ret.th.toBytes(), th_bytes)) {
+                        return error.NonCanonical;
+                    }
+                }
+
                 ret.rho = buf[V.bytes_length..bytes_length].*;
                 ret.aT = M.uniform(ret.rho, true);
                 return ret;
@@ -455,10 +505,7 @@ fn Kyber(comptime p: Params) type {
         // Derives inner PKE keypair from given seed.
         fn innerKeyFromSeed(seed: [inner_seed_length]u8, pk: *InnerPk, sk: *InnerSk) void {
             var expanded_seed: [64]u8 = undefined;
-
-            var h = sha3.Sha3_512.init(.{});
-            h.update(&seed);
-            h.final(&expanded_seed);
+            sha3.Sha3_512.hash(&seed, &expanded_seed, .{});
             pk.rho = expanded_seed[0..32].*;
             const sigma = expanded_seed[32..64];
             pk.aT = M.uniform(pk.rho, false); // Expand ρ to A; we'll transpose later on
@@ -1675,9 +1722,9 @@ const sha2 = crypto.hash.sha2;
 
 test "NIST KAT test" {
     inline for (.{
-        .{ Kyber512, "e9c2bd37133fcb40772f81559f14b1f58dccd1c816701be9ba6214d43baf4547" },
-        .{ Kyber1024, "89248f2f33f7f4f7051729111f3049c409a933ec904aedadf035f30fa5646cd5" },
-        .{ Kyber768, "a1e122cad3c24bc51622e4c242d8b8acbcd3f618fee4220400605ca8f9ea02c2" },
+        .{ kyber_d00.Kyber512, "e9c2bd37133fcb40772f81559f14b1f58dccd1c816701be9ba6214d43baf4547" },
+        .{ kyber_d00.Kyber1024, "89248f2f33f7f4f7051729111f3049c409a933ec904aedadf035f30fa5646cd5" },
+        .{ kyber_d00.Kyber768, "a1e122cad3c24bc51622e4c242d8b8acbcd3f618fee4220400605ca8f9ea02c2" },
     }) |modeHash| {
         const mode = modeHash[0];
         var seed: [48]u8 = undefined;
lib/std/crypto.zig
@@ -70,7 +70,8 @@ pub const dh = struct {
 
 /// Key Encapsulation Mechanisms.
 pub const kem = struct {
-    pub const kyber_d00 = @import("crypto/kyber_d00.zig");
+    pub const kyber_d00 = @import("crypto/ml_kem.zig").kyber_d00;
+    pub const ml_kem_01 = @import("crypto/ml_kem.zig").ml_kem_01;
 };
 
 /// Elliptic-curve arithmetic.