master
1//! Implementation of the IND-CCA2 post-quantum secure key encapsulation mechanism (KEM)
2//! ML-KEM (NIST FIPS-203 publication) and CRYSTALS-Kyber (v3.02/"draft00" CFRG draft).
3//!
4//! The namespace `d00` refers to the version currently implemented, in accordance with the CFRG draft.
5//! The `nist` namespace refers to the FIPS-203 publication.
6//!
7//! Quoting from the CFRG I-D:
8//!
9//! Kyber is not a Diffie-Hellman (DH) style non-interactive key
10//! agreement, but instead, Kyber is a Key Encapsulation Method (KEM).
11//! In essence, a KEM is a Public-Key Encryption (PKE) scheme where the
12//! plaintext cannot be specified, but is generated as a random key as
13//! part of the encryption. A KEM can be transformed into an unrestricted
14//! PKE using HPKE (RFC9180). On its own, a KEM can be used as a key
15//! agreement method in TLS.
16//!
17//! Kyber is an IND-CCA2 secure KEM. It is constructed by applying a
18//! Fujisaki--Okamato style transformation on InnerPKE, which is the
19//! underlying IND-CPA secure Public Key Encryption scheme. We cannot
20//! use InnerPKE directly, as its ciphertexts are malleable.
21//!
22//! ```
23//! F.O. transform
24//! InnerPKE ----------------------> Kyber
25//! IND-CPA IND-CCA2
26//! ```
27//!
28//! Kyber is a lattice-based scheme. More precisely, its security is
29//! based on the learning-with-errors-and-rounding problem in module
30//! lattices (MLWER). The underlying polynomial ring R (defined in
31//! Section 5) is chosen such that multiplication is very fast using the
32//! number theoretic transform (NTT, see Section 5.1.3).
33//!
34//! An InnerPKE private key is a vector _s_ over R of length k which is
35//! _small_ in a particular way. Here k is a security parameter akin to
36//! the size of a prime modulus. For Kyber512, which targets AES-128's
37//! security level, the value of k is 2.
38//!
39//! The public key consists of two values:
40//!
41//! * _A_ a uniformly sampled k by k matrix over R _and_
42//!
43//! * _t = A s + e_, where e is a suitably small masking vector.
44//!
45//! Distinguishing between such A s + e and a uniformly sampled t is the
46//! module learning-with-errors (MLWE) problem. If that is hard, then it
47//! is also hard to recover the private key from the public key as that
48//! would allow you to distinguish between those two.
49//!
50//! To save space in the public key, A is recomputed deterministically
51//! from a seed _rho_.
52//!
53//! A ciphertext for a message m under this public key is a pair (c_1,
54//! c_2) computed roughly as follows:
55//!
56//! c_1 = Compress(A^T r + e_1, d_u)
57//! c_2 = Compress(t^T r + e_2 + Decompress(m, 1), d_v)
58//!
59//! where
60//!
61//! * e_1, e_2 and r are small blinds;
62//!
63//! * Compress(-, d) removes some information, leaving d bits per
64//! coefficient and Decompress is such that Compress after Decompress
65//! does nothing and
66//!
67//! * d_u, d_v are scheme parameters.
68//!
69//! Distinguishing such a ciphertext and uniformly sampled (c_1, c_2) is
70//! an example of the full MLWER problem, see section 4.4 of [KyberV302].
71//!
72//! To decrypt the ciphertext, one computes
73//!
74//! m = Compress(Decompress(c_2, d_v) - s^T Decompress(c_1, d_u), 1).
75//!
76//! It it not straight-forward to see that this formula is correct. In
77//! fact, there is negligible but non-zero probability that a ciphertext
78//! does not decrypt correctly given by the DFP column in Table 4. This
79//! failure probability can be computed by a careful automated analysis
80//! of the probabilities involved, see kyber_failure.py of [SecEst].
81//!
82//! [KyberV302](https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf)
83//! [I-D](https://github.com/bwesterb/draft-schwabe-cfrg-kyber)
84//! [SecEst](https://github.com/pq-crystals/security-estimates)
85
86// TODO
87//
88// - The bottleneck in Kyber are the various hash/xof calls:
89// - Optimize Zig's keccak implementation.
90// - Use SIMD to compute keccak in parallel.
91// - Can we track bounds of coefficients using comptime types without
92// duplicating code?
93// - Would be neater to have tests closer to the thing under test.
94// - When generating a keypair, we have a copy of the inner public key with
95// its large matrix A in both the public key and the private key. In Go we
96// can just have a pointer in the private key to the public key, but
97// how do we do this elegantly in Zig?
98
99const std = @import("std");
100const builtin = @import("builtin");
101
102const testing = std.testing;
103const assert = std.debug.assert;
104const crypto = std.crypto;
105const errors = std.crypto.errors;
106const math = std.math;
107const mem = std.mem;
108const sha3 = crypto.hash.sha3;
109
110const RndGen = std.Random.DefaultPrng;
111
112// Q is the modulus q ≡ 3329 = 2¹¹ + 2¹⁰ + 2⁸ + 1
113const Q: i16 = 3329;
114
115// Montgomery R = 2^16 mod Q (for Montgomery multiplication)
116const R: i32 = 1 << 16;
117
118// N is the degree of polynomials (polynomial ring dimension)
119const N: usize = 256;
120
121// eta2 is the size of "small" vectors used in encryption blinds
122const eta2: u8 = 2;
123
124const Params = struct {
125 name: []const u8,
126
127 // NIST ML-KEM variant instead of Kyber as originally submitted.
128 ml_kem: bool = false,
129
130 // Width and height of the matrix A.
131 k: u8,
132
133 // Size of "small" vectors used in private key and encryption blinds.
134 eta1: u8,
135
136 // How many bits to retain of u, the private-key independent part
137 // of the ciphertext.
138 du: u8,
139
140 // How many bits to retain of v, the private-key dependent part
141 // of the ciphertext.
142 dv: u8,
143};
144
145pub const d00 = struct {
146 pub const Kyber512 = Kyber(.{
147 .name = "Kyber512",
148 .k = 2,
149 .eta1 = 3,
150 .du = 10,
151 .dv = 4,
152 });
153
154 pub const Kyber768 = Kyber(.{
155 .name = "Kyber768",
156 .k = 3,
157 .eta1 = 2,
158 .du = 10,
159 .dv = 4,
160 });
161
162 pub const Kyber1024 = Kyber(.{
163 .name = "Kyber1024",
164 .k = 4,
165 .eta1 = 2,
166 .du = 11,
167 .dv = 5,
168 });
169};
170
171pub const nist = struct {
172 pub const MLKem512 = Kyber(.{
173 .name = "ML-KEM-512",
174 .ml_kem = true,
175 .k = 2,
176 .eta1 = 3,
177 .du = 10,
178 .dv = 4,
179 });
180
181 pub const MLKem768 = Kyber(.{
182 .name = "ML-KEM-768",
183 .ml_kem = true,
184 .k = 3,
185 .eta1 = 2,
186 .du = 10,
187 .dv = 4,
188 });
189
190 pub const MLKem1024 = Kyber(.{
191 .name = "ML-KEM-1024",
192 .ml_kem = true,
193 .k = 4,
194 .eta1 = 2,
195 .du = 11,
196 .dv = 5,
197 });
198};
199
200const modes = [_]type{
201 d00.Kyber512,
202 d00.Kyber768,
203 d00.Kyber1024,
204 nist.MLKem512,
205 nist.MLKem768,
206 nist.MLKem1024,
207};
208const h_length: usize = 32;
209const inner_seed_length: usize = 32;
210const common_encaps_seed_length: usize = 32;
211const common_shared_key_size: usize = 32;
212
213fn Kyber(comptime p: Params) type {
214 return struct {
215 // Size of a ciphertext, in bytes.
216 pub const ciphertext_length = Poly.compressedSize(p.du) * p.k + Poly.compressedSize(p.dv);
217
218 const Self = @This();
219 const V = PolyVec(p.k);
220 const M = Mat(p.k);
221
222 /// Length (in bytes) of a shared secret.
223 pub const shared_length = common_shared_key_size;
224 /// Length (in bytes) of a seed for deterministic encapsulation.
225 pub const encaps_seed_length = common_encaps_seed_length;
226 /// Length (in bytes) of a seed for key generation.
227 pub const seed_length: usize = inner_seed_length + shared_length;
228 /// Algorithm name.
229 pub const name = p.name;
230
231 /// A shared secret, and an encapsulated (encrypted) representation of it.
232 pub const EncapsulatedSecret = struct {
233 shared_secret: [shared_length]u8,
234 ciphertext: [ciphertext_length]u8,
235 };
236
237 /// A Kyber public key.
238 pub const PublicKey = struct {
239 pk: InnerPk,
240
241 // Cached
242 hpk: [h_length]u8, // H(pk)
243
244 /// Size of a serialized representation of the key, in bytes.
245 pub const encoded_length = InnerPk.encoded_length;
246
247 /// Generates a shared secret, and encapsulates it for the public key.
248 /// If `seed` is `null`, a random seed is used. This is recommended.
249 /// If `seed` is set, encapsulation is deterministic.
250 pub fn encaps(pk: PublicKey, seed_: ?[encaps_seed_length]u8) EncapsulatedSecret {
251 var m: [inner_plaintext_length]u8 = undefined;
252
253 if (seed_) |seed| {
254 if (p.ml_kem) {
255 @memcpy(&m, &seed);
256 } else {
257 // m = H(seed)
258 sha3.Sha3_256.hash(&seed, &m, .{});
259 }
260 } else {
261 crypto.random.bytes(&m);
262 }
263
264 // (K', r) = G(m ‖ H(pk))
265 var kr: [inner_plaintext_length + h_length]u8 = undefined;
266 var g = sha3.Sha3_512.init(.{});
267 g.update(&m);
268 g.update(&pk.hpk);
269 g.final(&kr);
270
271 // c = innerEncrypt(pk, m, r)
272 const ct = pk.pk.encrypt(&m, kr[32..64]);
273
274 if (p.ml_kem) {
275 return EncapsulatedSecret{
276 .shared_secret = kr[0..shared_length].*, // ML-KEM: K = K'
277 .ciphertext = ct,
278 };
279 } else {
280 // Compute H(c) and put in second slot of kr, which will be (K', H(c)).
281 sha3.Sha3_256.hash(&ct, kr[32..], .{});
282
283 var ss: [shared_length]u8 = undefined;
284 sha3.Shake256.hash(&kr, &ss, .{});
285 return EncapsulatedSecret{
286 .shared_secret = ss, // Kyber: K = KDF(K' ‖ H(c))
287 .ciphertext = ct,
288 };
289 }
290 }
291
292 /// Serializes the key into a byte array.
293 pub fn toBytes(pk: PublicKey) [encoded_length]u8 {
294 return pk.pk.toBytes();
295 }
296
297 /// Deserializes the key from a byte array.
298 pub fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!PublicKey {
299 var ret: PublicKey = undefined;
300 ret.pk = try InnerPk.fromBytes(buf[0..InnerPk.encoded_length]);
301 sha3.Sha3_256.hash(buf, &ret.hpk, .{});
302 return ret;
303 }
304 };
305
306 /// A Kyber secret key.
307 pub const SecretKey = struct {
308 sk: InnerSk,
309 pk: InnerPk,
310 hpk: [h_length]u8, // H(pk)
311 z: [shared_length]u8,
312
313 /// Size of a serialized representation of the key, in bytes.
314 pub const encoded_length: usize =
315 InnerSk.encoded_length + InnerPk.encoded_length + h_length + shared_length;
316
317 /// Decapsulates the shared secret within ct using the private key.
318 pub fn decaps(sk: SecretKey, ct: *const [ciphertext_length]u8) ![shared_length]u8 {
319 // m' = innerDec(ct)
320 const m2 = sk.sk.decrypt(ct);
321
322 // (K'', r') = G(m' ‖ H(pk))
323 var kr2: [64]u8 = undefined;
324 var g = sha3.Sha3_512.init(.{});
325 g.update(&m2);
326 g.update(&sk.hpk);
327 g.final(&kr2);
328
329 // ct' = innerEnc(pk, m', r')
330 const ct2 = sk.pk.encrypt(&m2, kr2[32..64]);
331
332 // Compute H(ct) and put in the second slot of kr2 which will be (K'', H(ct)).
333 sha3.Sha3_256.hash(ct, kr2[32..], .{});
334
335 // Replace K'' by z in the first slot of kr2 if ct ≠ ct'.
336 cmov(32, kr2[0..32], sk.z, ctneq(ciphertext_length, ct.*, ct2));
337
338 if (p.ml_kem) {
339 // ML-KEM: K = K''/z
340 return kr2[0..shared_length].*;
341 } else {
342 // Kyber: K = KDF(K''/z ‖ H(c))
343 var ss: [shared_length]u8 = undefined;
344 sha3.Shake256.hash(&kr2, &ss, .{});
345 return ss;
346 }
347 }
348
349 /// Serializes the key into a byte array.
350 pub fn toBytes(sk: SecretKey) [encoded_length]u8 {
351 return sk.sk.toBytes() ++ sk.pk.toBytes() ++ sk.hpk ++ sk.z;
352 }
353
354 /// Deserializes the key from a byte array.
355 pub fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!SecretKey {
356 var ret: SecretKey = undefined;
357 comptime var s: usize = 0;
358 ret.sk = InnerSk.fromBytes(buf[s .. s + InnerSk.encoded_length]);
359 s += InnerSk.encoded_length;
360 ret.pk = try InnerPk.fromBytes(buf[s .. s + InnerPk.encoded_length]);
361 s += InnerPk.encoded_length;
362 ret.hpk = buf[s..][0..h_length].*;
363 s += h_length;
364 ret.z = buf[s..][0..shared_length].*;
365 return ret;
366 }
367 };
368
369 /// A Kyber key pair.
370 pub const KeyPair = struct {
371 secret_key: SecretKey,
372 public_key: PublicKey,
373
374 /// Deterministically derive a key pair from a cryptograpically secure secret seed.
375 ///
376 /// Except in tests, applications should generally call `generate()` instead of this function.
377 pub fn generateDeterministic(seed: [seed_length]u8) !KeyPair {
378 var ret: KeyPair = undefined;
379
380 // Generate inner key
381 innerKeyFromSeed(
382 seed[0..inner_seed_length].*,
383 &ret.public_key.pk,
384 &ret.secret_key.sk,
385 );
386 ret.secret_key.pk = ret.public_key.pk;
387
388 // Copy over z from seed.
389 ret.secret_key.z = seed[inner_seed_length..seed_length].*;
390
391 // Compute H(pk)
392 sha3.Sha3_256.hash(&ret.public_key.pk.toBytes(), &ret.secret_key.hpk, .{});
393 ret.public_key.hpk = ret.secret_key.hpk;
394
395 return ret;
396 }
397
398 /// Generate a new, random key pair.
399 pub fn generate() KeyPair {
400 var random_seed: [seed_length]u8 = undefined;
401 while (true) {
402 crypto.random.bytes(&random_seed);
403 return generateDeterministic(random_seed) catch {
404 @branchHint(.unlikely);
405 continue;
406 };
407 }
408 }
409 };
410
411 // Size of plaintexts of the in
412 const inner_plaintext_length: usize = Poly.compressedSize(1);
413
414 const InnerPk = struct {
415 rho: [32]u8, // ρ, the seed for the matrix A
416 th: V, // NTT(t), normalized
417
418 // Cached values
419 aT: M,
420
421 const encoded_length = V.encoded_length + 32;
422
423 fn encrypt(
424 pk: InnerPk,
425 pt: *const [inner_plaintext_length]u8,
426 seed: *const [32]u8,
427 ) [ciphertext_length]u8 {
428 // Sample r, e₁ and e₂ appropriately
429 const rh = V.noise(p.eta1, 0, seed).ntt().barrettReduce();
430 const e1 = V.noise(eta2, p.k, seed);
431 const e2 = Poly.noise(eta2, 2 * p.k, seed);
432
433 // Next we compute u = Aᵀ r + e₁. First Aᵀ.
434 var u: V = undefined;
435 for (0..p.k) |i| {
436 // Note that coefficients of r are bounded by q and those of Aᵀ
437 // are bounded by 4.5q and so their product is bounded by 2¹⁵q
438 // as required for multiplication.
439 u.ps[i] = pk.aT.rows[i].dotHat(rh);
440 }
441
442 // Aᵀ and r were not in Montgomery form, so the Montgomery
443 // multiplications in the inner product added a factor R⁻¹ which
444 // the InvNTT cancels out.
445 u = u.barrettReduce().invNTT().add(e1).normalize();
446
447 // Next, compute v = <t, r> + e₂ + Decompress_q(m, 1)
448 const v = pk.th.dotHat(rh).barrettReduce().invNTT()
449 .add(Poly.decompress(1, pt)).add(e2).normalize();
450
451 return u.compress(p.du) ++ v.compress(p.dv);
452 }
453
454 fn toBytes(pk: InnerPk) [encoded_length]u8 {
455 return pk.th.toBytes() ++ pk.rho;
456 }
457
458 fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!InnerPk {
459 var ret: InnerPk = undefined;
460
461 const th_bytes = buf[0..V.encoded_length];
462 ret.th = V.fromBytes(th_bytes).normalize();
463
464 if (p.ml_kem) {
465 // Verify that the coefficients used a canonical representation.
466 if (!mem.eql(u8, &ret.th.toBytes(), th_bytes)) {
467 return error.NonCanonical;
468 }
469 }
470
471 ret.rho = buf[V.encoded_length..encoded_length].*;
472 ret.aT = M.uniform(ret.rho, true);
473 return ret;
474 }
475 };
476
477 // Private key of the inner PKE
478 const InnerSk = struct {
479 sh: V, // NTT(s), normalized
480 const encoded_length = V.encoded_length;
481
482 fn decrypt(sk: InnerSk, ct: *const [ciphertext_length]u8) [inner_plaintext_length]u8 {
483 const u = V.decompress(p.du, ct[0..comptime V.compressedSize(p.du)]);
484 const v = Poly.decompress(
485 p.dv,
486 ct[comptime V.compressedSize(p.du)..ciphertext_length],
487 );
488
489 // Compute m = v - <s, u>
490 return v.sub(sk.sh.dotHat(u.ntt()).barrettReduce().invNTT())
491 .normalize().compress(1);
492 }
493
494 fn toBytes(sk: InnerSk) [encoded_length]u8 {
495 return sk.sh.toBytes();
496 }
497
498 fn fromBytes(buf: *const [encoded_length]u8) InnerSk {
499 var ret: InnerSk = undefined;
500 ret.sh = V.fromBytes(buf).normalize();
501 return ret;
502 }
503 };
504
505 // Derives inner PKE keypair from given seed.
506 fn innerKeyFromSeed(seed: [inner_seed_length]u8, pk: *InnerPk, sk: *InnerSk) void {
507 var expanded_seed: [64]u8 = undefined;
508 var h = sha3.Sha3_512.init(.{});
509 h.update(&seed);
510 if (p.ml_kem) h.update(&[1]u8{p.k});
511 h.final(&expanded_seed);
512 pk.rho = expanded_seed[0..32].*;
513 const sigma = expanded_seed[32..64];
514 pk.aT = M.uniform(pk.rho, false); // Expand ρ to A; we'll transpose later on
515
516 // Sample secret vector s.
517 sk.sh = V.noise(p.eta1, 0, sigma).ntt().normalize();
518
519 const eh = PolyVec(p.k).noise(p.eta1, p.k, sigma).ntt(); // sample blind e.
520 var th: V = undefined;
521
522 // Next, we compute t = A s + e.
523 for (0..p.k) |i| {
524 // Note that coefficients of s are bounded by q and those of A
525 // are bounded by 4.5q and so their product is bounded by 2¹⁵q
526 // as required for multiplication.
527 // A and s were not in Montgomery form, so the Montgomery
528 // multiplications in the inner product added a factor R⁻¹ which
529 // we'll cancel out with toMont(). This will also ensure the
530 // coefficients of th are bounded in absolute value by q.
531 th.ps[i] = pk.aT.rows[i].dotHat(sk.sh).toMont();
532 }
533
534 pk.th = th.add(eh).normalize(); // bounded by 8q
535 pk.aT = pk.aT.transpose();
536 }
537 };
538}
539
540// R mod q
541const r_mod_q: i32 = @rem(@as(i32, R), Q);
542
543// R² mod q
544const r2_mod_q: i32 = @rem(r_mod_q * r_mod_q, Q);
545
546// ζ is the degree 256 primitive root of unity used for the NTT.
547const zeta: i16 = 17;
548
549// (128)⁻¹ R². Used in inverse NTT.
550const r2_over_128: i32 = @mod(invertMod(128, Q) * r2_mod_q, Q);
551
552// zetas lists precomputed powers of the primitive root of unity in
553// Montgomery representation used for the NTT:
554//
555// zetas[i] = ζᵇʳᵛ⁽ⁱ⁾ R mod q
556//
557// where ζ = 17, brv(i) is the bitreversal of a 7-bit number and R=2¹⁶ mod q.
558const zetas = computeZetas();
559
560// invNTTReductions keeps track of which coefficients to apply Barrett
561// reduction to in Poly.invNTT().
562//
563// Generated lazily: once a butterfly is computed which is about to
564// overflow the i16, the largest coefficient is reduced. If that is
565// not enough, the other coefficient is reduced as well.
566//
567// This is actually optimal, as proven in https://eprint.iacr.org/2020/1377.pdf
568const inv_ntt_reductions = [_]i16{
569 -1, // after layer 1
570 -1, // after layer 2
571 16,
572 17,
573 48,
574 49,
575 80,
576 81,
577 112,
578 113,
579 144,
580 145,
581 176,
582 177,
583 208,
584 209,
585 240, 241, -1, // after layer 3
586 0, 1, 32,
587 33, 34, 35,
588 64, 65, 96,
589 97, 98, 99,
590 128, 129,
591 160, 161, 162, 163, 192, 193, 224, 225, 226, 227, -1, // after layer 4
592 2, 3, 66, 67, 68, 69, 70, 71, 130, 131, 194,
593 195, 196, 197,
594 198, 199, -1, // after layer 5
595 4, 5, 6,
596 7, 132, 133,
597 134, 135, 136,
598 137, 138, 139,
599 140, 141,
600 142, 143, -1, // after layer 6
601 -1, // after layer 7
602};
603
604test "invNTTReductions bounds" {
605 // Checks whether the reductions proposed by invNTTReductions
606 // don't overflow during invNTT().
607 var xs = [_]i32{1} ** 256; // start at |x| ≤ q
608
609 var r: usize = 0;
610 var layer: math.Log2Int(usize) = 1;
611 while (layer < 8) : (layer += 1) {
612 const w = @as(usize, 1) << layer;
613 var i: usize = 0;
614
615 while (i + w < 256) {
616 xs[i] = xs[i] + xs[i + w];
617 try testing.expect(xs[i] <= 9); // we can't exceed 9q
618 xs[i + w] = 1;
619 i += 1;
620 if (@mod(i, w) == 0) {
621 i += w;
622 }
623 }
624
625 while (true) {
626 const j = inv_ntt_reductions[r];
627 r += 1;
628 if (j < 0) {
629 break;
630 }
631 xs[@as(usize, @intCast(j))] = 1;
632 }
633 }
634}
635
636fn invertMod(a: anytype, p: @TypeOf(a)) @TypeOf(a) {
637 const r = extendedEuclidean(@TypeOf(a), a, p);
638 assert(r.gcd == 1);
639 return r.x;
640}
641
642// Reduce mod q for testing.
643fn modQ32(x: i32) i16 {
644 var y = @as(i16, @intCast(@rem(x, @as(i32, Q))));
645 if (y < 0) {
646 y += Q;
647 }
648 return y;
649}
650
651// Given -2¹⁵ q ≤ x < 2¹⁵ q, returns -q < y < q with x 2⁻¹⁶ = y (mod q).
652fn montReduce(x: i32) i16 {
653 const qInv = comptime invertMod(@as(i32, Q), R);
654 // This is Montgomery reduction with R=2¹⁶.
655 //
656 // Note gcd(2¹⁶, q) = 1 as q is prime. Write q' := 62209 = q⁻¹ mod R.
657 // First we compute
658 //
659 // m := ((x mod R) q') mod R
660 // = x q' mod R
661 // = int16(x q')
662 // = int16(int32(x) * int32(q'))
663 //
664 // Note that x q' might be as big as 2³² and could overflow the int32
665 // multiplication in the last line. However for any int32s a and b,
666 // we have int32(int64(a)*int64(b)) = int32(a*b) and so the result is ok.
667 const m: i16 = @truncate(@as(i32, @truncate(x *% qInv)));
668
669 // Note that x - m q is divisible by R; indeed modulo R we have
670 //
671 // x - m q ≡ x - x q' q ≡ x - x q⁻¹ q ≡ x - x = 0.
672 //
673 // We return y := (x - m q) / R. Note that y is indeed correct as
674 // modulo q we have
675 //
676 // y ≡ x R⁻¹ - m q R⁻¹ = x R⁻¹
677 //
678 // and as both 2¹⁵ q ≤ m q, x < 2¹⁵ q, we have
679 // 2¹⁶ q ≤ x - m q < 2¹⁶ and so q ≤ (x - m q) / R < q as desired.
680 const yR = x - @as(i32, m) * @as(i32, Q);
681 return @bitCast(@as(u16, @truncate(@as(u32, @bitCast(yR)) >> 16)));
682}
683
684test "Test montReduce" {
685 var rnd = RndGen.init(0);
686 for (0..1000) |_| {
687 const bound = comptime @as(i32, Q) * (1 << 15);
688 const x = rnd.random().intRangeLessThan(i32, -bound, bound);
689 const y = montReduce(x);
690 try testing.expect(-Q < y and y < Q);
691 try testing.expectEqual(modQ32(x), modQ32(@as(i32, y) * R));
692 }
693}
694
695// Given any x, return x R mod q where R=2¹⁶.
696fn feToMont(x: i16) i16 {
697 // Note |1353 x| ≤ 1353 2¹⁵ ≤ 13318 q ≤ 2¹⁵ q and so we're within
698 // the bounds of montReduce.
699 return montReduce(@as(i32, x) * r2_mod_q);
700}
701
702test "Test feToMont" {
703 var x: i32 = -(1 << 15);
704 while (x < 1 << 15) : (x += 1) {
705 const y = feToMont(@as(i16, @intCast(x)));
706 try testing.expectEqual(modQ32(@as(i32, y)), modQ32(x * r_mod_q));
707 }
708}
709
710// Given any x, compute 0 ≤ y ≤ q with x = y (mod q).
711//
712// Beware: we might have feBarrettReduce(x) = q ≠ 0 for some x. In fact,
713// this happens if and only if x = -nq for some positive integer n.
714fn feBarrettReduce(x: i16) i16 {
715 // This is standard Barrett reduction.
716 //
717 // For any x we have x mod q = x - ⌊x/q⌋ q. We will use 20159/2²⁶ as
718 // an approximation of 1/q. Note that 0 ≤ 20159/2²⁶ - 1/q ≤ 0.135/2²⁶
719 // and so | x 20156/2²⁶ - x/q | ≤ 2⁻¹⁰ for |x| ≤ 2¹⁶. For all x
720 // not a multiple of q, the number x/q is further than 1/q from any integer
721 // and so ⌊x 20156/2²⁶⌋ = ⌊x/q⌋. If x is a multiple of q and x is positive,
722 // then x 20156/2²⁶ is larger than x/q so ⌊x 20156/2²⁶⌋ = ⌊x/q⌋ as well.
723 // Finally, if x is negative multiple of q, then ⌊x 20156/2²⁶⌋ = ⌊x/q⌋-1.
724 // Thus
725 // [ q if x=-nq for pos. integer n
726 // x - ⌊x 20156/2²⁶⌋ q = [
727 // [ x mod q otherwise
728 //
729 // To actually compute this, note that
730 //
731 // ⌊x 20156/2²⁶⌋ = (20159 x) >> 26.
732 return x -% @as(i16, @intCast((@as(i32, x) * 20159) >> 26)) *% Q;
733}
734
735test "Test Barrett reduction" {
736 var x: i32 = -(1 << 15);
737 while (x < 1 << 15) : (x += 1) {
738 var y1 = feBarrettReduce(@as(i16, @intCast(x)));
739 const y2 = @mod(@as(i16, @intCast(x)), Q);
740 if (x < 0 and @rem(-x, Q) == 0) {
741 y1 -= Q;
742 }
743 try testing.expectEqual(y1, y2);
744 }
745}
746
747// Returns x if x < q and x - q otherwise. Assumes x ≥ -29439.
748fn csubq(x: i16) i16 {
749 var r = x;
750 r -= Q;
751 r += (r >> 15) & Q;
752 return r;
753}
754
755test "Test csubq" {
756 var x: i32 = -29439;
757 while (x < 1 << 15) : (x += 1) {
758 const y1 = csubq(@as(i16, @intCast(x)));
759 var y2 = @as(i16, @intCast(x));
760 if (@as(i16, @intCast(x)) >= Q) {
761 y2 -= Q;
762 }
763 try testing.expectEqual(y1, y2);
764 }
765}
766
767// Computes zetas table used by ntt and invNTT.
768fn computeZetas() [128]i16 {
769 @setEvalBranchQuota(10000);
770 var ret: [128]i16 = undefined;
771 for (&ret, 0..) |*r, i| {
772 const t = @as(i16, @intCast(modularPow(i32, zeta, @bitReverse(@as(u7, @intCast(i))), Q)));
773 r.* = csubq(feBarrettReduce(feToMont(t)));
774 }
775 return ret;
776}
777
778// An element of our base ring R which are polynomials over ℤ_q
779// modulo the equation Xᴺ = -1, where q=3329 and N=256.
780//
781// This type is also used to store NTT-transformed polynomials,
782// see Poly.NTT().
783//
784// Coefficients aren't always reduced. See Normalize().
785const Poly = struct {
786 cs: [N]i16,
787
788 const encoded_length = N / 2 * 3;
789 const zero: Poly = .{ .cs = .{0} ** N };
790
791 // Add two polynomials (coefficients not normalized)
792 fn add(a: Poly, b: Poly) Poly {
793 var ret: Poly = undefined;
794 for (0..N) |i| {
795 ret.cs[i] = a.cs[i] + b.cs[i];
796 }
797 return ret;
798 }
799
800 // Subtract two polynomials (coefficients not normalized)
801 fn sub(a: Poly, b: Poly) Poly {
802 var ret: Poly = undefined;
803 for (0..N) |i| {
804 ret.cs[i] = a.cs[i] - b.cs[i];
805 }
806 return ret;
807 }
808
809 // Executes a forward "NTT" on p.
810 //
811 // Assumes the coefficients are in absolute value ≤q. The resulting
812 // coefficients are in absolute value ≤7q. If the input is in Montgomery
813 // form, then the result is in Montgomery form and so (by linearity of the NTT)
814 // if the input is in regular form, then the result is also in regular form.
815 fn ntt(a: Poly) Poly {
816 // Note that ℤ_q does not have a primitive 512ᵗʰ root of unity (as 512
817 // does not divide into q-1) and so we cannot do a regular NTT. ℤ_q
818 // does have a primitive 256ᵗʰ root of unity, the smallest of which
819 // is ζ := 17.
820 //
821 // Recall that our base ring R := ℤ_q[x] / (x²⁵⁶ + 1). The polynomial
822 // x²⁵⁶+1 will not split completely (as its roots would be 512ᵗʰ roots
823 // of unity.) However, it does split almost (using ζ¹²⁸ = -1):
824 //
825 // x²⁵⁶ + 1 = (x²)¹²⁸ - ζ¹²⁸
826 // = ((x²)⁶⁴ - ζ⁶⁴)((x²)⁶⁴ + ζ⁶⁴)
827 // = ((x²)³² - ζ³²)((x²)³² + ζ³²)((x²)³² - ζ⁹⁶)((x²)³² + ζ⁹⁶)
828 // ⋮
829 // = (x² - ζ)(x² + ζ)(x² - ζ⁶⁵)(x² + ζ⁶⁵) … (x² + ζ¹²⁷)
830 //
831 // Note that the powers of ζ that appear (from the second line down) are
832 // in binary
833 //
834 // 0100000 1100000
835 // 0010000 1010000 0110000 1110000
836 // 0001000 1001000 0101000 1101000 0011000 1011000 0111000 1111000
837 // …
838 //
839 // That is: brv(2), brv(3), brv(4), …, where brv(x) denotes the 7-bit
840 // bitreversal of x. These powers of ζ are given by the Zetas array.
841 //
842 // The polynomials x² ± ζⁱ are irreducible and coprime, hence by
843 // the Chinese Remainder Theorem we know
844 //
845 // ℤ_q[x]/(x²⁵⁶+1) → ℤ_q[x]/(x²-ζ) x … x ℤ_q[x]/(x²+ζ¹²⁷)
846 //
847 // given by a ↦ ( a mod x²-ζ, …, a mod x²+ζ¹²⁷ )
848 // is an isomorphism, which is the "NTT". It can be efficiently computed by
849 //
850 //
851 // a ↦ ( a mod (x²)⁶⁴ - ζ⁶⁴, a mod (x²)⁶⁴ + ζ⁶⁴ )
852 // ↦ ( a mod (x²)³² - ζ³², a mod (x²)³² + ζ³²,
853 // a mod (x²)⁹⁶ - ζ⁹⁶, a mod (x²)⁹⁶ + ζ⁹⁶ )
854 //
855 // et cetera
856 // If N was 8 then this can be pictured in the following diagram:
857 //
858 // https://cnx.org/resources/17ee4dfe517a6adda05377b25a00bf6e6c93c334/File0026.png
859 //
860 // Each cross is a Cooley-Tukey butterfly: it's the map
861 //
862 // (a, b) ↦ (a + ζb, a - ζb)
863 //
864 // for the appropriate power ζ for that column and row group.
865 var p = a;
866 var k: usize = 0; // index into zetas
867
868 var l = N >> 1;
869 while (l > 1) : (l >>= 1) {
870 // On the nᵗʰ iteration of the l-loop, the absolute value of the
871 // coefficients are bounded by nq.
872
873 // offset effectively loops over the row groups in this column; it is
874 // the first row in the row group.
875 var offset: usize = 0;
876 while (offset < N - l) : (offset += 2 * l) {
877 k += 1;
878 const z = @as(i32, zetas[k]);
879
880 // j loops over each butterfly in the row group.
881 for (offset..offset + l) |j| {
882 const t = montReduce(z * @as(i32, p.cs[j + l]));
883 p.cs[j + l] = p.cs[j] - t;
884 p.cs[j] += t;
885 }
886 }
887 }
888
889 return p;
890 }
891
892 // Executes an inverse "NTT" on p and multiply by the Montgomery factor R.
893 //
894 // Assumes the coefficients are in absolute value ≤q. The resulting
895 // coefficients are in absolute value ≤q. If the input is in Montgomery
896 // form, then the result is in Montgomery form and so (by linearity)
897 // if the input is in regular form, then the result is also in regular form.
898 fn invNTT(a: Poly) Poly {
899 var k: usize = 127; // index into zetas
900 var r: usize = 0; // index into invNTTReductions
901 var p = a;
902
903 // We basically do the oppposite of NTT, but postpone dividing by 2 in the
904 // inverse of the Cooley-Tukey butterfly and accumulate that into a big
905 // division by 2⁷ at the end. See the comments in the ntt() function.
906
907 var l: usize = 2;
908 while (l < N) : (l <<= 1) {
909 var offset: usize = 0;
910 while (offset < N - l) : (offset += 2 * l) {
911 // As we're inverting, we need powers of ζ⁻¹ (instead of ζ).
912 // To be precise, we need ζᵇʳᵛ⁽ᵏ⁾⁻¹²⁸. However, as ζ⁻¹²⁸ = -1,
913 // we can use the existing zetas table instead of
914 // keeping a separate invZetas table as in Dilithium.
915
916 const minZeta = @as(i32, zetas[k]);
917 k -= 1;
918
919 for (offset..offset + l) |j| {
920 // Gentleman-Sande butterfly: (a, b) ↦ (a + b, ζ(a-b))
921 const t = p.cs[j + l] - p.cs[j];
922 p.cs[j] += p.cs[j + l];
923 p.cs[j + l] = montReduce(minZeta * @as(i32, t));
924
925 // Note that if we had |a| < αq and |b| < βq before the
926 // butterfly, then now we have |a| < (α+β)q and |b| < q.
927 }
928 }
929
930 // We let the invNTTReductions instruct us which coefficients to
931 // Barrett reduce.
932 while (true) {
933 const i = inv_ntt_reductions[r];
934 r += 1;
935 if (i < 0) {
936 break;
937 }
938 p.cs[@as(usize, @intCast(i))] = feBarrettReduce(p.cs[@as(usize, @intCast(i))]);
939 }
940 }
941
942 for (0..N) |j| {
943 // Note 1441 = (128)⁻¹ R². The coefficients are bounded by 9q, so
944 // as 1441 * 9 ≈ 2¹⁴ < 2¹⁵, we're within the required bounds
945 // for montReduce().
946 p.cs[j] = montReduce(r2_over_128 * @as(i32, p.cs[j]));
947 }
948
949 return p;
950 }
951
952 // Normalizes coefficients.
953 //
954 // Ensures each coefficient is in {0, …, q-1}.
955 fn normalize(a: Poly) Poly {
956 var ret: Poly = undefined;
957 for (0..N) |i| {
958 ret.cs[i] = csubq(feBarrettReduce(a.cs[i]));
959 }
960 return ret;
961 }
962
963 // Put p in Montgomery form.
964 fn toMont(a: Poly) Poly {
965 var ret: Poly = undefined;
966 for (0..N) |i| {
967 ret.cs[i] = feToMont(a.cs[i]);
968 }
969 return ret;
970 }
971
972 // Barret reduce coefficients.
973 //
974 // Beware, this does not fully normalize coefficients.
975 fn barrettReduce(a: Poly) Poly {
976 var ret: Poly = undefined;
977 for (0..N) |i| {
978 ret.cs[i] = feBarrettReduce(a.cs[i]);
979 }
980 return ret;
981 }
982
983 fn compressedSize(comptime d: u8) usize {
984 return @divTrunc(N * d, 8);
985 }
986
987 // Returns packed Compress_q(p, d).
988 //
989 // Assumes p is normalized.
990 fn compress(p: Poly, comptime d: u8) [compressedSize(d)]u8 {
991 @setEvalBranchQuota(10000);
992 const q_over_2: u32 = comptime @divTrunc(Q, 2); // (q-1)/2
993 const two_d_min_1: u32 = comptime (1 << d) - 1; // 2ᵈ-1
994 var in_off: usize = 0;
995 var out_off: usize = 0;
996
997 const batch_size: usize = comptime math.lcm(d, 8);
998 const in_batch_size: usize = comptime batch_size / d;
999 const out_batch_size: usize = comptime batch_size / 8;
1000
1001 const out_length: usize = comptime @divTrunc(N * d, 8);
1002 comptime assert(out_length * 8 == d * N);
1003 var out = [_]u8{0} ** out_length;
1004
1005 while (in_off < N) {
1006 // First we compress into in.
1007 var in: [in_batch_size]u16 = undefined;
1008 inline for (0..in_batch_size) |i| {
1009 // Compress_q(x, d) = ⌈(2ᵈ/q)x⌋ mod⁺ 2ᵈ
1010 // = ⌊(2ᵈ/q)x+½⌋ mod⁺ 2ᵈ
1011 // = ⌊((x << d) + q/2) / q⌋ mod⁺ 2ᵈ
1012 // = DIV((x << d) + q/2, q) & ((1<<d) - 1)
1013 const t = @as(u24, @intCast(p.cs[in_off + i])) << d;
1014 // Division by invariant multiplication, equivalent to DIV(t + q/2, q).
1015 // A division may not be a constant-time operation, even with a constant denominator.
1016 // Here, side channels would leak information about the shared secret, see https://kyberslash.cr.yp.to
1017 // Multiplication, on the other hand, is a constant-time operation on the CPUs we currently support.
1018 comptime assert(d <= 11);
1019 comptime assert(((20642679 * @as(u64, Q)) >> 36) == 1);
1020 const u: u32 = @intCast((@as(u64, t + q_over_2) * 20642679) >> 36);
1021 in[i] = @intCast(u & two_d_min_1);
1022 }
1023
1024 // Now we pack the d-bit integers from `in' into out as bytes.
1025 comptime var in_shift: usize = 0;
1026 comptime var j: usize = 0;
1027 comptime var i: usize = 0;
1028 inline while (i < in_batch_size) : (j += 1) {
1029 comptime var todo: usize = 8;
1030 inline while (todo > 0) {
1031 const out_shift = comptime 8 - todo;
1032 out[out_off + j] |= @as(u8, @truncate((in[i] >> in_shift) << out_shift));
1033
1034 const done = comptime @min(@min(d, todo), d - in_shift);
1035 todo -= done;
1036 in_shift += done;
1037
1038 if (in_shift == d) {
1039 in_shift = 0;
1040 i += 1;
1041 }
1042 }
1043 }
1044
1045 in_off += in_batch_size;
1046 out_off += out_batch_size;
1047 }
1048
1049 return out;
1050 }
1051
1052 // Set p to Decompress_q(m, d).
1053 fn decompress(comptime d: u8, in: *const [compressedSize(d)]u8) Poly {
1054 @setEvalBranchQuota(10000);
1055 const in_len = comptime @divTrunc(N * d, 8);
1056 comptime assert(in_len * 8 == d * N);
1057 var ret: Poly = undefined;
1058 var in_off: usize = 0;
1059 var out_off: usize = 0;
1060
1061 const batch_size: usize = comptime math.lcm(d, 8);
1062 const in_batch_size: usize = comptime batch_size / 8;
1063 const out_batch_size: usize = comptime batch_size / d;
1064
1065 while (out_off < N) {
1066 comptime var in_shift: usize = 0;
1067 comptime var j: usize = 0;
1068 comptime var i: usize = 0;
1069 inline while (i < out_batch_size) : (i += 1) {
1070 // First, unpack next coefficient.
1071 comptime var todo = d;
1072 var out: u16 = 0;
1073
1074 inline while (todo > 0) {
1075 const out_shift = comptime d - todo;
1076 const m = comptime (1 << d) - 1;
1077 out |= (@as(u16, in[in_off + j] >> in_shift) << out_shift) & m;
1078
1079 const done = comptime @min(@min(8, todo), 8 - in_shift);
1080 todo -= done;
1081 in_shift += done;
1082
1083 if (in_shift == 8) {
1084 in_shift = 0;
1085 j += 1;
1086 }
1087 }
1088
1089 // Decompress_q(x, d) = ⌈(q/2ᵈ)x⌋
1090 // = ⌊(q/2ᵈ)x+½⌋
1091 // = ⌊(qx + 2ᵈ⁻¹)/2ᵈ⌋
1092 // = (qx + (1<<(d-1))) >> d
1093 const qx = @as(u32, out) * @as(u32, Q);
1094 ret.cs[out_off + i] = @as(i16, @intCast((qx + (1 << (d - 1))) >> d));
1095 }
1096
1097 in_off += in_batch_size;
1098 out_off += out_batch_size;
1099 }
1100
1101 return ret;
1102 }
1103
1104 // Returns the "pointwise" multiplication a o b.
1105 //
1106 // That is: invNTT(a o b) = invNTT(a) * invNTT(b). Assumes a and b are in
1107 // Montgomery form. Products between coefficients of a and b must be strictly
1108 // bounded in absolute value by 2¹⁵q. a o b will be in Montgomery form and
1109 // bounded in absolute value by 2q.
1110 fn mulHat(a: Poly, b: Poly) Poly {
1111 // Recall from the discussion in ntt(), that a transformed polynomial is
1112 // an element of ℤ_q[x]/(x²-ζ) x … x ℤ_q[x]/(x²+ζ¹²⁷);
1113 // that is: 128 degree-one polynomials instead of simply 256 elements
1114 // from ℤ_q as in the regular NTT. So instead of pointwise multiplication,
1115 // we multiply the 128 pairs of degree-one polynomials modulo the
1116 // right equation:
1117 //
1118 // (a₁ + a₂x)(b₁ + b₂x) = a₁b₁ + a₂b₂ζ' + (a₁b₂ + a₂b₁)x,
1119 //
1120 // where ζ' is the appropriate power of ζ.
1121
1122 var p: Poly = undefined;
1123 var k: usize = 64;
1124 var i: usize = 0;
1125 while (i < N) : (i += 4) {
1126 const z = @as(i32, zetas[k]);
1127 k += 1;
1128
1129 const a1b1 = montReduce(@as(i32, a.cs[i + 1]) * @as(i32, b.cs[i + 1]));
1130 const a0b0 = montReduce(@as(i32, a.cs[i]) * @as(i32, b.cs[i]));
1131 const a1b0 = montReduce(@as(i32, a.cs[i + 1]) * @as(i32, b.cs[i]));
1132 const a0b1 = montReduce(@as(i32, a.cs[i]) * @as(i32, b.cs[i + 1]));
1133
1134 p.cs[i] = montReduce(a1b1 * z) + a0b0;
1135 p.cs[i + 1] = a0b1 + a1b0;
1136
1137 const a3b3 = montReduce(@as(i32, a.cs[i + 3]) * @as(i32, b.cs[i + 3]));
1138 const a2b2 = montReduce(@as(i32, a.cs[i + 2]) * @as(i32, b.cs[i + 2]));
1139 const a3b2 = montReduce(@as(i32, a.cs[i + 3]) * @as(i32, b.cs[i + 2]));
1140 const a2b3 = montReduce(@as(i32, a.cs[i + 2]) * @as(i32, b.cs[i + 3]));
1141
1142 p.cs[i + 2] = a2b2 - montReduce(a3b3 * z);
1143 p.cs[i + 3] = a2b3 + a3b2;
1144 }
1145
1146 return p;
1147 }
1148
1149 // Sample p from a centered binomial distribution with n=2η and p=½ - viz:
1150 // coefficients are in {-η, …, η} with probabilities
1151 //
1152 // {ncr(0, 2η)/2^2η, ncr(1, 2η)/2^2η, …, ncr(2η,2η)/2^2η}
1153 fn noise(comptime eta: u8, nonce: u8, seed: *const [32]u8) Poly {
1154 var h = sha3.Shake256.init(.{});
1155 const suffix: [1]u8 = .{nonce};
1156 h.update(seed);
1157 h.update(&suffix);
1158
1159 // The distribution at hand is exactly the same as that
1160 // of (a₁ + a₂ + … + a_η) - (b₁ + … + b_η) where a_i,b_i~U(1).
1161 // Thus we need 2η bits per coefficient.
1162 const buf_len = comptime 2 * eta * N / 8;
1163 var buf: [buf_len]u8 = undefined;
1164 h.squeeze(&buf);
1165
1166 // buf is interpreted as a₁…a_ηb₁…b_ηa₁…a_ηb₁…b_η…. We process
1167 // multiple coefficients in one batch.
1168
1169 const T = switch (builtin.target.cpu.arch) {
1170 .x86_64, .x86 => u32, // Generates better code on Intel CPUs
1171 else => u64, // u128 might be faster on some other CPUs.
1172 };
1173
1174 comptime var batch_count: usize = undefined;
1175 comptime var batch_bytes: usize = undefined;
1176 comptime var mask: T = 0;
1177 comptime {
1178 batch_count = @bitSizeOf(T) / @as(usize, 2 * eta);
1179 while (@rem(N, batch_count) != 0 and batch_count > 0) : (batch_count -= 1) {}
1180 assert(batch_count > 0);
1181 assert(@rem(2 * eta * batch_count, 8) == 0);
1182 batch_bytes = 2 * eta * batch_count / 8;
1183
1184 for (0..2 * eta * batch_count) |_| {
1185 mask <<= eta;
1186 mask |= 1;
1187 }
1188 }
1189
1190 var ret: Poly = undefined;
1191 for (0..comptime N / batch_count) |i| {
1192 // Read coefficients into t. In the case of η=3,
1193 // we have t = a₁ + 2a₂ + 4a₃ + 8b₁ + 16b₂ + …
1194 var t: T = 0;
1195 inline for (0..batch_bytes) |j| {
1196 t |= @as(T, buf[batch_bytes * i + j]) << (8 * j);
1197 }
1198
1199 // Accumulate `a's and `b's together by masking them out, shifting
1200 // and adding. For η=3, we have d = a₁ + a₂ + a₃ + 8(b₁ + b₂ + b₃) + …
1201 var d: T = 0;
1202 inline for (0..eta) |j| {
1203 d += (t >> j) & mask;
1204 }
1205
1206 // Extract each a and b separately and set coefficient in polynomial.
1207 inline for (0..batch_count) |j| {
1208 const mask2 = comptime (1 << eta) - 1;
1209 const a = @as(i16, @intCast((d >> (comptime (2 * j * eta))) & mask2));
1210 const b = @as(i16, @intCast((d >> (comptime ((2 * j + 1) * eta))) & mask2));
1211 ret.cs[batch_count * i + j] = a - b;
1212 }
1213 }
1214
1215 return ret;
1216 }
1217
1218 fn uniform(seed: [32]u8, x: u8, y: u8) Poly {
1219 const domain_sep: [2]u8 = .{ x, y };
1220 return sampleUniformRejection(
1221 Poly,
1222 Q,
1223 12,
1224 N,
1225 &seed,
1226 &domain_sep,
1227 );
1228 }
1229
1230 // Packs p.
1231 //
1232 // Assumes p is normalized (and not just Barrett reduced).
1233 fn toBytes(p: Poly) [encoded_length]u8 {
1234 var ret: [encoded_length]u8 = undefined;
1235 for (0..comptime N / 2) |i| {
1236 const t0 = @as(u16, @intCast(p.cs[2 * i]));
1237 const t1 = @as(u16, @intCast(p.cs[2 * i + 1]));
1238 ret[3 * i] = @as(u8, @truncate(t0));
1239 ret[3 * i + 1] = @as(u8, @truncate((t0 >> 8) | (t1 << 4)));
1240 ret[3 * i + 2] = @as(u8, @truncate(t1 >> 4));
1241 }
1242 return ret;
1243 }
1244
1245 // Unpacks a Poly from buf.
1246 //
1247 // p will not be normalized; instead 0 ≤ p[i] < 4096.
1248 fn fromBytes(buf: *const [encoded_length]u8) Poly {
1249 var ret: Poly = undefined;
1250 for (0..comptime N / 2) |i| {
1251 const b0 = @as(i16, buf[3 * i]);
1252 const b1 = @as(i16, buf[3 * i + 1]);
1253 const b2 = @as(i16, buf[3 * i + 2]);
1254 ret.cs[2 * i] = b0 | ((b1 & 0xf) << 8);
1255 ret.cs[2 * i + 1] = (b1 >> 4) | b2 << 4;
1256 }
1257 return ret;
1258 }
1259};
1260
1261// A vector of k polynomials.
1262fn PolyVec(comptime k: u8) type {
1263 return struct {
1264 ps: [k]Poly,
1265
1266 const Self = @This();
1267 const encoded_length = k * Poly.encoded_length;
1268
1269 fn compressedSize(comptime d: u8) usize {
1270 return Poly.compressedSize(d) * k;
1271 }
1272
1273 /// Apply unary operation to each polynomial
1274 fn map(v: Self, comptime op: fn (Poly) Poly) Self {
1275 var ret: Self = undefined;
1276 inline for (0..k) |i| {
1277 ret.ps[i] = op(v.ps[i]);
1278 }
1279 return ret;
1280 }
1281
1282 /// Apply binary operation pairwise
1283 fn mapBinary(a: Self, b: Self, comptime op: fn (Poly, Poly) Poly) Self {
1284 var ret: Self = undefined;
1285 inline for (0..k) |i| {
1286 ret.ps[i] = op(a.ps[i], b.ps[i]);
1287 }
1288 return ret;
1289 }
1290
1291 fn ntt(v: Self) Self {
1292 return map(v, Poly.ntt);
1293 }
1294
1295 fn invNTT(v: Self) Self {
1296 return map(v, Poly.invNTT);
1297 }
1298
1299 fn normalize(v: Self) Self {
1300 return map(v, Poly.normalize);
1301 }
1302
1303 fn barrettReduce(v: Self) Self {
1304 return map(v, Poly.barrettReduce);
1305 }
1306
1307 fn add(a: Self, b: Self) Self {
1308 return mapBinary(a, b, Poly.add);
1309 }
1310
1311 fn sub(a: Self, b: Self) Self {
1312 return mapBinary(a, b, Poly.sub);
1313 }
1314
1315 // Samples v[i] from centered binomial distribution with the given η,
1316 // seed and nonce+i.
1317 fn noise(comptime eta: u8, nonce: u8, seed: *const [32]u8) Self {
1318 var ret: Self = undefined;
1319 for (0..k) |i| {
1320 ret.ps[i] = Poly.noise(eta, nonce + @as(u8, @intCast(i)), seed);
1321 }
1322 return ret;
1323 }
1324
1325 // Sets p to the inner product of a and b using "pointwise" multiplication.
1326 //
1327 // See MulHat() and NTT() for a description of the multiplication.
1328 // Assumes a and b are in Montgomery form. p will be in Montgomery form,
1329 // and its coefficients will be bounded in absolute value by 2kq.
1330 // If a and b are not in Montgomery form, then the action is the same
1331 // as "pointwise" multiplication followed by multiplying by R⁻¹, the inverse
1332 // of the Montgomery factor.
1333 fn dotHat(a: Self, b: Self) Poly {
1334 var ret: Poly = Poly.zero;
1335 for (0..k) |i| {
1336 ret = ret.add(a.ps[i].mulHat(b.ps[i]));
1337 }
1338 return ret;
1339 }
1340
1341 fn compress(v: Self, comptime d: u8) [compressedSize(d)]u8 {
1342 const cs = comptime Poly.compressedSize(d);
1343 var ret: [compressedSize(d)]u8 = undefined;
1344 inline for (0..k) |i| {
1345 ret[i * cs .. (i + 1) * cs].* = v.ps[i].compress(d);
1346 }
1347 return ret;
1348 }
1349
1350 fn decompress(comptime d: u8, buf: *const [compressedSize(d)]u8) Self {
1351 const cs = comptime Poly.compressedSize(d);
1352 var ret: Self = undefined;
1353 inline for (0..k) |i| {
1354 ret.ps[i] = Poly.decompress(d, buf[i * cs .. (i + 1) * cs]);
1355 }
1356 return ret;
1357 }
1358
1359 /// Serializes the key into a byte array.
1360 fn toBytes(v: Self) [encoded_length]u8 {
1361 var ret: [encoded_length]u8 = undefined;
1362 inline for (0..k) |i| {
1363 ret[i * Poly.encoded_length .. (i + 1) * Poly.encoded_length].* = v.ps[i].toBytes();
1364 }
1365 return ret;
1366 }
1367
1368 /// Deserializes the key from a byte array.
1369 fn fromBytes(buf: *const [encoded_length]u8) Self {
1370 var ret: Self = undefined;
1371 inline for (0..k) |i| {
1372 ret.ps[i] = Poly.fromBytes(
1373 buf[i * Poly.encoded_length .. (i + 1) * Poly.encoded_length],
1374 );
1375 }
1376 return ret;
1377 }
1378 };
1379}
1380
1381// A matrix of k vectors
1382fn Mat(comptime k: u8) type {
1383 return struct {
1384 const Self = @This();
1385 rows: [k]PolyVec(k),
1386
1387 fn uniform(seed: [32]u8, comptime transposed: bool) Self {
1388 var ret: Self = undefined;
1389 var i: u8 = 0;
1390 while (i < k) : (i += 1) {
1391 var j: u8 = 0;
1392 while (j < k) : (j += 1) {
1393 ret.rows[i].ps[j] = Poly.uniform(
1394 seed,
1395 if (transposed) i else j,
1396 if (transposed) j else i,
1397 );
1398 }
1399 }
1400 return ret;
1401 }
1402
1403 // Returns transpose of A
1404 fn transpose(m: Self) Self {
1405 var ret: Self = undefined;
1406 for (0..k) |i| {
1407 for (0..k) |j| {
1408 ret.rows[i].ps[j] = m.rows[j].ps[i];
1409 }
1410 }
1411 return ret;
1412 }
1413 };
1414}
1415
1416// Returns `true` if a ≠ b.
1417fn ctneq(comptime len: usize, a: [len]u8, b: [len]u8) u1 {
1418 return 1 - @intFromBool(crypto.timing_safe.eql([len]u8, a, b));
1419}
1420
1421// Copy src into dst given b = 1.
1422fn cmov(comptime len: usize, dst: *[len]u8, src: [len]u8, b: u1) void {
1423 const mask = @as(u8, 0) -% b;
1424 for (0..len) |i| {
1425 dst[i] ^= mask & (dst[i] ^ src[i]);
1426 }
1427}
1428
1429// Test helper: generates a random polynomial with each coefficient |x| ≤ q
1430fn randPolyAbsLeqQ(rnd: anytype) Poly {
1431 var ret: Poly = undefined;
1432 for (0..N) |i| {
1433 ret.cs[i] = rnd.random().intRangeAtMost(i16, -Q, Q);
1434 }
1435 return ret;
1436}
1437
1438// Test helper: generates a random normalized polynomial
1439fn randPolyNormalized(rnd: anytype) Poly {
1440 var ret: Poly = undefined;
1441 for (0..N) |i| {
1442 ret.cs[i] = rnd.random().intRangeLessThan(i16, 0, Q);
1443 }
1444 return ret;
1445}
1446
1447test "MulHat" {
1448 if (comptime builtin.cpu.has(.s390x, .vector)) return error.SkipZigTest;
1449
1450 var rnd = RndGen.init(0);
1451
1452 for (0..100) |_| {
1453 const a = randPolyAbsLeqQ(&rnd);
1454 const b = randPolyAbsLeqQ(&rnd);
1455
1456 const p2 = a.ntt().mulHat(b.ntt()).barrettReduce().invNTT().normalize();
1457 var p: Poly = undefined;
1458
1459 @memset(&p.cs, 0);
1460
1461 for (0..N) |i| {
1462 for (0..N) |j| {
1463 var v = montReduce(@as(i32, a.cs[i]) * @as(i32, b.cs[j]));
1464 var k = i + j;
1465 if (k >= N) {
1466 // Recall Xᴺ = -1.
1467 k -= N;
1468 v = -v;
1469 }
1470 p.cs[k] = feBarrettReduce(v + p.cs[k]);
1471 }
1472 }
1473
1474 p = p.toMont().normalize();
1475
1476 try testing.expectEqual(p, p2);
1477 }
1478}
1479
1480test "NTT" {
1481 var rnd = RndGen.init(0);
1482
1483 for (0..1000) |_| {
1484 var p = randPolyAbsLeqQ(&rnd);
1485 const q = p.toMont().normalize();
1486 p = p.ntt();
1487
1488 for (0..N) |i| {
1489 try testing.expect(p.cs[i] <= 7 * Q and -7 * Q <= p.cs[i]);
1490 }
1491
1492 p = p.normalize().invNTT();
1493 for (0..N) |i| {
1494 try testing.expect(p.cs[i] <= Q and -Q <= p.cs[i]);
1495 }
1496
1497 p = p.normalize();
1498
1499 try testing.expectEqual(p, q);
1500 }
1501}
1502
1503test "Compression" {
1504 var rnd = RndGen.init(0);
1505 inline for (.{ 1, 4, 5, 10, 11 }) |d| {
1506 for (0..1000) |_| {
1507 const p = randPolyNormalized(&rnd);
1508 const pp = p.compress(d);
1509 const pq = Poly.decompress(d, &pp).compress(d);
1510 try testing.expectEqual(pp, pq);
1511 }
1512 }
1513}
1514
1515test "noise" {
1516 var seed: [32]u8 = undefined;
1517 for (&seed, 0..) |*s, i| {
1518 s.* = @as(u8, @intCast(i));
1519 }
1520 try testing.expectEqual(Poly.noise(3, 37, &seed).cs, .{
1521 0, 0, 1, -1, 0, 2, 0, -1, -1, 3, 0, 1, -2, -2, 0, 1, -2,
1522 1, 0, -2, 3, 0, 0, 0, 1, 3, 1, 1, 2, 1, -1, -1, -1, 0,
1523 1, 0, 1, 0, 2, 0, 1, -2, 0, -1, -1, -2, 1, -1, -1, 2, -1,
1524 1, 1, 2, -3, -1, -1, 0, 0, 0, 0, 1, -1, -2, -2, 0, -2, 0,
1525 0, 0, 1, 0, -1, -1, 1, -2, 2, 0, 0, 2, -2, 0, 1, 0, 1,
1526 1, 1, 0, 1, -2, -1, -2, -1, 1, 0, 0, 0, 0, 0, 1, 0, -1,
1527 -1, 0, -1, 1, 0, 1, 0, -1, -1, 0, -2, 2, 0, -2, 1, -1, 0,
1528 1, -1, -1, 2, 1, 0, 0, -2, -1, 2, 0, 0, 0, -1, -1, 3, 1,
1529 0, 1, 0, 1, 0, 2, 1, 0, 0, 1, 0, 1, 0, 0, -1, -1, -1,
1530 0, 1, 3, 1, 0, 1, 0, 1, -1, -1, -1, -1, 0, 0, -2, -1, -1,
1531 2, 0, 1, 0, 1, 0, 2, -2, 0, 1, 1, -3, -1, -2, -1, 0, 1,
1532 0, 1, -2, 2, 2, 1, 1, 0, -1, 0, -1, -1, 1, 0, -1, 2, 1,
1533 -1, 1, 2, -2, 1, 2, 0, 1, 2, 1, 0, 0, 2, 1, 2, 1, 0,
1534 2, 1, 0, 0, -1, -1, 1, -1, 0, 1, -1, 2, 2, 0, 0, -1, 1,
1535 1, 1, 1, 0, 0, -2, 0, -1, 1, 2, 0, 0, 1, 1, -1, 1, 0,
1536 1,
1537 });
1538 try testing.expectEqual(Poly.noise(2, 37, &seed).cs, .{
1539 1, 0, 1, -1, -1, -2, -1, -1, 2, 0, -1, 0, 0, -1,
1540 1, 1, -1, 1, 0, 2, -2, 0, 1, 2, 0, 0, -1, 1,
1541 0, -1, 1, -1, 1, 2, 1, 1, 0, -1, 1, -1, -2, -1,
1542 1, -1, -1, -1, 2, -1, -1, 0, 0, 1, 1, -1, 1, 1,
1543 1, 1, -1, -2, 0, 1, 0, 0, 2, 1, -1, 2, 0, 0,
1544 1, 1, 0, -1, 0, 0, -1, -1, 2, 0, 1, -1, 2, -1,
1545 -1, -1, -1, 0, -2, 0, 2, 1, 0, 0, 0, -1, 0, 0,
1546 0, -1, -1, 0, -1, -1, 0, -1, 0, 0, -2, 1, 1, 0,
1547 1, 0, 1, 0, 1, 1, -1, 2, 0, 1, -1, 1, 2, 0,
1548 0, 0, 0, -1, -1, -1, 0, 1, 0, -1, 2, 0, 0, 1,
1549 1, 1, 0, 1, -1, 1, 2, 1, 0, 2, -1, 1, -1, -2,
1550 -1, -2, -1, 1, 0, -2, -2, -1, 1, 0, 0, 0, 0, 1,
1551 0, 0, 0, 2, 2, 0, 1, 0, -1, -1, 0, 2, 0, 0,
1552 -2, 1, 0, 2, 1, -1, -2, 0, 0, -1, 1, 1, 0, 0,
1553 2, 0, 1, 1, -2, 1, -2, 1, 1, 0, 2, 0, -1, 0,
1554 -1, 0, 1, 2, 0, 1, 0, -2, 1, -2, -2, 1, -1, 0,
1555 -1, 1, 1, 0, 0, 0, 1, 0, -1, 1, 1, 0, 0, 0,
1556 0, 1, 0, 1, -1, 0, 1, -1, -1, 2, 0, 0, 1, -1,
1557 0, 1, -1, 0,
1558 });
1559}
1560
1561test "uniform sampling" {
1562 var seed: [32]u8 = undefined;
1563 for (&seed, 0..) |*s, i| {
1564 s.* = @as(u8, @intCast(i));
1565 }
1566 try testing.expectEqual(Poly.uniform(seed, 1, 0).cs, .{
1567 797, 993, 161, 6, 2608, 2385, 2096, 2661, 1676, 247, 2440,
1568 342, 634, 194, 1570, 2848, 986, 684, 3148, 3208, 2018, 351,
1569 2288, 612, 1394, 170, 1521, 3119, 58, 596, 2093, 1549, 409,
1570 2156, 1934, 1730, 1324, 388, 446, 418, 1719, 2202, 1812, 98,
1571 1019, 2369, 214, 2699, 28, 1523, 2824, 273, 402, 2899, 246,
1572 210, 1288, 863, 2708, 177, 3076, 349, 44, 949, 854, 1371,
1573 957, 292, 2502, 1617, 1501, 254, 7, 1761, 2581, 2206, 2655,
1574 1211, 629, 1274, 2358, 816, 2766, 2115, 2985, 1006, 2433, 856,
1575 2596, 3192, 1, 1378, 2345, 707, 1891, 1669, 536, 1221, 710,
1576 2511, 120, 1176, 322, 1897, 2309, 595, 2950, 1171, 801, 1848,
1577 695, 2912, 1396, 1931, 1775, 2904, 893, 2507, 1810, 2873, 253,
1578 1529, 1047, 2615, 1687, 831, 1414, 965, 3169, 1887, 753, 3246,
1579 1937, 115, 2953, 586, 545, 1621, 1667, 3187, 1654, 1988, 1857,
1580 512, 1239, 1219, 898, 3106, 391, 1331, 2228, 3169, 586, 2412,
1581 845, 768, 156, 662, 478, 1693, 2632, 573, 2434, 1671, 173,
1582 969, 364, 1663, 2701, 2169, 813, 1000, 1471, 720, 2431, 2530,
1583 3161, 733, 1691, 527, 2634, 335, 26, 2377, 1707, 767, 3020,
1584 950, 502, 426, 1138, 3208, 2607, 2389, 44, 1358, 1392, 2334,
1585 875, 2097, 173, 1697, 2578, 942, 1817, 974, 1165, 2853, 1958,
1586 2973, 3282, 271, 1236, 1677, 2230, 673, 1554, 96, 242, 1729,
1587 2518, 1884, 2272, 71, 1382, 924, 1807, 1610, 456, 1148, 2479,
1588 2152, 238, 2208, 2329, 713, 1175, 1196, 757, 1078, 3190, 3169,
1589 708, 3117, 154, 1751, 3225, 1364, 154, 23, 2842, 1105, 1419,
1590 79, 5, 2013,
1591 });
1592}
1593
1594test "Polynomial packing" {
1595 var rnd = RndGen.init(0);
1596
1597 for (0..1000) |_| {
1598 const p = randPolyNormalized(&rnd);
1599 try testing.expectEqual(Poly.fromBytes(&p.toBytes()), p);
1600 }
1601}
1602
1603test "Test inner PKE" {
1604 if (comptime builtin.cpu.has(.s390x, .vector)) return error.SkipZigTest;
1605
1606 var seed: [32]u8 = undefined;
1607 var pt: [32]u8 = undefined;
1608 for (&seed, &pt, 0..) |*s, *p, i| {
1609 s.* = @as(u8, @intCast(i));
1610 p.* = @as(u8, @intCast(i + 32));
1611 }
1612 inline for (modes) |mode| {
1613 for (0..10) |i| {
1614 var pk: mode.InnerPk = undefined;
1615 var sk: mode.InnerSk = undefined;
1616 seed[0] = @as(u8, @intCast(i));
1617 mode.innerKeyFromSeed(seed, &pk, &sk);
1618 for (0..10) |j| {
1619 seed[1] = @as(u8, @intCast(j));
1620 try testing.expectEqual(sk.decrypt(&pk.encrypt(&pt, &seed)), pt);
1621 }
1622 }
1623 }
1624}
1625
1626test "Test happy flow" {
1627 if (comptime builtin.cpu.has(.s390x, .vector)) return error.SkipZigTest;
1628
1629 var seed: [64]u8 = undefined;
1630 for (&seed, 0..) |*s, i| {
1631 s.* = @as(u8, @intCast(i));
1632 }
1633 inline for (modes) |mode| {
1634 for (0..10) |i| {
1635 seed[0] = @as(u8, @intCast(i));
1636 const kp = try mode.KeyPair.generateDeterministic(seed);
1637 const sk = try mode.SecretKey.fromBytes(&kp.secret_key.toBytes());
1638 try testing.expectEqual(sk, kp.secret_key);
1639 const pk = try mode.PublicKey.fromBytes(&kp.public_key.toBytes());
1640 try testing.expectEqual(pk, kp.public_key);
1641 for (0..10) |j| {
1642 seed[1] = @as(u8, @intCast(j));
1643 const e = pk.encaps(seed[0..32].*);
1644 try testing.expectEqual(e.shared_secret, try sk.decaps(&e.ciphertext));
1645 }
1646 }
1647 }
1648}
1649
1650// Code to test NIST Known Answer Tests (KAT), see PQCgenKAT.c.
1651
1652test "NIST KAT test d00.Kyber512" {
1653 if (comptime builtin.cpu.has(.loongarch, .lsx)) return error.SkipZigTest;
1654 if (comptime builtin.cpu.has(.s390x, .vector)) return error.SkipZigTest;
1655
1656 try testNistKat(d00.Kyber512, "e9c2bd37133fcb40772f81559f14b1f58dccd1c816701be9ba6214d43baf4547");
1657}
1658
1659test "NIST KAT test d00.Kyber1024" {
1660 if (comptime builtin.cpu.has(.loongarch, .lsx)) return error.SkipZigTest;
1661 if (comptime builtin.cpu.has(.s390x, .vector)) return error.SkipZigTest;
1662
1663 try testNistKat(d00.Kyber1024, "89248f2f33f7f4f7051729111f3049c409a933ec904aedadf035f30fa5646cd5");
1664}
1665
1666test "NIST KAT test d00.Kyber768" {
1667 if (comptime builtin.cpu.has(.loongarch, .lsx)) return error.SkipZigTest;
1668 if (comptime builtin.cpu.has(.s390x, .vector)) return error.SkipZigTest;
1669
1670 try testNistKat(d00.Kyber768, "a1e122cad3c24bc51622e4c242d8b8acbcd3f618fee4220400605ca8f9ea02c2");
1671}
1672
1673fn testNistKat(mode: type, hash: []const u8) !void {
1674 var seed: [48]u8 = undefined;
1675 for (&seed, 0..) |*s, i| {
1676 s.* = @as(u8, @intCast(i));
1677 }
1678 var fw: std.Io.Writer.Hashing(crypto.hash.sha2.Sha256) = .init(&.{});
1679 var g = NistDRBG.init(seed);
1680 try fw.writer.print("# {s}\n\n", .{mode.name});
1681 for (0..100) |i| {
1682 g.fill(&seed);
1683 try fw.writer.print("count = {}\n", .{i});
1684 try fw.writer.print("seed = {X}\n", .{&seed});
1685 var g2 = NistDRBG.init(seed);
1686
1687 // This is not equivalent to g2.fill(kseed[:]). As the reference
1688 // implementation calls randombytes twice generating the keypair,
1689 // we have to do that as well.
1690 var kseed: [64]u8 = undefined;
1691 var eseed: [32]u8 = undefined;
1692 g2.fill(kseed[0..32]);
1693 g2.fill(kseed[32..64]);
1694 g2.fill(&eseed);
1695 const kp = try mode.KeyPair.generateDeterministic(kseed);
1696 const e = kp.public_key.encaps(eseed);
1697 const ss2 = try kp.secret_key.decaps(&e.ciphertext);
1698 try testing.expectEqual(ss2, e.shared_secret);
1699 try fw.writer.print("pk = {X}\n", .{&kp.public_key.toBytes()});
1700 try fw.writer.print("sk = {X}\n", .{&kp.secret_key.toBytes()});
1701 try fw.writer.print("ct = {X}\n", .{&e.ciphertext});
1702 try fw.writer.print("ss = {X}\n\n", .{&e.shared_secret});
1703 }
1704
1705 var out: [32]u8 = undefined;
1706 fw.hasher.final(&out);
1707 var outHex: [64]u8 = undefined;
1708 _ = try std.fmt.bufPrint(&outHex, "{x}", .{&out});
1709 try testing.expectEqualStrings(&outHex, hash);
1710}
1711
1712const NistDRBG = struct {
1713 key: [32]u8,
1714 v: [16]u8,
1715
1716 fn incV(g: *NistDRBG) void {
1717 var j: usize = 15;
1718 while (j >= 0) : (j -= 1) {
1719 if (g.v[j] == 255) {
1720 g.v[j] = 0;
1721 } else {
1722 g.v[j] += 1;
1723 break;
1724 }
1725 }
1726 }
1727
1728 // AES256_CTR_DRBG_Update(pd, &g.key, &g.v).
1729 fn update(g: *NistDRBG, pd: ?[48]u8) void {
1730 var buf: [48]u8 = undefined;
1731 const ctx = crypto.core.aes.Aes256.initEnc(g.key);
1732 var i: usize = 0;
1733 while (i < 3) : (i += 1) {
1734 g.incV();
1735 var block: [16]u8 = undefined;
1736 ctx.encrypt(&block, &g.v);
1737 buf[i * 16 ..][0..16].* = block;
1738 }
1739 if (pd) |p| {
1740 for (&buf, p) |*b, x| {
1741 b.* ^= x;
1742 }
1743 }
1744 g.key = buf[0..32].*;
1745 g.v = buf[32..48].*;
1746 }
1747
1748 // randombytes.
1749 fn fill(g: *NistDRBG, out: []u8) void {
1750 var block: [16]u8 = undefined;
1751 var dst = out;
1752
1753 const ctx = crypto.core.aes.Aes256.initEnc(g.key);
1754 while (dst.len > 0) {
1755 g.incV();
1756 ctx.encrypt(&block, &g.v);
1757 if (dst.len < 16) {
1758 @memcpy(dst, block[0..dst.len]);
1759 break;
1760 }
1761 dst[0..block.len].* = block;
1762 dst = dst[16..dst.len];
1763 }
1764 g.update(null);
1765 }
1766
1767 fn init(seed: [48]u8) NistDRBG {
1768 var ret: NistDRBG = .{ .key = .{0} ** 32, .v = .{0} ** 16 };
1769 ret.update(seed);
1770 return ret;
1771 }
1772};
1773
1774/// Extended Euclidian Algorithm
1775/// Only meant to be used on comptime values; correctness matters, performance doesn't.
1776fn extendedEuclidean(comptime T: type, comptime a_: T, comptime b_: T) struct { gcd: T, x: T, y: T } {
1777 var a = a_;
1778 var b = b_;
1779 var x0: T = 1;
1780 var x1: T = 0;
1781 var y0: T = 0;
1782 var y1: T = 1;
1783
1784 while (b != 0) {
1785 const q = @divTrunc(a, b);
1786 const temp_a = a;
1787 a = b;
1788 b = temp_a - q * b;
1789
1790 const temp_x = x0;
1791 x0 = x1;
1792 x1 = temp_x - q * x1;
1793
1794 const temp_y = y0;
1795 y0 = y1;
1796 y1 = temp_y - q * y1;
1797 }
1798
1799 return .{ .gcd = a, .x = x0, .y = y0 };
1800}
1801
1802/// Modular inversion: computes a^(-1) mod p
1803/// Requires gcd(a,p) = 1. The result is normalized to the range [0, p).
1804fn modularInverse(comptime T: type, comptime a: T, comptime p: T) T {
1805 // Use a signed type for EEA computation
1806 const type_info = @typeInfo(T);
1807 const SignedT = if (type_info == .int and type_info.int.signedness == .unsigned)
1808 std.meta.Int(.signed, type_info.int.bits)
1809 else
1810 T;
1811
1812 const a_signed = @as(SignedT, @intCast(a));
1813 const p_signed = @as(SignedT, @intCast(p));
1814
1815 const r = extendedEuclidean(SignedT, a_signed, p_signed);
1816 assert(r.gcd == 1);
1817
1818 // Normalize result to [0, p)
1819 var result = r.x;
1820 while (result < 0) {
1821 result += p_signed;
1822 }
1823
1824 return @intCast(result);
1825}
1826
1827/// Modular exponentiation: computes a^s mod p using square-and-multiply algorithm.
1828fn modularPow(comptime T: type, comptime a: T, s: T, comptime p: T) T {
1829 const type_info = @typeInfo(T);
1830 const bits = type_info.int.bits;
1831 const WideT = std.meta.Int(.unsigned, bits * 2);
1832
1833 var ret: T = 1;
1834 var base: T = a;
1835 var exp = s;
1836
1837 while (exp > 0) {
1838 if (exp & 1 == 1) {
1839 ret = @intCast((@as(WideT, ret) * @as(WideT, base)) % p);
1840 }
1841 base = @intCast((@as(WideT, base) * @as(WideT, base)) % p);
1842 exp >>= 1;
1843 }
1844
1845 return ret;
1846}
1847
1848/// Creates an all-ones or all-zeros mask from a single bit value.
1849/// Returns all 1s (0xFF...FF) if bit == 1, all 0s if bit == 0.
1850fn bitMask(comptime T: type, bit: T) T {
1851 const type_info = @typeInfo(T);
1852 if (type_info != .int or type_info.int.signedness != .unsigned) {
1853 @compileError("bitMask requires an unsigned integer type");
1854 }
1855 return -%bit;
1856}
1857
1858/// Creates a mask from the sign bit of a signed integer.
1859/// Returns all 1s (0xFF...FF) if x < 0, all 0s if x >= 0.
1860fn signMask(comptime T: type, x: T) std.meta.Int(.unsigned, @typeInfo(T).int.bits) {
1861 const type_info = @typeInfo(T);
1862 if (type_info != .int) {
1863 @compileError("signMask requires an integer type");
1864 }
1865
1866 const bits = type_info.int.bits;
1867 const SignedT = std.meta.Int(.signed, bits);
1868
1869 // Convert to signed if needed, arithmetic right shift to propagate sign bit
1870 const x_signed: SignedT = if (type_info.int.signedness == .signed) x else @bitCast(x);
1871 const shifted = x_signed >> (bits - 1);
1872 return @bitCast(shifted);
1873}
1874
1875test "bitMask and signMask helpers" {
1876 try testing.expectEqual(@as(u32, 0x00000000), bitMask(u32, 0));
1877 try testing.expectEqual(@as(u32, 0xFFFFFFFF), bitMask(u32, 1));
1878 try testing.expectEqual(@as(u8, 0x00), bitMask(u8, 0));
1879 try testing.expectEqual(@as(u8, 0xFF), bitMask(u8, 1));
1880 try testing.expectEqual(@as(u64, 0x0000000000000000), bitMask(u64, 0));
1881 try testing.expectEqual(@as(u64, 0xFFFFFFFFFFFFFFFF), bitMask(u64, 1));
1882
1883 try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -1));
1884 try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -100));
1885 try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 0));
1886 try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 1));
1887 try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 100));
1888
1889 try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(u32, 0x80000000)); // MSB set
1890 try testing.expectEqual(@as(u32, 0x00000000), signMask(u32, 0x7FFFFFFF)); // MSB clear
1891}
1892
1893/// Montgomery reduction: for input x, returns y where y ≡ x*R^(-1) (mod q).
1894/// This is a generic implementation parameterized by the modulus q, its inverse qInv,
1895/// the Montgomery constant R, and the result bound.
1896///
1897/// For ML-DSA: R = 2^32, returns y < 2q
1898/// For ML-KEM: R = 2^16, returns y in range (-q, q)
1899fn montgomeryReduce(
1900 comptime InT: type,
1901 comptime OutT: type,
1902 comptime q: comptime_int,
1903 comptime qInv: comptime_int,
1904 comptime r_bits: comptime_int,
1905 x: InT,
1906) OutT {
1907 const mask = (@as(InT, 1) << r_bits) - 1;
1908 const m_full = (x *% qInv) & mask;
1909 const m: OutT = @truncate(m_full);
1910
1911 const yR = x -% @as(InT, m) * @as(InT, q);
1912 const y_shifted = @as(std.meta.Int(.unsigned, @typeInfo(InT).Int.bits), @bitCast(yR)) >> r_bits;
1913 return @bitCast(@as(std.meta.Int(.unsigned, @typeInfo(OutT).Int.bits), @truncate(y_shifted)));
1914}
1915
1916/// Uniform sampling using SHAKE-128 with rejection sampling.
1917/// Samples polynomial coefficients uniformly from [0, q) using rejection sampling.
1918///
1919/// Parameters:
1920/// - PolyType: The polynomial type to return
1921/// - q: Modulus
1922/// - bits_per_coef: Number of bits per coefficient (12 or 23)
1923/// - n: Number of coefficients
1924/// - seed: Random seed
1925/// - domain_sep: Domain separation bytes (appended to seed)
1926fn sampleUniformRejection(
1927 comptime PolyType: type,
1928 comptime q: comptime_int,
1929 comptime bits_per_coef: comptime_int,
1930 comptime n: comptime_int,
1931 seed: []const u8,
1932 domain_sep: []const u8,
1933) PolyType {
1934 var h = sha3.Shake128.init(.{});
1935 h.update(seed);
1936 h.update(domain_sep);
1937
1938 const buf_len = sha3.Shake128.block_length; // 168 bytes
1939 var buf: [buf_len]u8 = undefined;
1940
1941 var ret: PolyType = undefined;
1942 var coef_idx: usize = 0;
1943
1944 if (bits_per_coef == 12) {
1945 // ML-KEM path: pack 2 coefficients per 3 bytes (12 bits each)
1946 outer: while (true) {
1947 h.squeeze(&buf);
1948
1949 var j: usize = 0;
1950 while (j < buf_len) : (j += 3) {
1951 const b0 = @as(u16, buf[j]);
1952 const b1 = @as(u16, buf[j + 1]);
1953 const b2 = @as(u16, buf[j + 2]);
1954
1955 const ts: [2]u16 = .{
1956 b0 | ((b1 & 0xf) << 8),
1957 (b1 >> 4) | (b2 << 4),
1958 };
1959
1960 inline for (ts) |t| {
1961 if (t < q) {
1962 ret.cs[coef_idx] = @intCast(t);
1963 coef_idx += 1;
1964 if (coef_idx == n) break :outer;
1965 }
1966 }
1967 }
1968 }
1969 } else if (bits_per_coef == 23) {
1970 // ML-DSA path: 1 coefficient per 3 bytes (23 bits)
1971 while (coef_idx < n) {
1972 h.squeeze(&buf);
1973
1974 var j: usize = 0;
1975 while (j < buf_len and coef_idx < n) : (j += 3) {
1976 const t = (@as(u32, buf[j]) |
1977 (@as(u32, buf[j + 1]) << 8) |
1978 (@as(u32, buf[j + 2]) << 16)) & 0x7fffff;
1979
1980 if (t < q) {
1981 ret.cs[coef_idx] = @intCast(t);
1982 coef_idx += 1;
1983 }
1984 }
1985 }
1986 } else {
1987 @compileError("bits_per_coef must be 12 or 23");
1988 }
1989
1990 return ret;
1991}