master
   1//! Implementation of the IND-CCA2 post-quantum secure key encapsulation mechanism (KEM)
   2//! ML-KEM (NIST FIPS-203 publication) and CRYSTALS-Kyber (v3.02/"draft00" CFRG draft).
   3//!
   4//! The namespace `d00` refers to the version currently implemented, in accordance with the CFRG draft.
   5//! The `nist` namespace refers to the FIPS-203 publication.
   6//!
   7//! Quoting from the CFRG I-D:
   8//!
   9//! Kyber is not a Diffie-Hellman (DH) style non-interactive key
  10//! agreement, but instead, Kyber is a Key Encapsulation Method (KEM).
  11//! In essence, a KEM is a Public-Key Encryption (PKE) scheme where the
  12//! plaintext cannot be specified, but is generated as a random key as
  13//! part of the encryption. A KEM can be transformed into an unrestricted
  14//! PKE using HPKE (RFC9180). On its own, a KEM can be used as a key
  15//! agreement method in TLS.
  16//!
  17//! Kyber is an IND-CCA2 secure KEM. It is constructed by applying a
  18//! Fujisaki--Okamato style transformation on InnerPKE, which is the
  19//! underlying IND-CPA secure Public Key Encryption scheme. We cannot
  20//! use InnerPKE directly, as its ciphertexts are malleable.
  21//!
  22//! ```
  23//!                     F.O. transform
  24//!     InnerPKE   ---------------------->   Kyber
  25//!     IND-CPA                              IND-CCA2
  26//! ```
  27//!
  28//! Kyber is a lattice-based scheme.  More precisely, its security is
  29//! based on the learning-with-errors-and-rounding problem in module
  30//! lattices (MLWER).  The underlying polynomial ring R (defined in
  31//! Section 5) is chosen such that multiplication is very fast using the
  32//! number theoretic transform (NTT, see Section 5.1.3).
  33//!
  34//! An InnerPKE private key is a vector _s_ over R of length k which is
  35//! _small_ in a particular way.  Here k is a security parameter akin to
  36//! the size of a prime modulus.  For Kyber512, which targets AES-128's
  37//! security level, the value of k is 2.
  38//!
  39//! The public key consists of two values:
  40//!
  41//! * _A_ a uniformly sampled k by k matrix over R _and_
  42//!
  43//! * _t = A s + e_, where e is a suitably small masking vector.
  44//!
  45//! Distinguishing between such A s + e and a uniformly sampled t is the
  46//! module learning-with-errors (MLWE) problem.  If that is hard, then it
  47//! is also hard to recover the private key from the public key as that
  48//! would allow you to distinguish between those two.
  49//!
  50//! To save space in the public key, A is recomputed deterministically
  51//! from a seed _rho_.
  52//!
  53//! A ciphertext for a message m under this public key is a pair (c_1,
  54//! c_2) computed roughly as follows:
  55//!
  56//! c_1 = Compress(A^T r + e_1, d_u)
  57//! c_2 = Compress(t^T r + e_2 + Decompress(m, 1), d_v)
  58//!
  59//! where
  60//!
  61//! * e_1, e_2 and r are small blinds;
  62//!
  63//! * Compress(-, d) removes some information, leaving d bits per
  64//!   coefficient and Decompress is such that Compress after Decompress
  65//!   does nothing and
  66//!
  67//! * d_u, d_v are scheme parameters.
  68//!
  69//! Distinguishing such a ciphertext and uniformly sampled (c_1, c_2) is
  70//! an example of the full MLWER problem, see section 4.4 of [KyberV302].
  71//!
  72//! To decrypt the ciphertext, one computes
  73//!
  74//! m = Compress(Decompress(c_2, d_v) - s^T Decompress(c_1, d_u), 1).
  75//!
  76//! It it not straight-forward to see that this formula is correct.  In
  77//! fact, there is negligible but non-zero probability that a ciphertext
  78//! does not decrypt correctly given by the DFP column in Table 4.  This
  79//! failure probability can be computed by a careful automated analysis
  80//! of the probabilities involved, see kyber_failure.py of [SecEst].
  81//!
  82//! [KyberV302](https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf)
  83//! [I-D](https://github.com/bwesterb/draft-schwabe-cfrg-kyber)
  84//! [SecEst](https://github.com/pq-crystals/security-estimates)
  85
  86// TODO
  87//
  88// - The bottleneck in Kyber are the various hash/xof calls:
  89//    - Optimize Zig's keccak implementation.
  90//    - Use SIMD to compute keccak in parallel.
  91// - Can we track bounds of coefficients using comptime types without
  92//   duplicating code?
  93// - Would be neater to have tests closer to the thing under test.
  94// - When generating a keypair, we have a copy of the inner public key with
  95//   its large matrix A in both the public key and the private key. In Go we
  96//   can just have a pointer in the private key to the public key, but
  97//   how do we do this elegantly in Zig?
  98
  99const std = @import("std");
 100const builtin = @import("builtin");
 101
 102const testing = std.testing;
 103const assert = std.debug.assert;
 104const crypto = std.crypto;
 105const errors = std.crypto.errors;
 106const math = std.math;
 107const mem = std.mem;
 108const sha3 = crypto.hash.sha3;
 109
 110const RndGen = std.Random.DefaultPrng;
 111
 112// Q is the modulus q ≡ 3329 = 2¹¹ + 2¹⁰ + 2⁸ + 1
 113const Q: i16 = 3329;
 114
 115// Montgomery R = 2^16 mod Q (for Montgomery multiplication)
 116const R: i32 = 1 << 16;
 117
 118// N is the degree of polynomials (polynomial ring dimension)
 119const N: usize = 256;
 120
 121// eta2 is the size of "small" vectors used in encryption blinds
 122const eta2: u8 = 2;
 123
 124const Params = struct {
 125    name: []const u8,
 126
 127    // NIST ML-KEM variant instead of Kyber as originally submitted.
 128    ml_kem: bool = false,
 129
 130    // Width and height of the matrix A.
 131    k: u8,
 132
 133    // Size of "small" vectors used in private key and encryption blinds.
 134    eta1: u8,
 135
 136    // How many bits to retain of u, the private-key independent part
 137    // of the ciphertext.
 138    du: u8,
 139
 140    // How many bits to retain of v, the private-key dependent part
 141    // of the ciphertext.
 142    dv: u8,
 143};
 144
 145pub const d00 = struct {
 146    pub const Kyber512 = Kyber(.{
 147        .name = "Kyber512",
 148        .k = 2,
 149        .eta1 = 3,
 150        .du = 10,
 151        .dv = 4,
 152    });
 153
 154    pub const Kyber768 = Kyber(.{
 155        .name = "Kyber768",
 156        .k = 3,
 157        .eta1 = 2,
 158        .du = 10,
 159        .dv = 4,
 160    });
 161
 162    pub const Kyber1024 = Kyber(.{
 163        .name = "Kyber1024",
 164        .k = 4,
 165        .eta1 = 2,
 166        .du = 11,
 167        .dv = 5,
 168    });
 169};
 170
 171pub const nist = struct {
 172    pub const MLKem512 = Kyber(.{
 173        .name = "ML-KEM-512",
 174        .ml_kem = true,
 175        .k = 2,
 176        .eta1 = 3,
 177        .du = 10,
 178        .dv = 4,
 179    });
 180
 181    pub const MLKem768 = Kyber(.{
 182        .name = "ML-KEM-768",
 183        .ml_kem = true,
 184        .k = 3,
 185        .eta1 = 2,
 186        .du = 10,
 187        .dv = 4,
 188    });
 189
 190    pub const MLKem1024 = Kyber(.{
 191        .name = "ML-KEM-1024",
 192        .ml_kem = true,
 193        .k = 4,
 194        .eta1 = 2,
 195        .du = 11,
 196        .dv = 5,
 197    });
 198};
 199
 200const modes = [_]type{
 201    d00.Kyber512,
 202    d00.Kyber768,
 203    d00.Kyber1024,
 204    nist.MLKem512,
 205    nist.MLKem768,
 206    nist.MLKem1024,
 207};
 208const h_length: usize = 32;
 209const inner_seed_length: usize = 32;
 210const common_encaps_seed_length: usize = 32;
 211const common_shared_key_size: usize = 32;
 212
 213fn Kyber(comptime p: Params) type {
 214    return struct {
 215        // Size of a ciphertext, in bytes.
 216        pub const ciphertext_length = Poly.compressedSize(p.du) * p.k + Poly.compressedSize(p.dv);
 217
 218        const Self = @This();
 219        const V = PolyVec(p.k);
 220        const M = Mat(p.k);
 221
 222        /// Length (in bytes) of a shared secret.
 223        pub const shared_length = common_shared_key_size;
 224        /// Length (in bytes) of a seed for deterministic encapsulation.
 225        pub const encaps_seed_length = common_encaps_seed_length;
 226        /// Length (in bytes) of a seed for key generation.
 227        pub const seed_length: usize = inner_seed_length + shared_length;
 228        /// Algorithm name.
 229        pub const name = p.name;
 230
 231        /// A shared secret, and an encapsulated (encrypted) representation of it.
 232        pub const EncapsulatedSecret = struct {
 233            shared_secret: [shared_length]u8,
 234            ciphertext: [ciphertext_length]u8,
 235        };
 236
 237        /// A Kyber public key.
 238        pub const PublicKey = struct {
 239            pk: InnerPk,
 240
 241            // Cached
 242            hpk: [h_length]u8, // H(pk)
 243
 244            /// Size of a serialized representation of the key, in bytes.
 245            pub const encoded_length = InnerPk.encoded_length;
 246
 247            /// Generates a shared secret, and encapsulates it for the public key.
 248            /// If `seed` is `null`, a random seed is used. This is recommended.
 249            /// If `seed` is set, encapsulation is deterministic.
 250            pub fn encaps(pk: PublicKey, seed_: ?[encaps_seed_length]u8) EncapsulatedSecret {
 251                var m: [inner_plaintext_length]u8 = undefined;
 252
 253                if (seed_) |seed| {
 254                    if (p.ml_kem) {
 255                        @memcpy(&m, &seed);
 256                    } else {
 257                        // m = H(seed)
 258                        sha3.Sha3_256.hash(&seed, &m, .{});
 259                    }
 260                } else {
 261                    crypto.random.bytes(&m);
 262                }
 263
 264                // (K', r) = G(m ‖ H(pk))
 265                var kr: [inner_plaintext_length + h_length]u8 = undefined;
 266                var g = sha3.Sha3_512.init(.{});
 267                g.update(&m);
 268                g.update(&pk.hpk);
 269                g.final(&kr);
 270
 271                // c = innerEncrypt(pk, m, r)
 272                const ct = pk.pk.encrypt(&m, kr[32..64]);
 273
 274                if (p.ml_kem) {
 275                    return EncapsulatedSecret{
 276                        .shared_secret = kr[0..shared_length].*, // ML-KEM: K = K'
 277                        .ciphertext = ct,
 278                    };
 279                } else {
 280                    // Compute H(c) and put in second slot of kr, which will be (K', H(c)).
 281                    sha3.Sha3_256.hash(&ct, kr[32..], .{});
 282
 283                    var ss: [shared_length]u8 = undefined;
 284                    sha3.Shake256.hash(&kr, &ss, .{});
 285                    return EncapsulatedSecret{
 286                        .shared_secret = ss, // Kyber: K = KDF(K' ‖ H(c))
 287                        .ciphertext = ct,
 288                    };
 289                }
 290            }
 291
 292            /// Serializes the key into a byte array.
 293            pub fn toBytes(pk: PublicKey) [encoded_length]u8 {
 294                return pk.pk.toBytes();
 295            }
 296
 297            /// Deserializes the key from a byte array.
 298            pub fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!PublicKey {
 299                var ret: PublicKey = undefined;
 300                ret.pk = try InnerPk.fromBytes(buf[0..InnerPk.encoded_length]);
 301                sha3.Sha3_256.hash(buf, &ret.hpk, .{});
 302                return ret;
 303            }
 304        };
 305
 306        /// A Kyber secret key.
 307        pub const SecretKey = struct {
 308            sk: InnerSk,
 309            pk: InnerPk,
 310            hpk: [h_length]u8, // H(pk)
 311            z: [shared_length]u8,
 312
 313            /// Size of a serialized representation of the key, in bytes.
 314            pub const encoded_length: usize =
 315                InnerSk.encoded_length + InnerPk.encoded_length + h_length + shared_length;
 316
 317            /// Decapsulates the shared secret within ct using the private key.
 318            pub fn decaps(sk: SecretKey, ct: *const [ciphertext_length]u8) ![shared_length]u8 {
 319                // m' = innerDec(ct)
 320                const m2 = sk.sk.decrypt(ct);
 321
 322                // (K'', r') = G(m' ‖ H(pk))
 323                var kr2: [64]u8 = undefined;
 324                var g = sha3.Sha3_512.init(.{});
 325                g.update(&m2);
 326                g.update(&sk.hpk);
 327                g.final(&kr2);
 328
 329                // ct' = innerEnc(pk, m', r')
 330                const ct2 = sk.pk.encrypt(&m2, kr2[32..64]);
 331
 332                // Compute H(ct) and put in the second slot of kr2 which will be (K'', H(ct)).
 333                sha3.Sha3_256.hash(ct, kr2[32..], .{});
 334
 335                // Replace K'' by z in the first slot of kr2 if ct ≠ ct'.
 336                cmov(32, kr2[0..32], sk.z, ctneq(ciphertext_length, ct.*, ct2));
 337
 338                if (p.ml_kem) {
 339                    // ML-KEM: K = K''/z
 340                    return kr2[0..shared_length].*;
 341                } else {
 342                    // Kyber: K = KDF(K''/z ‖ H(c))
 343                    var ss: [shared_length]u8 = undefined;
 344                    sha3.Shake256.hash(&kr2, &ss, .{});
 345                    return ss;
 346                }
 347            }
 348
 349            /// Serializes the key into a byte array.
 350            pub fn toBytes(sk: SecretKey) [encoded_length]u8 {
 351                return sk.sk.toBytes() ++ sk.pk.toBytes() ++ sk.hpk ++ sk.z;
 352            }
 353
 354            /// Deserializes the key from a byte array.
 355            pub fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!SecretKey {
 356                var ret: SecretKey = undefined;
 357                comptime var s: usize = 0;
 358                ret.sk = InnerSk.fromBytes(buf[s .. s + InnerSk.encoded_length]);
 359                s += InnerSk.encoded_length;
 360                ret.pk = try InnerPk.fromBytes(buf[s .. s + InnerPk.encoded_length]);
 361                s += InnerPk.encoded_length;
 362                ret.hpk = buf[s..][0..h_length].*;
 363                s += h_length;
 364                ret.z = buf[s..][0..shared_length].*;
 365                return ret;
 366            }
 367        };
 368
 369        /// A Kyber key pair.
 370        pub const KeyPair = struct {
 371            secret_key: SecretKey,
 372            public_key: PublicKey,
 373
 374            /// Deterministically derive a key pair from a cryptograpically secure secret seed.
 375            ///
 376            /// Except in tests, applications should generally call `generate()` instead of this function.
 377            pub fn generateDeterministic(seed: [seed_length]u8) !KeyPair {
 378                var ret: KeyPair = undefined;
 379
 380                // Generate inner key
 381                innerKeyFromSeed(
 382                    seed[0..inner_seed_length].*,
 383                    &ret.public_key.pk,
 384                    &ret.secret_key.sk,
 385                );
 386                ret.secret_key.pk = ret.public_key.pk;
 387
 388                // Copy over z from seed.
 389                ret.secret_key.z = seed[inner_seed_length..seed_length].*;
 390
 391                // Compute H(pk)
 392                sha3.Sha3_256.hash(&ret.public_key.pk.toBytes(), &ret.secret_key.hpk, .{});
 393                ret.public_key.hpk = ret.secret_key.hpk;
 394
 395                return ret;
 396            }
 397
 398            /// Generate a new, random key pair.
 399            pub fn generate() KeyPair {
 400                var random_seed: [seed_length]u8 = undefined;
 401                while (true) {
 402                    crypto.random.bytes(&random_seed);
 403                    return generateDeterministic(random_seed) catch {
 404                        @branchHint(.unlikely);
 405                        continue;
 406                    };
 407                }
 408            }
 409        };
 410
 411        // Size of plaintexts of the in
 412        const inner_plaintext_length: usize = Poly.compressedSize(1);
 413
 414        const InnerPk = struct {
 415            rho: [32]u8, // ρ, the seed for the matrix A
 416            th: V, // NTT(t), normalized
 417
 418            // Cached values
 419            aT: M,
 420
 421            const encoded_length = V.encoded_length + 32;
 422
 423            fn encrypt(
 424                pk: InnerPk,
 425                pt: *const [inner_plaintext_length]u8,
 426                seed: *const [32]u8,
 427            ) [ciphertext_length]u8 {
 428                // Sample r, e₁ and e₂ appropriately
 429                const rh = V.noise(p.eta1, 0, seed).ntt().barrettReduce();
 430                const e1 = V.noise(eta2, p.k, seed);
 431                const e2 = Poly.noise(eta2, 2 * p.k, seed);
 432
 433                // Next we compute u = Aᵀ r + e₁.  First Aᵀ.
 434                var u: V = undefined;
 435                for (0..p.k) |i| {
 436                    // Note that coefficients of r are bounded by q and those of Aᵀ
 437                    // are bounded by 4.5q and so their product is bounded by 2¹⁵q
 438                    // as required for multiplication.
 439                    u.ps[i] = pk.aT.rows[i].dotHat(rh);
 440                }
 441
 442                // Aᵀ and r were not in Montgomery form, so the Montgomery
 443                // multiplications in the inner product added a factor R⁻¹ which
 444                // the InvNTT cancels out.
 445                u = u.barrettReduce().invNTT().add(e1).normalize();
 446
 447                // Next, compute v = <t, r> + e₂ + Decompress_q(m, 1)
 448                const v = pk.th.dotHat(rh).barrettReduce().invNTT()
 449                    .add(Poly.decompress(1, pt)).add(e2).normalize();
 450
 451                return u.compress(p.du) ++ v.compress(p.dv);
 452            }
 453
 454            fn toBytes(pk: InnerPk) [encoded_length]u8 {
 455                return pk.th.toBytes() ++ pk.rho;
 456            }
 457
 458            fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!InnerPk {
 459                var ret: InnerPk = undefined;
 460
 461                const th_bytes = buf[0..V.encoded_length];
 462                ret.th = V.fromBytes(th_bytes).normalize();
 463
 464                if (p.ml_kem) {
 465                    // Verify that the coefficients used a canonical representation.
 466                    if (!mem.eql(u8, &ret.th.toBytes(), th_bytes)) {
 467                        return error.NonCanonical;
 468                    }
 469                }
 470
 471                ret.rho = buf[V.encoded_length..encoded_length].*;
 472                ret.aT = M.uniform(ret.rho, true);
 473                return ret;
 474            }
 475        };
 476
 477        // Private key of the inner PKE
 478        const InnerSk = struct {
 479            sh: V, // NTT(s), normalized
 480            const encoded_length = V.encoded_length;
 481
 482            fn decrypt(sk: InnerSk, ct: *const [ciphertext_length]u8) [inner_plaintext_length]u8 {
 483                const u = V.decompress(p.du, ct[0..comptime V.compressedSize(p.du)]);
 484                const v = Poly.decompress(
 485                    p.dv,
 486                    ct[comptime V.compressedSize(p.du)..ciphertext_length],
 487                );
 488
 489                // Compute m = v - <s, u>
 490                return v.sub(sk.sh.dotHat(u.ntt()).barrettReduce().invNTT())
 491                    .normalize().compress(1);
 492            }
 493
 494            fn toBytes(sk: InnerSk) [encoded_length]u8 {
 495                return sk.sh.toBytes();
 496            }
 497
 498            fn fromBytes(buf: *const [encoded_length]u8) InnerSk {
 499                var ret: InnerSk = undefined;
 500                ret.sh = V.fromBytes(buf).normalize();
 501                return ret;
 502            }
 503        };
 504
 505        // Derives inner PKE keypair from given seed.
 506        fn innerKeyFromSeed(seed: [inner_seed_length]u8, pk: *InnerPk, sk: *InnerSk) void {
 507            var expanded_seed: [64]u8 = undefined;
 508            var h = sha3.Sha3_512.init(.{});
 509            h.update(&seed);
 510            if (p.ml_kem) h.update(&[1]u8{p.k});
 511            h.final(&expanded_seed);
 512            pk.rho = expanded_seed[0..32].*;
 513            const sigma = expanded_seed[32..64];
 514            pk.aT = M.uniform(pk.rho, false); // Expand ρ to A; we'll transpose later on
 515
 516            // Sample secret vector s.
 517            sk.sh = V.noise(p.eta1, 0, sigma).ntt().normalize();
 518
 519            const eh = PolyVec(p.k).noise(p.eta1, p.k, sigma).ntt(); // sample blind e.
 520            var th: V = undefined;
 521
 522            // Next, we compute t = A s + e.
 523            for (0..p.k) |i| {
 524                // Note that coefficients of s are bounded by q and those of A
 525                // are bounded by 4.5q and so their product is bounded by 2¹⁵q
 526                // as required for multiplication.
 527                // A and s were not in Montgomery form, so the Montgomery
 528                // multiplications in the inner product added a factor R⁻¹ which
 529                // we'll cancel out with toMont().  This will also ensure the
 530                // coefficients of th are bounded in absolute value by q.
 531                th.ps[i] = pk.aT.rows[i].dotHat(sk.sh).toMont();
 532            }
 533
 534            pk.th = th.add(eh).normalize(); // bounded by 8q
 535            pk.aT = pk.aT.transpose();
 536        }
 537    };
 538}
 539
 540// R mod q
 541const r_mod_q: i32 = @rem(@as(i32, R), Q);
 542
 543// R² mod q
 544const r2_mod_q: i32 = @rem(r_mod_q * r_mod_q, Q);
 545
 546// ζ is the degree 256 primitive root of unity used for the NTT.
 547const zeta: i16 = 17;
 548
 549// (128)⁻¹ R². Used in inverse NTT.
 550const r2_over_128: i32 = @mod(invertMod(128, Q) * r2_mod_q, Q);
 551
 552// zetas lists precomputed powers of the primitive root of unity in
 553// Montgomery representation used for the NTT:
 554//
 555//  zetas[i] = ζᵇʳᵛ⁽ⁱ⁾ R mod q
 556//
 557// where ζ = 17, brv(i) is the bitreversal of a 7-bit number and R=2¹⁶ mod q.
 558const zetas = computeZetas();
 559
 560// invNTTReductions keeps track of which coefficients to apply Barrett
 561// reduction to in Poly.invNTT().
 562//
 563// Generated lazily: once a butterfly is computed which is about to
 564// overflow the i16, the largest coefficient is reduced.  If that is
 565// not enough, the other coefficient is reduced as well.
 566//
 567// This is actually optimal, as proven in https://eprint.iacr.org/2020/1377.pdf
 568const inv_ntt_reductions = [_]i16{
 569    -1, // after layer 1
 570    -1, // after layer 2
 571    16,
 572    17,
 573    48,
 574    49,
 575    80,
 576    81,
 577    112,
 578    113,
 579    144,
 580    145,
 581    176,
 582    177,
 583    208,
 584    209,
 585    240, 241, -1, // after layer 3
 586    0,   1,   32,
 587    33,  34,  35,
 588    64,  65,  96,
 589    97,  98,  99,
 590    128, 129,
 591    160, 161, 162, 163, 192, 193, 224, 225, 226, 227, -1, // after layer 4
 592    2,   3,   66,  67,  68,  69,  70,  71,  130, 131, 194,
 593    195, 196, 197,
 594    198, 199, -1, // after layer 5
 595    4,   5,   6,
 596    7,   132, 133,
 597    134, 135, 136,
 598    137, 138, 139,
 599    140, 141,
 600    142, 143, -1, // after layer 6
 601    -1, //  after layer 7
 602};
 603
 604test "invNTTReductions bounds" {
 605    // Checks whether the reductions proposed by invNTTReductions
 606    // don't overflow during invNTT().
 607    var xs = [_]i32{1} ** 256; // start at |x| ≤ q
 608
 609    var r: usize = 0;
 610    var layer: math.Log2Int(usize) = 1;
 611    while (layer < 8) : (layer += 1) {
 612        const w = @as(usize, 1) << layer;
 613        var i: usize = 0;
 614
 615        while (i + w < 256) {
 616            xs[i] = xs[i] + xs[i + w];
 617            try testing.expect(xs[i] <= 9); // we can't exceed 9q
 618            xs[i + w] = 1;
 619            i += 1;
 620            if (@mod(i, w) == 0) {
 621                i += w;
 622            }
 623        }
 624
 625        while (true) {
 626            const j = inv_ntt_reductions[r];
 627            r += 1;
 628            if (j < 0) {
 629                break;
 630            }
 631            xs[@as(usize, @intCast(j))] = 1;
 632        }
 633    }
 634}
 635
 636fn invertMod(a: anytype, p: @TypeOf(a)) @TypeOf(a) {
 637    const r = extendedEuclidean(@TypeOf(a), a, p);
 638    assert(r.gcd == 1);
 639    return r.x;
 640}
 641
 642// Reduce mod q for testing.
 643fn modQ32(x: i32) i16 {
 644    var y = @as(i16, @intCast(@rem(x, @as(i32, Q))));
 645    if (y < 0) {
 646        y += Q;
 647    }
 648    return y;
 649}
 650
 651// Given -2¹⁵ q ≤ x < 2¹⁵ q, returns -q < y < q with x 2⁻¹⁶ = y (mod q).
 652fn montReduce(x: i32) i16 {
 653    const qInv = comptime invertMod(@as(i32, Q), R);
 654    // This is Montgomery reduction with R=2¹⁶.
 655    //
 656    // Note gcd(2¹⁶, q) = 1 as q is prime.  Write q' := 62209 = q⁻¹ mod R.
 657    // First we compute
 658    //
 659    // m := ((x mod R) q') mod R
 660    //         = x q' mod R
 661    //    = int16(x q')
 662    //    = int16(int32(x) * int32(q'))
 663    //
 664    // Note that x q' might be as big as 2³² and could overflow the int32
 665    // multiplication in the last line.  However for any int32s a and b,
 666    // we have int32(int64(a)*int64(b)) = int32(a*b) and so the result is ok.
 667    const m: i16 = @truncate(@as(i32, @truncate(x *% qInv)));
 668
 669    // Note that x - m q is divisible by R; indeed modulo R we have
 670    //
 671    //  x - m q ≡ x - x q' q ≡ x - x q⁻¹ q ≡ x - x = 0.
 672    //
 673    // We return y := (x - m q) / R.  Note that y is indeed correct as
 674    // modulo q we have
 675    //
 676    //  y ≡ x R⁻¹ - m q R⁻¹ = x R⁻¹
 677    //
 678    // and as both 2¹⁵ q ≤ m q, x < 2¹⁵ q, we have
 679    // 2¹⁶ q ≤ x - m q < 2¹⁶ and so q ≤ (x - m q) / R < q as desired.
 680    const yR = x - @as(i32, m) * @as(i32, Q);
 681    return @bitCast(@as(u16, @truncate(@as(u32, @bitCast(yR)) >> 16)));
 682}
 683
 684test "Test montReduce" {
 685    var rnd = RndGen.init(0);
 686    for (0..1000) |_| {
 687        const bound = comptime @as(i32, Q) * (1 << 15);
 688        const x = rnd.random().intRangeLessThan(i32, -bound, bound);
 689        const y = montReduce(x);
 690        try testing.expect(-Q < y and y < Q);
 691        try testing.expectEqual(modQ32(x), modQ32(@as(i32, y) * R));
 692    }
 693}
 694
 695// Given any x, return x R mod q where R=2¹⁶.
 696fn feToMont(x: i16) i16 {
 697    // Note |1353 x| ≤ 1353 2¹⁵ ≤ 13318 q ≤ 2¹⁵ q and so we're within
 698    // the bounds of montReduce.
 699    return montReduce(@as(i32, x) * r2_mod_q);
 700}
 701
 702test "Test feToMont" {
 703    var x: i32 = -(1 << 15);
 704    while (x < 1 << 15) : (x += 1) {
 705        const y = feToMont(@as(i16, @intCast(x)));
 706        try testing.expectEqual(modQ32(@as(i32, y)), modQ32(x * r_mod_q));
 707    }
 708}
 709
 710// Given any x, compute 0 ≤ y ≤ q with x = y (mod q).
 711//
 712// Beware: we might have feBarrettReduce(x) = q ≠ 0 for some x.  In fact,
 713// this happens if and only if x = -nq for some positive integer n.
 714fn feBarrettReduce(x: i16) i16 {
 715    // This is standard Barrett reduction.
 716    //
 717    // For any x we have x mod q = x - ⌊x/q⌋ q.  We will use 20159/2²⁶ as
 718    // an approximation of 1/q. Note that  0 ≤ 20159/2²⁶ - 1/q ≤ 0.135/2²⁶
 719    // and so | x 20156/2²⁶ - x/q | ≤ 2⁻¹⁰ for |x| ≤ 2¹⁶.  For all x
 720    // not a multiple of q, the number x/q is further than 1/q from any integer
 721    // and so ⌊x 20156/2²⁶⌋ = ⌊x/q⌋.  If x is a multiple of q and x is positive,
 722    // then x 20156/2²⁶ is larger than x/q so ⌊x 20156/2²⁶⌋ = ⌊x/q⌋ as well.
 723    // Finally, if x is negative multiple of q, then ⌊x 20156/2²⁶⌋ = ⌊x/q⌋-1.
 724    // Thus
 725    //                        [ q        if x=-nq for pos. integer n
 726    //  x - ⌊x 20156/2²⁶⌋ q = [
 727    //                        [ x mod q  otherwise
 728    //
 729    // To actually compute this, note that
 730    //
 731    //  ⌊x 20156/2²⁶⌋ = (20159 x) >> 26.
 732    return x -% @as(i16, @intCast((@as(i32, x) * 20159) >> 26)) *% Q;
 733}
 734
 735test "Test Barrett reduction" {
 736    var x: i32 = -(1 << 15);
 737    while (x < 1 << 15) : (x += 1) {
 738        var y1 = feBarrettReduce(@as(i16, @intCast(x)));
 739        const y2 = @mod(@as(i16, @intCast(x)), Q);
 740        if (x < 0 and @rem(-x, Q) == 0) {
 741            y1 -= Q;
 742        }
 743        try testing.expectEqual(y1, y2);
 744    }
 745}
 746
 747// Returns x if x < q and x - q otherwise.  Assumes x ≥ -29439.
 748fn csubq(x: i16) i16 {
 749    var r = x;
 750    r -= Q;
 751    r += (r >> 15) & Q;
 752    return r;
 753}
 754
 755test "Test csubq" {
 756    var x: i32 = -29439;
 757    while (x < 1 << 15) : (x += 1) {
 758        const y1 = csubq(@as(i16, @intCast(x)));
 759        var y2 = @as(i16, @intCast(x));
 760        if (@as(i16, @intCast(x)) >= Q) {
 761            y2 -= Q;
 762        }
 763        try testing.expectEqual(y1, y2);
 764    }
 765}
 766
 767// Computes zetas table used by ntt and invNTT.
 768fn computeZetas() [128]i16 {
 769    @setEvalBranchQuota(10000);
 770    var ret: [128]i16 = undefined;
 771    for (&ret, 0..) |*r, i| {
 772        const t = @as(i16, @intCast(modularPow(i32, zeta, @bitReverse(@as(u7, @intCast(i))), Q)));
 773        r.* = csubq(feBarrettReduce(feToMont(t)));
 774    }
 775    return ret;
 776}
 777
 778// An element of our base ring R which are polynomials over ℤ_q
 779// modulo the equation Xᴺ = -1, where q=3329 and N=256.
 780//
 781// This type is also used to store NTT-transformed polynomials,
 782// see Poly.NTT().
 783//
 784// Coefficients aren't always reduced.  See Normalize().
 785const Poly = struct {
 786    cs: [N]i16,
 787
 788    const encoded_length = N / 2 * 3;
 789    const zero: Poly = .{ .cs = .{0} ** N };
 790
 791    // Add two polynomials (coefficients not normalized)
 792    fn add(a: Poly, b: Poly) Poly {
 793        var ret: Poly = undefined;
 794        for (0..N) |i| {
 795            ret.cs[i] = a.cs[i] + b.cs[i];
 796        }
 797        return ret;
 798    }
 799
 800    // Subtract two polynomials (coefficients not normalized)
 801    fn sub(a: Poly, b: Poly) Poly {
 802        var ret: Poly = undefined;
 803        for (0..N) |i| {
 804            ret.cs[i] = a.cs[i] - b.cs[i];
 805        }
 806        return ret;
 807    }
 808
 809    // Executes a forward "NTT" on p.
 810    //
 811    // Assumes the coefficients are in absolute value ≤q.  The resulting
 812    // coefficients are in absolute value ≤7q.  If the input is in Montgomery
 813    // form, then the result is in Montgomery form and so (by linearity of the NTT)
 814    // if the input is in regular form, then the result is also in regular form.
 815    fn ntt(a: Poly) Poly {
 816        // Note that ℤ_q does not have a primitive 512ᵗʰ root of unity (as 512
 817        // does not divide into q-1) and so we cannot do a regular NTT.  ℤ_q
 818        // does have a primitive 256ᵗʰ root of unity, the smallest of which
 819        // is ζ := 17.
 820        //
 821        // Recall that our base ring R := ℤ_q[x] / (x²⁵⁶ + 1).  The polynomial
 822        // x²⁵⁶+1 will not split completely (as its roots would be 512ᵗʰ roots
 823        // of unity.)  However, it does split almost (using ζ¹²⁸ = -1):
 824        //
 825        // x²⁵⁶ + 1 = (x²)¹²⁸ - ζ¹²⁸
 826        //          = ((x²)⁶⁴ - ζ⁶⁴)((x²)⁶⁴ + ζ⁶⁴)
 827        //          = ((x²)³² - ζ³²)((x²)³² + ζ³²)((x²)³² - ζ⁹⁶)((x²)³² + ζ⁹⁶)
 828        //          ⋮
 829        //          = (x² - ζ)(x² + ζ)(x² - ζ⁶⁵)(x² + ζ⁶⁵) … (x² + ζ¹²⁷)
 830        //
 831        // Note that the powers of ζ that appear (from the second line down) are
 832        // in binary
 833        //
 834        // 0100000 1100000
 835        // 0010000 1010000 0110000 1110000
 836        // 0001000 1001000 0101000 1101000 0011000 1011000 0111000 1111000
 837        //         …
 838        //
 839        // That is: brv(2), brv(3), brv(4), …, where brv(x) denotes the 7-bit
 840        // bitreversal of x.  These powers of ζ are given by the Zetas array.
 841        //
 842        // The polynomials x² ± ζⁱ are irreducible and coprime, hence by
 843        // the Chinese Remainder Theorem we know
 844        //
 845        //  ℤ_q[x]/(x²⁵⁶+1) → ℤ_q[x]/(x²-ζ) x … x  ℤ_q[x]/(x²+ζ¹²⁷)
 846        //
 847        // given by a ↦ ( a mod x²-ζ, …, a mod x²+ζ¹²⁷ )
 848        // is an isomorphism, which is the "NTT".  It can be efficiently computed by
 849        //
 850        //
 851        //  a ↦ ( a mod (x²)⁶⁴ - ζ⁶⁴, a mod (x²)⁶⁴ + ζ⁶⁴ )
 852        //    ↦ ( a mod (x²)³² - ζ³², a mod (x²)³² + ζ³²,
 853        //        a mod (x²)⁹⁶ - ζ⁹⁶, a mod (x²)⁹⁶ + ζ⁹⁶ )
 854        //
 855        //      et cetera
 856        // If N was 8 then this can be pictured in the following diagram:
 857        //
 858        //  https://cnx.org/resources/17ee4dfe517a6adda05377b25a00bf6e6c93c334/File0026.png
 859        //
 860        // Each cross is a Cooley-Tukey butterfly: it's the map
 861        //
 862        //  (a, b) ↦ (a + ζb, a - ζb)
 863        //
 864        // for the appropriate power ζ for that column and row group.
 865        var p = a;
 866        var k: usize = 0; // index into zetas
 867
 868        var l = N >> 1;
 869        while (l > 1) : (l >>= 1) {
 870            // On the nᵗʰ iteration of the l-loop, the absolute value of the
 871            // coefficients are bounded by nq.
 872
 873            // offset effectively loops over the row groups in this column; it is
 874            // the first row in the row group.
 875            var offset: usize = 0;
 876            while (offset < N - l) : (offset += 2 * l) {
 877                k += 1;
 878                const z = @as(i32, zetas[k]);
 879
 880                // j loops over each butterfly in the row group.
 881                for (offset..offset + l) |j| {
 882                    const t = montReduce(z * @as(i32, p.cs[j + l]));
 883                    p.cs[j + l] = p.cs[j] - t;
 884                    p.cs[j] += t;
 885                }
 886            }
 887        }
 888
 889        return p;
 890    }
 891
 892    // Executes an inverse "NTT" on p and multiply by the Montgomery factor R.
 893    //
 894    // Assumes the coefficients are in absolute value ≤q.  The resulting
 895    // coefficients are in absolute value ≤q.  If the input is in Montgomery
 896    // form, then the result is in Montgomery form and so (by linearity)
 897    // if the input is in regular form, then the result is also in regular form.
 898    fn invNTT(a: Poly) Poly {
 899        var k: usize = 127; // index into zetas
 900        var r: usize = 0; // index into invNTTReductions
 901        var p = a;
 902
 903        // We basically do the oppposite of NTT, but postpone dividing by 2 in the
 904        // inverse of the Cooley-Tukey butterfly and accumulate that into a big
 905        // division by 2⁷ at the end.  See the comments in the ntt() function.
 906
 907        var l: usize = 2;
 908        while (l < N) : (l <<= 1) {
 909            var offset: usize = 0;
 910            while (offset < N - l) : (offset += 2 * l) {
 911                // As we're inverting, we need powers of ζ⁻¹ (instead of ζ).
 912                // To be precise, we need ζᵇʳᵛ⁽ᵏ⁾⁻¹²⁸. However, as ζ⁻¹²⁸ = -1,
 913                // we can use the existing zetas table instead of
 914                // keeping a separate invZetas table as in Dilithium.
 915
 916                const minZeta = @as(i32, zetas[k]);
 917                k -= 1;
 918
 919                for (offset..offset + l) |j| {
 920                    // Gentleman-Sande butterfly: (a, b) ↦ (a + b, ζ(a-b))
 921                    const t = p.cs[j + l] - p.cs[j];
 922                    p.cs[j] += p.cs[j + l];
 923                    p.cs[j + l] = montReduce(minZeta * @as(i32, t));
 924
 925                    // Note that if we had |a| < αq and |b| < βq before the
 926                    // butterfly, then now we have |a| < (α+β)q and |b| < q.
 927                }
 928            }
 929
 930            // We let the invNTTReductions instruct us which coefficients to
 931            // Barrett reduce.
 932            while (true) {
 933                const i = inv_ntt_reductions[r];
 934                r += 1;
 935                if (i < 0) {
 936                    break;
 937                }
 938                p.cs[@as(usize, @intCast(i))] = feBarrettReduce(p.cs[@as(usize, @intCast(i))]);
 939            }
 940        }
 941
 942        for (0..N) |j| {
 943            // Note 1441 = (128)⁻¹ R².  The coefficients are bounded by 9q, so
 944            // as 1441 * 9 ≈ 2¹⁴ < 2¹⁵, we're within the required bounds
 945            // for montReduce().
 946            p.cs[j] = montReduce(r2_over_128 * @as(i32, p.cs[j]));
 947        }
 948
 949        return p;
 950    }
 951
 952    // Normalizes coefficients.
 953    //
 954    // Ensures each coefficient is in {0, …, q-1}.
 955    fn normalize(a: Poly) Poly {
 956        var ret: Poly = undefined;
 957        for (0..N) |i| {
 958            ret.cs[i] = csubq(feBarrettReduce(a.cs[i]));
 959        }
 960        return ret;
 961    }
 962
 963    // Put p in Montgomery form.
 964    fn toMont(a: Poly) Poly {
 965        var ret: Poly = undefined;
 966        for (0..N) |i| {
 967            ret.cs[i] = feToMont(a.cs[i]);
 968        }
 969        return ret;
 970    }
 971
 972    // Barret reduce coefficients.
 973    //
 974    // Beware, this does not fully normalize coefficients.
 975    fn barrettReduce(a: Poly) Poly {
 976        var ret: Poly = undefined;
 977        for (0..N) |i| {
 978            ret.cs[i] = feBarrettReduce(a.cs[i]);
 979        }
 980        return ret;
 981    }
 982
 983    fn compressedSize(comptime d: u8) usize {
 984        return @divTrunc(N * d, 8);
 985    }
 986
 987    // Returns packed Compress_q(p, d).
 988    //
 989    // Assumes p is normalized.
 990    fn compress(p: Poly, comptime d: u8) [compressedSize(d)]u8 {
 991        @setEvalBranchQuota(10000);
 992        const q_over_2: u32 = comptime @divTrunc(Q, 2); // (q-1)/2
 993        const two_d_min_1: u32 = comptime (1 << d) - 1; // 2ᵈ-1
 994        var in_off: usize = 0;
 995        var out_off: usize = 0;
 996
 997        const batch_size: usize = comptime math.lcm(d, 8);
 998        const in_batch_size: usize = comptime batch_size / d;
 999        const out_batch_size: usize = comptime batch_size / 8;
1000
1001        const out_length: usize = comptime @divTrunc(N * d, 8);
1002        comptime assert(out_length * 8 == d * N);
1003        var out = [_]u8{0} ** out_length;
1004
1005        while (in_off < N) {
1006            // First we compress into in.
1007            var in: [in_batch_size]u16 = undefined;
1008            inline for (0..in_batch_size) |i| {
1009                // Compress_q(x, d) = ⌈(2ᵈ/q)x⌋ mod⁺ 2ᵈ
1010                //                  = ⌊(2ᵈ/q)x+½⌋ mod⁺ 2ᵈ
1011                //                  = ⌊((x << d) + q/2) / q⌋ mod⁺ 2ᵈ
1012                //                  = DIV((x << d) + q/2, q) & ((1<<d) - 1)
1013                const t = @as(u24, @intCast(p.cs[in_off + i])) << d;
1014                // Division by invariant multiplication, equivalent to DIV(t + q/2, q).
1015                // A division may not be a constant-time operation, even with a constant denominator.
1016                // Here, side channels would leak information about the shared secret, see https://kyberslash.cr.yp.to
1017                // Multiplication, on the other hand, is a constant-time operation on the CPUs we currently support.
1018                comptime assert(d <= 11);
1019                comptime assert(((20642679 * @as(u64, Q)) >> 36) == 1);
1020                const u: u32 = @intCast((@as(u64, t + q_over_2) * 20642679) >> 36);
1021                in[i] = @intCast(u & two_d_min_1);
1022            }
1023
1024            // Now we pack the d-bit integers from `in' into out as bytes.
1025            comptime var in_shift: usize = 0;
1026            comptime var j: usize = 0;
1027            comptime var i: usize = 0;
1028            inline while (i < in_batch_size) : (j += 1) {
1029                comptime var todo: usize = 8;
1030                inline while (todo > 0) {
1031                    const out_shift = comptime 8 - todo;
1032                    out[out_off + j] |= @as(u8, @truncate((in[i] >> in_shift) << out_shift));
1033
1034                    const done = comptime @min(@min(d, todo), d - in_shift);
1035                    todo -= done;
1036                    in_shift += done;
1037
1038                    if (in_shift == d) {
1039                        in_shift = 0;
1040                        i += 1;
1041                    }
1042                }
1043            }
1044
1045            in_off += in_batch_size;
1046            out_off += out_batch_size;
1047        }
1048
1049        return out;
1050    }
1051
1052    // Set p to Decompress_q(m, d).
1053    fn decompress(comptime d: u8, in: *const [compressedSize(d)]u8) Poly {
1054        @setEvalBranchQuota(10000);
1055        const in_len = comptime @divTrunc(N * d, 8);
1056        comptime assert(in_len * 8 == d * N);
1057        var ret: Poly = undefined;
1058        var in_off: usize = 0;
1059        var out_off: usize = 0;
1060
1061        const batch_size: usize = comptime math.lcm(d, 8);
1062        const in_batch_size: usize = comptime batch_size / 8;
1063        const out_batch_size: usize = comptime batch_size / d;
1064
1065        while (out_off < N) {
1066            comptime var in_shift: usize = 0;
1067            comptime var j: usize = 0;
1068            comptime var i: usize = 0;
1069            inline while (i < out_batch_size) : (i += 1) {
1070                // First, unpack next coefficient.
1071                comptime var todo = d;
1072                var out: u16 = 0;
1073
1074                inline while (todo > 0) {
1075                    const out_shift = comptime d - todo;
1076                    const m = comptime (1 << d) - 1;
1077                    out |= (@as(u16, in[in_off + j] >> in_shift) << out_shift) & m;
1078
1079                    const done = comptime @min(@min(8, todo), 8 - in_shift);
1080                    todo -= done;
1081                    in_shift += done;
1082
1083                    if (in_shift == 8) {
1084                        in_shift = 0;
1085                        j += 1;
1086                    }
1087                }
1088
1089                // Decompress_q(x, d) = ⌈(q/2ᵈ)x⌋
1090                //                    = ⌊(q/2ᵈ)x+½⌋
1091                //                    = ⌊(qx + 2ᵈ⁻¹)/2ᵈ⌋
1092                //                    = (qx + (1<<(d-1))) >> d
1093                const qx = @as(u32, out) * @as(u32, Q);
1094                ret.cs[out_off + i] = @as(i16, @intCast((qx + (1 << (d - 1))) >> d));
1095            }
1096
1097            in_off += in_batch_size;
1098            out_off += out_batch_size;
1099        }
1100
1101        return ret;
1102    }
1103
1104    // Returns the "pointwise" multiplication a o b.
1105    //
1106    // That is: invNTT(a o b) = invNTT(a) * invNTT(b).  Assumes a and b are in
1107    // Montgomery form.  Products between coefficients of a and b must be strictly
1108    // bounded in absolute value by 2¹⁵q.  a o b will be in Montgomery form and
1109    // bounded in absolute value by 2q.
1110    fn mulHat(a: Poly, b: Poly) Poly {
1111        // Recall from the discussion in ntt(), that a transformed polynomial is
1112        // an element of ℤ_q[x]/(x²-ζ) x … x  ℤ_q[x]/(x²+ζ¹²⁷);
1113        // that is: 128 degree-one polynomials instead of simply 256 elements
1114        // from ℤ_q as in the regular NTT.  So instead of pointwise multiplication,
1115        // we multiply the 128 pairs of degree-one polynomials modulo the
1116        // right equation:
1117        //
1118        //  (a₁ + a₂x)(b₁ + b₂x) = a₁b₁ + a₂b₂ζ' + (a₁b₂ + a₂b₁)x,
1119        //
1120        // where ζ' is the appropriate power of ζ.
1121
1122        var p: Poly = undefined;
1123        var k: usize = 64;
1124        var i: usize = 0;
1125        while (i < N) : (i += 4) {
1126            const z = @as(i32, zetas[k]);
1127            k += 1;
1128
1129            const a1b1 = montReduce(@as(i32, a.cs[i + 1]) * @as(i32, b.cs[i + 1]));
1130            const a0b0 = montReduce(@as(i32, a.cs[i]) * @as(i32, b.cs[i]));
1131            const a1b0 = montReduce(@as(i32, a.cs[i + 1]) * @as(i32, b.cs[i]));
1132            const a0b1 = montReduce(@as(i32, a.cs[i]) * @as(i32, b.cs[i + 1]));
1133
1134            p.cs[i] = montReduce(a1b1 * z) + a0b0;
1135            p.cs[i + 1] = a0b1 + a1b0;
1136
1137            const a3b3 = montReduce(@as(i32, a.cs[i + 3]) * @as(i32, b.cs[i + 3]));
1138            const a2b2 = montReduce(@as(i32, a.cs[i + 2]) * @as(i32, b.cs[i + 2]));
1139            const a3b2 = montReduce(@as(i32, a.cs[i + 3]) * @as(i32, b.cs[i + 2]));
1140            const a2b3 = montReduce(@as(i32, a.cs[i + 2]) * @as(i32, b.cs[i + 3]));
1141
1142            p.cs[i + 2] = a2b2 - montReduce(a3b3 * z);
1143            p.cs[i + 3] = a2b3 + a3b2;
1144        }
1145
1146        return p;
1147    }
1148
1149    // Sample p from a centered binomial distribution with n=2η and p=½ - viz:
1150    // coefficients are in {-η, …, η} with probabilities
1151    //
1152    //  {ncr(0, 2η)/2^2η, ncr(1, 2η)/2^2η, …, ncr(2η,2η)/2^2η}
1153    fn noise(comptime eta: u8, nonce: u8, seed: *const [32]u8) Poly {
1154        var h = sha3.Shake256.init(.{});
1155        const suffix: [1]u8 = .{nonce};
1156        h.update(seed);
1157        h.update(&suffix);
1158
1159        // The distribution at hand is exactly the same as that
1160        // of (a₁ + a₂ + … + a_η) - (b₁ + … + b_η) where a_i,b_i~U(1).
1161        // Thus we need 2η bits per coefficient.
1162        const buf_len = comptime 2 * eta * N / 8;
1163        var buf: [buf_len]u8 = undefined;
1164        h.squeeze(&buf);
1165
1166        // buf is interpreted as a₁…a_ηb₁…b_ηa₁…a_ηb₁…b_η…. We process
1167        // multiple coefficients in one batch.
1168
1169        const T = switch (builtin.target.cpu.arch) {
1170            .x86_64, .x86 => u32, // Generates better code on Intel CPUs
1171            else => u64, // u128 might be faster on some other CPUs.
1172        };
1173
1174        comptime var batch_count: usize = undefined;
1175        comptime var batch_bytes: usize = undefined;
1176        comptime var mask: T = 0;
1177        comptime {
1178            batch_count = @bitSizeOf(T) / @as(usize, 2 * eta);
1179            while (@rem(N, batch_count) != 0 and batch_count > 0) : (batch_count -= 1) {}
1180            assert(batch_count > 0);
1181            assert(@rem(2 * eta * batch_count, 8) == 0);
1182            batch_bytes = 2 * eta * batch_count / 8;
1183
1184            for (0..2 * eta * batch_count) |_| {
1185                mask <<= eta;
1186                mask |= 1;
1187            }
1188        }
1189
1190        var ret: Poly = undefined;
1191        for (0..comptime N / batch_count) |i| {
1192            // Read coefficients into t. In the case of η=3,
1193            // we have t = a₁ + 2a₂ + 4a₃ + 8b₁ + 16b₂ + …
1194            var t: T = 0;
1195            inline for (0..batch_bytes) |j| {
1196                t |= @as(T, buf[batch_bytes * i + j]) << (8 * j);
1197            }
1198
1199            // Accumulate `a's and `b's together by masking them out, shifting
1200            // and adding. For η=3, we have  d = a₁ + a₂ + a₃ + 8(b₁ + b₂ + b₃) + …
1201            var d: T = 0;
1202            inline for (0..eta) |j| {
1203                d += (t >> j) & mask;
1204            }
1205
1206            // Extract each a and b separately and set coefficient in polynomial.
1207            inline for (0..batch_count) |j| {
1208                const mask2 = comptime (1 << eta) - 1;
1209                const a = @as(i16, @intCast((d >> (comptime (2 * j * eta))) & mask2));
1210                const b = @as(i16, @intCast((d >> (comptime ((2 * j + 1) * eta))) & mask2));
1211                ret.cs[batch_count * i + j] = a - b;
1212            }
1213        }
1214
1215        return ret;
1216    }
1217
1218    fn uniform(seed: [32]u8, x: u8, y: u8) Poly {
1219        const domain_sep: [2]u8 = .{ x, y };
1220        return sampleUniformRejection(
1221            Poly,
1222            Q,
1223            12,
1224            N,
1225            &seed,
1226            &domain_sep,
1227        );
1228    }
1229
1230    // Packs p.
1231    //
1232    // Assumes p is normalized (and not just Barrett reduced).
1233    fn toBytes(p: Poly) [encoded_length]u8 {
1234        var ret: [encoded_length]u8 = undefined;
1235        for (0..comptime N / 2) |i| {
1236            const t0 = @as(u16, @intCast(p.cs[2 * i]));
1237            const t1 = @as(u16, @intCast(p.cs[2 * i + 1]));
1238            ret[3 * i] = @as(u8, @truncate(t0));
1239            ret[3 * i + 1] = @as(u8, @truncate((t0 >> 8) | (t1 << 4)));
1240            ret[3 * i + 2] = @as(u8, @truncate(t1 >> 4));
1241        }
1242        return ret;
1243    }
1244
1245    // Unpacks a Poly from buf.
1246    //
1247    // p will not be normalized; instead 0 ≤ p[i] < 4096.
1248    fn fromBytes(buf: *const [encoded_length]u8) Poly {
1249        var ret: Poly = undefined;
1250        for (0..comptime N / 2) |i| {
1251            const b0 = @as(i16, buf[3 * i]);
1252            const b1 = @as(i16, buf[3 * i + 1]);
1253            const b2 = @as(i16, buf[3 * i + 2]);
1254            ret.cs[2 * i] = b0 | ((b1 & 0xf) << 8);
1255            ret.cs[2 * i + 1] = (b1 >> 4) | b2 << 4;
1256        }
1257        return ret;
1258    }
1259};
1260
1261// A vector of k polynomials.
1262fn PolyVec(comptime k: u8) type {
1263    return struct {
1264        ps: [k]Poly,
1265
1266        const Self = @This();
1267        const encoded_length = k * Poly.encoded_length;
1268
1269        fn compressedSize(comptime d: u8) usize {
1270            return Poly.compressedSize(d) * k;
1271        }
1272
1273        /// Apply unary operation to each polynomial
1274        fn map(v: Self, comptime op: fn (Poly) Poly) Self {
1275            var ret: Self = undefined;
1276            inline for (0..k) |i| {
1277                ret.ps[i] = op(v.ps[i]);
1278            }
1279            return ret;
1280        }
1281
1282        /// Apply binary operation pairwise
1283        fn mapBinary(a: Self, b: Self, comptime op: fn (Poly, Poly) Poly) Self {
1284            var ret: Self = undefined;
1285            inline for (0..k) |i| {
1286                ret.ps[i] = op(a.ps[i], b.ps[i]);
1287            }
1288            return ret;
1289        }
1290
1291        fn ntt(v: Self) Self {
1292            return map(v, Poly.ntt);
1293        }
1294
1295        fn invNTT(v: Self) Self {
1296            return map(v, Poly.invNTT);
1297        }
1298
1299        fn normalize(v: Self) Self {
1300            return map(v, Poly.normalize);
1301        }
1302
1303        fn barrettReduce(v: Self) Self {
1304            return map(v, Poly.barrettReduce);
1305        }
1306
1307        fn add(a: Self, b: Self) Self {
1308            return mapBinary(a, b, Poly.add);
1309        }
1310
1311        fn sub(a: Self, b: Self) Self {
1312            return mapBinary(a, b, Poly.sub);
1313        }
1314
1315        // Samples v[i] from centered binomial distribution with the given η,
1316        // seed and nonce+i.
1317        fn noise(comptime eta: u8, nonce: u8, seed: *const [32]u8) Self {
1318            var ret: Self = undefined;
1319            for (0..k) |i| {
1320                ret.ps[i] = Poly.noise(eta, nonce + @as(u8, @intCast(i)), seed);
1321            }
1322            return ret;
1323        }
1324
1325        // Sets p to the inner product of a and b using "pointwise" multiplication.
1326        //
1327        // See MulHat() and NTT() for a description of the multiplication.
1328        // Assumes a and b are in Montgomery form.  p will be in Montgomery form,
1329        // and its coefficients will be bounded in absolute value by 2kq.
1330        // If a and b are not in Montgomery form, then the action is the same
1331        // as "pointwise" multiplication followed by multiplying by R⁻¹, the inverse
1332        // of the Montgomery factor.
1333        fn dotHat(a: Self, b: Self) Poly {
1334            var ret: Poly = Poly.zero;
1335            for (0..k) |i| {
1336                ret = ret.add(a.ps[i].mulHat(b.ps[i]));
1337            }
1338            return ret;
1339        }
1340
1341        fn compress(v: Self, comptime d: u8) [compressedSize(d)]u8 {
1342            const cs = comptime Poly.compressedSize(d);
1343            var ret: [compressedSize(d)]u8 = undefined;
1344            inline for (0..k) |i| {
1345                ret[i * cs .. (i + 1) * cs].* = v.ps[i].compress(d);
1346            }
1347            return ret;
1348        }
1349
1350        fn decompress(comptime d: u8, buf: *const [compressedSize(d)]u8) Self {
1351            const cs = comptime Poly.compressedSize(d);
1352            var ret: Self = undefined;
1353            inline for (0..k) |i| {
1354                ret.ps[i] = Poly.decompress(d, buf[i * cs .. (i + 1) * cs]);
1355            }
1356            return ret;
1357        }
1358
1359        /// Serializes the key into a byte array.
1360        fn toBytes(v: Self) [encoded_length]u8 {
1361            var ret: [encoded_length]u8 = undefined;
1362            inline for (0..k) |i| {
1363                ret[i * Poly.encoded_length .. (i + 1) * Poly.encoded_length].* = v.ps[i].toBytes();
1364            }
1365            return ret;
1366        }
1367
1368        /// Deserializes the key from a byte array.
1369        fn fromBytes(buf: *const [encoded_length]u8) Self {
1370            var ret: Self = undefined;
1371            inline for (0..k) |i| {
1372                ret.ps[i] = Poly.fromBytes(
1373                    buf[i * Poly.encoded_length .. (i + 1) * Poly.encoded_length],
1374                );
1375            }
1376            return ret;
1377        }
1378    };
1379}
1380
1381// A matrix of k vectors
1382fn Mat(comptime k: u8) type {
1383    return struct {
1384        const Self = @This();
1385        rows: [k]PolyVec(k),
1386
1387        fn uniform(seed: [32]u8, comptime transposed: bool) Self {
1388            var ret: Self = undefined;
1389            var i: u8 = 0;
1390            while (i < k) : (i += 1) {
1391                var j: u8 = 0;
1392                while (j < k) : (j += 1) {
1393                    ret.rows[i].ps[j] = Poly.uniform(
1394                        seed,
1395                        if (transposed) i else j,
1396                        if (transposed) j else i,
1397                    );
1398                }
1399            }
1400            return ret;
1401        }
1402
1403        // Returns transpose of A
1404        fn transpose(m: Self) Self {
1405            var ret: Self = undefined;
1406            for (0..k) |i| {
1407                for (0..k) |j| {
1408                    ret.rows[i].ps[j] = m.rows[j].ps[i];
1409                }
1410            }
1411            return ret;
1412        }
1413    };
1414}
1415
1416// Returns `true` if a ≠ b.
1417fn ctneq(comptime len: usize, a: [len]u8, b: [len]u8) u1 {
1418    return 1 - @intFromBool(crypto.timing_safe.eql([len]u8, a, b));
1419}
1420
1421// Copy src into dst given b = 1.
1422fn cmov(comptime len: usize, dst: *[len]u8, src: [len]u8, b: u1) void {
1423    const mask = @as(u8, 0) -% b;
1424    for (0..len) |i| {
1425        dst[i] ^= mask & (dst[i] ^ src[i]);
1426    }
1427}
1428
1429// Test helper: generates a random polynomial with each coefficient |x| ≤ q
1430fn randPolyAbsLeqQ(rnd: anytype) Poly {
1431    var ret: Poly = undefined;
1432    for (0..N) |i| {
1433        ret.cs[i] = rnd.random().intRangeAtMost(i16, -Q, Q);
1434    }
1435    return ret;
1436}
1437
1438// Test helper: generates a random normalized polynomial
1439fn randPolyNormalized(rnd: anytype) Poly {
1440    var ret: Poly = undefined;
1441    for (0..N) |i| {
1442        ret.cs[i] = rnd.random().intRangeLessThan(i16, 0, Q);
1443    }
1444    return ret;
1445}
1446
1447test "MulHat" {
1448    if (comptime builtin.cpu.has(.s390x, .vector)) return error.SkipZigTest;
1449
1450    var rnd = RndGen.init(0);
1451
1452    for (0..100) |_| {
1453        const a = randPolyAbsLeqQ(&rnd);
1454        const b = randPolyAbsLeqQ(&rnd);
1455
1456        const p2 = a.ntt().mulHat(b.ntt()).barrettReduce().invNTT().normalize();
1457        var p: Poly = undefined;
1458
1459        @memset(&p.cs, 0);
1460
1461        for (0..N) |i| {
1462            for (0..N) |j| {
1463                var v = montReduce(@as(i32, a.cs[i]) * @as(i32, b.cs[j]));
1464                var k = i + j;
1465                if (k >= N) {
1466                    // Recall Xᴺ = -1.
1467                    k -= N;
1468                    v = -v;
1469                }
1470                p.cs[k] = feBarrettReduce(v + p.cs[k]);
1471            }
1472        }
1473
1474        p = p.toMont().normalize();
1475
1476        try testing.expectEqual(p, p2);
1477    }
1478}
1479
1480test "NTT" {
1481    var rnd = RndGen.init(0);
1482
1483    for (0..1000) |_| {
1484        var p = randPolyAbsLeqQ(&rnd);
1485        const q = p.toMont().normalize();
1486        p = p.ntt();
1487
1488        for (0..N) |i| {
1489            try testing.expect(p.cs[i] <= 7 * Q and -7 * Q <= p.cs[i]);
1490        }
1491
1492        p = p.normalize().invNTT();
1493        for (0..N) |i| {
1494            try testing.expect(p.cs[i] <= Q and -Q <= p.cs[i]);
1495        }
1496
1497        p = p.normalize();
1498
1499        try testing.expectEqual(p, q);
1500    }
1501}
1502
1503test "Compression" {
1504    var rnd = RndGen.init(0);
1505    inline for (.{ 1, 4, 5, 10, 11 }) |d| {
1506        for (0..1000) |_| {
1507            const p = randPolyNormalized(&rnd);
1508            const pp = p.compress(d);
1509            const pq = Poly.decompress(d, &pp).compress(d);
1510            try testing.expectEqual(pp, pq);
1511        }
1512    }
1513}
1514
1515test "noise" {
1516    var seed: [32]u8 = undefined;
1517    for (&seed, 0..) |*s, i| {
1518        s.* = @as(u8, @intCast(i));
1519    }
1520    try testing.expectEqual(Poly.noise(3, 37, &seed).cs, .{
1521        0,  0,  1,  -1, 0,  2,  0,  -1, -1, 3,  0,  1,  -2, -2, 0,  1,  -2,
1522        1,  0,  -2, 3,  0,  0,  0,  1,  3,  1,  1,  2,  1,  -1, -1, -1, 0,
1523        1,  0,  1,  0,  2,  0,  1,  -2, 0,  -1, -1, -2, 1,  -1, -1, 2,  -1,
1524        1,  1,  2,  -3, -1, -1, 0,  0,  0,  0,  1,  -1, -2, -2, 0,  -2, 0,
1525        0,  0,  1,  0,  -1, -1, 1,  -2, 2,  0,  0,  2,  -2, 0,  1,  0,  1,
1526        1,  1,  0,  1,  -2, -1, -2, -1, 1,  0,  0,  0,  0,  0,  1,  0,  -1,
1527        -1, 0,  -1, 1,  0,  1,  0,  -1, -1, 0,  -2, 2,  0,  -2, 1,  -1, 0,
1528        1,  -1, -1, 2,  1,  0,  0,  -2, -1, 2,  0,  0,  0,  -1, -1, 3,  1,
1529        0,  1,  0,  1,  0,  2,  1,  0,  0,  1,  0,  1,  0,  0,  -1, -1, -1,
1530        0,  1,  3,  1,  0,  1,  0,  1,  -1, -1, -1, -1, 0,  0,  -2, -1, -1,
1531        2,  0,  1,  0,  1,  0,  2,  -2, 0,  1,  1,  -3, -1, -2, -1, 0,  1,
1532        0,  1,  -2, 2,  2,  1,  1,  0,  -1, 0,  -1, -1, 1,  0,  -1, 2,  1,
1533        -1, 1,  2,  -2, 1,  2,  0,  1,  2,  1,  0,  0,  2,  1,  2,  1,  0,
1534        2,  1,  0,  0,  -1, -1, 1,  -1, 0,  1,  -1, 2,  2,  0,  0,  -1, 1,
1535        1,  1,  1,  0,  0,  -2, 0,  -1, 1,  2,  0,  0,  1,  1,  -1, 1,  0,
1536        1,
1537    });
1538    try testing.expectEqual(Poly.noise(2, 37, &seed).cs, .{
1539        1,  0,  1,  -1, -1, -2, -1, -1, 2,  0,  -1, 0,  0,  -1,
1540        1,  1,  -1, 1,  0,  2,  -2, 0,  1,  2,  0,  0,  -1, 1,
1541        0,  -1, 1,  -1, 1,  2,  1,  1,  0,  -1, 1,  -1, -2, -1,
1542        1,  -1, -1, -1, 2,  -1, -1, 0,  0,  1,  1,  -1, 1,  1,
1543        1,  1,  -1, -2, 0,  1,  0,  0,  2,  1,  -1, 2,  0,  0,
1544        1,  1,  0,  -1, 0,  0,  -1, -1, 2,  0,  1,  -1, 2,  -1,
1545        -1, -1, -1, 0,  -2, 0,  2,  1,  0,  0,  0,  -1, 0,  0,
1546        0,  -1, -1, 0,  -1, -1, 0,  -1, 0,  0,  -2, 1,  1,  0,
1547        1,  0,  1,  0,  1,  1,  -1, 2,  0,  1,  -1, 1,  2,  0,
1548        0,  0,  0,  -1, -1, -1, 0,  1,  0,  -1, 2,  0,  0,  1,
1549        1,  1,  0,  1,  -1, 1,  2,  1,  0,  2,  -1, 1,  -1, -2,
1550        -1, -2, -1, 1,  0,  -2, -2, -1, 1,  0,  0,  0,  0,  1,
1551        0,  0,  0,  2,  2,  0,  1,  0,  -1, -1, 0,  2,  0,  0,
1552        -2, 1,  0,  2,  1,  -1, -2, 0,  0,  -1, 1,  1,  0,  0,
1553        2,  0,  1,  1,  -2, 1,  -2, 1,  1,  0,  2,  0,  -1, 0,
1554        -1, 0,  1,  2,  0,  1,  0,  -2, 1,  -2, -2, 1,  -1, 0,
1555        -1, 1,  1,  0,  0,  0,  1,  0,  -1, 1,  1,  0,  0,  0,
1556        0,  1,  0,  1,  -1, 0,  1,  -1, -1, 2,  0,  0,  1,  -1,
1557        0,  1,  -1, 0,
1558    });
1559}
1560
1561test "uniform sampling" {
1562    var seed: [32]u8 = undefined;
1563    for (&seed, 0..) |*s, i| {
1564        s.* = @as(u8, @intCast(i));
1565    }
1566    try testing.expectEqual(Poly.uniform(seed, 1, 0).cs, .{
1567        797,  993,  161,  6,    2608, 2385, 2096, 2661, 1676, 247,  2440,
1568        342,  634,  194,  1570, 2848, 986,  684,  3148, 3208, 2018, 351,
1569        2288, 612,  1394, 170,  1521, 3119, 58,   596,  2093, 1549, 409,
1570        2156, 1934, 1730, 1324, 388,  446,  418,  1719, 2202, 1812, 98,
1571        1019, 2369, 214,  2699, 28,   1523, 2824, 273,  402,  2899, 246,
1572        210,  1288, 863,  2708, 177,  3076, 349,  44,   949,  854,  1371,
1573        957,  292,  2502, 1617, 1501, 254,  7,    1761, 2581, 2206, 2655,
1574        1211, 629,  1274, 2358, 816,  2766, 2115, 2985, 1006, 2433, 856,
1575        2596, 3192, 1,    1378, 2345, 707,  1891, 1669, 536,  1221, 710,
1576        2511, 120,  1176, 322,  1897, 2309, 595,  2950, 1171, 801,  1848,
1577        695,  2912, 1396, 1931, 1775, 2904, 893,  2507, 1810, 2873, 253,
1578        1529, 1047, 2615, 1687, 831,  1414, 965,  3169, 1887, 753,  3246,
1579        1937, 115,  2953, 586,  545,  1621, 1667, 3187, 1654, 1988, 1857,
1580        512,  1239, 1219, 898,  3106, 391,  1331, 2228, 3169, 586,  2412,
1581        845,  768,  156,  662,  478,  1693, 2632, 573,  2434, 1671, 173,
1582        969,  364,  1663, 2701, 2169, 813,  1000, 1471, 720,  2431, 2530,
1583        3161, 733,  1691, 527,  2634, 335,  26,   2377, 1707, 767,  3020,
1584        950,  502,  426,  1138, 3208, 2607, 2389, 44,   1358, 1392, 2334,
1585        875,  2097, 173,  1697, 2578, 942,  1817, 974,  1165, 2853, 1958,
1586        2973, 3282, 271,  1236, 1677, 2230, 673,  1554, 96,   242,  1729,
1587        2518, 1884, 2272, 71,   1382, 924,  1807, 1610, 456,  1148, 2479,
1588        2152, 238,  2208, 2329, 713,  1175, 1196, 757,  1078, 3190, 3169,
1589        708,  3117, 154,  1751, 3225, 1364, 154,  23,   2842, 1105, 1419,
1590        79,   5,    2013,
1591    });
1592}
1593
1594test "Polynomial packing" {
1595    var rnd = RndGen.init(0);
1596
1597    for (0..1000) |_| {
1598        const p = randPolyNormalized(&rnd);
1599        try testing.expectEqual(Poly.fromBytes(&p.toBytes()), p);
1600    }
1601}
1602
1603test "Test inner PKE" {
1604    if (comptime builtin.cpu.has(.s390x, .vector)) return error.SkipZigTest;
1605
1606    var seed: [32]u8 = undefined;
1607    var pt: [32]u8 = undefined;
1608    for (&seed, &pt, 0..) |*s, *p, i| {
1609        s.* = @as(u8, @intCast(i));
1610        p.* = @as(u8, @intCast(i + 32));
1611    }
1612    inline for (modes) |mode| {
1613        for (0..10) |i| {
1614            var pk: mode.InnerPk = undefined;
1615            var sk: mode.InnerSk = undefined;
1616            seed[0] = @as(u8, @intCast(i));
1617            mode.innerKeyFromSeed(seed, &pk, &sk);
1618            for (0..10) |j| {
1619                seed[1] = @as(u8, @intCast(j));
1620                try testing.expectEqual(sk.decrypt(&pk.encrypt(&pt, &seed)), pt);
1621            }
1622        }
1623    }
1624}
1625
1626test "Test happy flow" {
1627    if (comptime builtin.cpu.has(.s390x, .vector)) return error.SkipZigTest;
1628
1629    var seed: [64]u8 = undefined;
1630    for (&seed, 0..) |*s, i| {
1631        s.* = @as(u8, @intCast(i));
1632    }
1633    inline for (modes) |mode| {
1634        for (0..10) |i| {
1635            seed[0] = @as(u8, @intCast(i));
1636            const kp = try mode.KeyPair.generateDeterministic(seed);
1637            const sk = try mode.SecretKey.fromBytes(&kp.secret_key.toBytes());
1638            try testing.expectEqual(sk, kp.secret_key);
1639            const pk = try mode.PublicKey.fromBytes(&kp.public_key.toBytes());
1640            try testing.expectEqual(pk, kp.public_key);
1641            for (0..10) |j| {
1642                seed[1] = @as(u8, @intCast(j));
1643                const e = pk.encaps(seed[0..32].*);
1644                try testing.expectEqual(e.shared_secret, try sk.decaps(&e.ciphertext));
1645            }
1646        }
1647    }
1648}
1649
1650// Code to test NIST Known Answer Tests (KAT), see PQCgenKAT.c.
1651
1652test "NIST KAT test d00.Kyber512" {
1653    if (comptime builtin.cpu.has(.loongarch, .lsx)) return error.SkipZigTest;
1654    if (comptime builtin.cpu.has(.s390x, .vector)) return error.SkipZigTest;
1655
1656    try testNistKat(d00.Kyber512, "e9c2bd37133fcb40772f81559f14b1f58dccd1c816701be9ba6214d43baf4547");
1657}
1658
1659test "NIST KAT test d00.Kyber1024" {
1660    if (comptime builtin.cpu.has(.loongarch, .lsx)) return error.SkipZigTest;
1661    if (comptime builtin.cpu.has(.s390x, .vector)) return error.SkipZigTest;
1662
1663    try testNistKat(d00.Kyber1024, "89248f2f33f7f4f7051729111f3049c409a933ec904aedadf035f30fa5646cd5");
1664}
1665
1666test "NIST KAT test d00.Kyber768" {
1667    if (comptime builtin.cpu.has(.loongarch, .lsx)) return error.SkipZigTest;
1668    if (comptime builtin.cpu.has(.s390x, .vector)) return error.SkipZigTest;
1669
1670    try testNistKat(d00.Kyber768, "a1e122cad3c24bc51622e4c242d8b8acbcd3f618fee4220400605ca8f9ea02c2");
1671}
1672
1673fn testNistKat(mode: type, hash: []const u8) !void {
1674    var seed: [48]u8 = undefined;
1675    for (&seed, 0..) |*s, i| {
1676        s.* = @as(u8, @intCast(i));
1677    }
1678    var fw: std.Io.Writer.Hashing(crypto.hash.sha2.Sha256) = .init(&.{});
1679    var g = NistDRBG.init(seed);
1680    try fw.writer.print("# {s}\n\n", .{mode.name});
1681    for (0..100) |i| {
1682        g.fill(&seed);
1683        try fw.writer.print("count = {}\n", .{i});
1684        try fw.writer.print("seed = {X}\n", .{&seed});
1685        var g2 = NistDRBG.init(seed);
1686
1687        // This is not equivalent to g2.fill(kseed[:]). As the reference
1688        // implementation calls randombytes twice generating the keypair,
1689        // we have to do that as well.
1690        var kseed: [64]u8 = undefined;
1691        var eseed: [32]u8 = undefined;
1692        g2.fill(kseed[0..32]);
1693        g2.fill(kseed[32..64]);
1694        g2.fill(&eseed);
1695        const kp = try mode.KeyPair.generateDeterministic(kseed);
1696        const e = kp.public_key.encaps(eseed);
1697        const ss2 = try kp.secret_key.decaps(&e.ciphertext);
1698        try testing.expectEqual(ss2, e.shared_secret);
1699        try fw.writer.print("pk = {X}\n", .{&kp.public_key.toBytes()});
1700        try fw.writer.print("sk = {X}\n", .{&kp.secret_key.toBytes()});
1701        try fw.writer.print("ct = {X}\n", .{&e.ciphertext});
1702        try fw.writer.print("ss = {X}\n\n", .{&e.shared_secret});
1703    }
1704
1705    var out: [32]u8 = undefined;
1706    fw.hasher.final(&out);
1707    var outHex: [64]u8 = undefined;
1708    _ = try std.fmt.bufPrint(&outHex, "{x}", .{&out});
1709    try testing.expectEqualStrings(&outHex, hash);
1710}
1711
1712const NistDRBG = struct {
1713    key: [32]u8,
1714    v: [16]u8,
1715
1716    fn incV(g: *NistDRBG) void {
1717        var j: usize = 15;
1718        while (j >= 0) : (j -= 1) {
1719            if (g.v[j] == 255) {
1720                g.v[j] = 0;
1721            } else {
1722                g.v[j] += 1;
1723                break;
1724            }
1725        }
1726    }
1727
1728    // AES256_CTR_DRBG_Update(pd, &g.key, &g.v).
1729    fn update(g: *NistDRBG, pd: ?[48]u8) void {
1730        var buf: [48]u8 = undefined;
1731        const ctx = crypto.core.aes.Aes256.initEnc(g.key);
1732        var i: usize = 0;
1733        while (i < 3) : (i += 1) {
1734            g.incV();
1735            var block: [16]u8 = undefined;
1736            ctx.encrypt(&block, &g.v);
1737            buf[i * 16 ..][0..16].* = block;
1738        }
1739        if (pd) |p| {
1740            for (&buf, p) |*b, x| {
1741                b.* ^= x;
1742            }
1743        }
1744        g.key = buf[0..32].*;
1745        g.v = buf[32..48].*;
1746    }
1747
1748    // randombytes.
1749    fn fill(g: *NistDRBG, out: []u8) void {
1750        var block: [16]u8 = undefined;
1751        var dst = out;
1752
1753        const ctx = crypto.core.aes.Aes256.initEnc(g.key);
1754        while (dst.len > 0) {
1755            g.incV();
1756            ctx.encrypt(&block, &g.v);
1757            if (dst.len < 16) {
1758                @memcpy(dst, block[0..dst.len]);
1759                break;
1760            }
1761            dst[0..block.len].* = block;
1762            dst = dst[16..dst.len];
1763        }
1764        g.update(null);
1765    }
1766
1767    fn init(seed: [48]u8) NistDRBG {
1768        var ret: NistDRBG = .{ .key = .{0} ** 32, .v = .{0} ** 16 };
1769        ret.update(seed);
1770        return ret;
1771    }
1772};
1773
1774/// Extended Euclidian Algorithm
1775/// Only meant to be used on comptime values; correctness matters, performance doesn't.
1776fn extendedEuclidean(comptime T: type, comptime a_: T, comptime b_: T) struct { gcd: T, x: T, y: T } {
1777    var a = a_;
1778    var b = b_;
1779    var x0: T = 1;
1780    var x1: T = 0;
1781    var y0: T = 0;
1782    var y1: T = 1;
1783
1784    while (b != 0) {
1785        const q = @divTrunc(a, b);
1786        const temp_a = a;
1787        a = b;
1788        b = temp_a - q * b;
1789
1790        const temp_x = x0;
1791        x0 = x1;
1792        x1 = temp_x - q * x1;
1793
1794        const temp_y = y0;
1795        y0 = y1;
1796        y1 = temp_y - q * y1;
1797    }
1798
1799    return .{ .gcd = a, .x = x0, .y = y0 };
1800}
1801
1802/// Modular inversion: computes a^(-1) mod p
1803/// Requires gcd(a,p) = 1. The result is normalized to the range [0, p).
1804fn modularInverse(comptime T: type, comptime a: T, comptime p: T) T {
1805    // Use a signed type for EEA computation
1806    const type_info = @typeInfo(T);
1807    const SignedT = if (type_info == .int and type_info.int.signedness == .unsigned)
1808        std.meta.Int(.signed, type_info.int.bits)
1809    else
1810        T;
1811
1812    const a_signed = @as(SignedT, @intCast(a));
1813    const p_signed = @as(SignedT, @intCast(p));
1814
1815    const r = extendedEuclidean(SignedT, a_signed, p_signed);
1816    assert(r.gcd == 1);
1817
1818    // Normalize result to [0, p)
1819    var result = r.x;
1820    while (result < 0) {
1821        result += p_signed;
1822    }
1823
1824    return @intCast(result);
1825}
1826
1827/// Modular exponentiation: computes a^s mod p using square-and-multiply algorithm.
1828fn modularPow(comptime T: type, comptime a: T, s: T, comptime p: T) T {
1829    const type_info = @typeInfo(T);
1830    const bits = type_info.int.bits;
1831    const WideT = std.meta.Int(.unsigned, bits * 2);
1832
1833    var ret: T = 1;
1834    var base: T = a;
1835    var exp = s;
1836
1837    while (exp > 0) {
1838        if (exp & 1 == 1) {
1839            ret = @intCast((@as(WideT, ret) * @as(WideT, base)) % p);
1840        }
1841        base = @intCast((@as(WideT, base) * @as(WideT, base)) % p);
1842        exp >>= 1;
1843    }
1844
1845    return ret;
1846}
1847
1848/// Creates an all-ones or all-zeros mask from a single bit value.
1849/// Returns all 1s (0xFF...FF) if bit == 1, all 0s if bit == 0.
1850fn bitMask(comptime T: type, bit: T) T {
1851    const type_info = @typeInfo(T);
1852    if (type_info != .int or type_info.int.signedness != .unsigned) {
1853        @compileError("bitMask requires an unsigned integer type");
1854    }
1855    return -%bit;
1856}
1857
1858/// Creates a mask from the sign bit of a signed integer.
1859/// Returns all 1s (0xFF...FF) if x < 0, all 0s if x >= 0.
1860fn signMask(comptime T: type, x: T) std.meta.Int(.unsigned, @typeInfo(T).int.bits) {
1861    const type_info = @typeInfo(T);
1862    if (type_info != .int) {
1863        @compileError("signMask requires an integer type");
1864    }
1865
1866    const bits = type_info.int.bits;
1867    const SignedT = std.meta.Int(.signed, bits);
1868
1869    // Convert to signed if needed, arithmetic right shift to propagate sign bit
1870    const x_signed: SignedT = if (type_info.int.signedness == .signed) x else @bitCast(x);
1871    const shifted = x_signed >> (bits - 1);
1872    return @bitCast(shifted);
1873}
1874
1875test "bitMask and signMask helpers" {
1876    try testing.expectEqual(@as(u32, 0x00000000), bitMask(u32, 0));
1877    try testing.expectEqual(@as(u32, 0xFFFFFFFF), bitMask(u32, 1));
1878    try testing.expectEqual(@as(u8, 0x00), bitMask(u8, 0));
1879    try testing.expectEqual(@as(u8, 0xFF), bitMask(u8, 1));
1880    try testing.expectEqual(@as(u64, 0x0000000000000000), bitMask(u64, 0));
1881    try testing.expectEqual(@as(u64, 0xFFFFFFFFFFFFFFFF), bitMask(u64, 1));
1882
1883    try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -1));
1884    try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -100));
1885    try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 0));
1886    try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 1));
1887    try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 100));
1888
1889    try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(u32, 0x80000000)); // MSB set
1890    try testing.expectEqual(@as(u32, 0x00000000), signMask(u32, 0x7FFFFFFF)); // MSB clear
1891}
1892
1893/// Montgomery reduction: for input x, returns y where y ≡ x*R^(-1) (mod q).
1894/// This is a generic implementation parameterized by the modulus q, its inverse qInv,
1895/// the Montgomery constant R, and the result bound.
1896///
1897/// For ML-DSA: R = 2^32, returns y < 2q
1898/// For ML-KEM: R = 2^16, returns y in range (-q, q)
1899fn montgomeryReduce(
1900    comptime InT: type,
1901    comptime OutT: type,
1902    comptime q: comptime_int,
1903    comptime qInv: comptime_int,
1904    comptime r_bits: comptime_int,
1905    x: InT,
1906) OutT {
1907    const mask = (@as(InT, 1) << r_bits) - 1;
1908    const m_full = (x *% qInv) & mask;
1909    const m: OutT = @truncate(m_full);
1910
1911    const yR = x -% @as(InT, m) * @as(InT, q);
1912    const y_shifted = @as(std.meta.Int(.unsigned, @typeInfo(InT).Int.bits), @bitCast(yR)) >> r_bits;
1913    return @bitCast(@as(std.meta.Int(.unsigned, @typeInfo(OutT).Int.bits), @truncate(y_shifted)));
1914}
1915
1916/// Uniform sampling using SHAKE-128 with rejection sampling.
1917/// Samples polynomial coefficients uniformly from [0, q) using rejection sampling.
1918///
1919/// Parameters:
1920/// - PolyType: The polynomial type to return
1921/// - q: Modulus
1922/// - bits_per_coef: Number of bits per coefficient (12 or 23)
1923/// - n: Number of coefficients
1924/// - seed: Random seed
1925/// - domain_sep: Domain separation bytes (appended to seed)
1926fn sampleUniformRejection(
1927    comptime PolyType: type,
1928    comptime q: comptime_int,
1929    comptime bits_per_coef: comptime_int,
1930    comptime n: comptime_int,
1931    seed: []const u8,
1932    domain_sep: []const u8,
1933) PolyType {
1934    var h = sha3.Shake128.init(.{});
1935    h.update(seed);
1936    h.update(domain_sep);
1937
1938    const buf_len = sha3.Shake128.block_length; // 168 bytes
1939    var buf: [buf_len]u8 = undefined;
1940
1941    var ret: PolyType = undefined;
1942    var coef_idx: usize = 0;
1943
1944    if (bits_per_coef == 12) {
1945        // ML-KEM path: pack 2 coefficients per 3 bytes (12 bits each)
1946        outer: while (true) {
1947            h.squeeze(&buf);
1948
1949            var j: usize = 0;
1950            while (j < buf_len) : (j += 3) {
1951                const b0 = @as(u16, buf[j]);
1952                const b1 = @as(u16, buf[j + 1]);
1953                const b2 = @as(u16, buf[j + 2]);
1954
1955                const ts: [2]u16 = .{
1956                    b0 | ((b1 & 0xf) << 8),
1957                    (b1 >> 4) | (b2 << 4),
1958                };
1959
1960                inline for (ts) |t| {
1961                    if (t < q) {
1962                        ret.cs[coef_idx] = @intCast(t);
1963                        coef_idx += 1;
1964                        if (coef_idx == n) break :outer;
1965                    }
1966                }
1967            }
1968        }
1969    } else if (bits_per_coef == 23) {
1970        // ML-DSA path: 1 coefficient per 3 bytes (23 bits)
1971        while (coef_idx < n) {
1972            h.squeeze(&buf);
1973
1974            var j: usize = 0;
1975            while (j < buf_len and coef_idx < n) : (j += 3) {
1976                const t = (@as(u32, buf[j]) |
1977                    (@as(u32, buf[j + 1]) << 8) |
1978                    (@as(u32, buf[j + 2]) << 16)) & 0x7fffff;
1979
1980                if (t < q) {
1981                    ret.cs[coef_idx] = @intCast(t);
1982                    coef_idx += 1;
1983                }
1984            }
1985        }
1986    } else {
1987        @compileError("bits_per_coef must be 12 or 23");
1988    }
1989
1990    return ret;
1991}