master
1//! Module-Lattice-Based Digital Signature Algorithm (ML-DSA) as specified in NIST FIPS 204.
2//!
3//! ML-DSA is a post-quantum secure digital signature scheme based on the hardness
4//! of the Module Learning With Errors (MLWE) and Module Short Integer Solution (MSIS)
5//! problems over module lattices.
6//!
7//! We provide three parameter sets:
8//!
9//! - ML-DSA-44: NIST security category 2 (128-bit security)
10//! - ML-DSA-65: NIST security category 3 (192-bit security)
11//! - ML-DSA-87: NIST security category 5 (256-bit security)
12
13const std = @import("std");
14const builtin = @import("builtin");
15const testing = std.testing;
16const assert = std.debug.assert;
17const crypto = std.crypto;
18const errors = std.crypto.errors;
19const math = std.math;
20const mem = std.mem;
21const sha3 = crypto.hash.sha3;
22
23const ContextTooLongError = errors.ContextTooLongError;
24const EncodingError = errors.EncodingError;
25const SignatureVerificationError = errors.SignatureVerificationError;
26
27/// ML-DSA-44 (Module-Lattice-Based Digital Signature Algorithm, 44 parameter set)
28/// as specified in NIST FIPS 204.
29///
30/// This is a post-quantum signature scheme providing NIST security category 2,
31/// which is roughly equivalent to the security of SHA-256 or AES-128.
32///
33/// Key sizes:
34///
35/// - Public key: 1312 bytes
36/// - Secret key: 2560 bytes
37/// - Signature: 2420 bytes
38///
39/// Example usage:
40///
41/// ```zig
42/// const kp = MLDSA44.KeyPair.generate();
43/// const msg = "Hello, post-quantum world!";
44/// const sig = try kp.sign(msg, null);
45/// try sig.verify(msg, kp.public_key);
46/// ```
47pub const MLDSA44 = MLDSAImpl(.{
48 .name = "ML-DSA-44",
49 .k = 4,
50 .l = 4,
51 .eta = 2,
52 .omega = 80,
53 .tau = 39,
54 .gamma1_bits = 17,
55 .gamma2 = 95232, // (Q-1)/88
56 .tr_size = 64,
57 .ctilde_size = 32,
58});
59
60/// ML-DSA-65 (Module-Lattice-Based Digital Signature Algorithm, 65 parameter set)
61/// as specified in NIST FIPS 204.
62///
63/// This is a post-quantum signature scheme providing NIST security category 3,
64/// which is roughly equivalent to the security of SHA-384 or AES-192.
65///
66/// Key sizes:
67///
68/// - Public key: 1952 bytes
69/// - Secret key: 4032 bytes
70/// - Signature: 3309 bytes
71///
72/// This parameter set offers higher security than ML-DSA-44 at the cost of
73/// larger keys and signatures.
74pub const MLDSA65 = MLDSAImpl(.{
75 .name = "ML-DSA-65",
76 .k = 6,
77 .l = 5,
78 .eta = 4,
79 .omega = 55,
80 .tau = 49,
81 .gamma1_bits = 19,
82 .gamma2 = 261888, // (Q-1)/32
83 .tr_size = 64,
84 .ctilde_size = 48,
85});
86
87/// ML-DSA-87 (Module-Lattice-Based Digital Signature Algorithm, 87 parameter set)
88/// as specified in NIST FIPS 204.
89///
90/// This is a post-quantum signature scheme providing NIST security category 5,
91/// which is roughly equivalent to the security of SHA-512 or AES-256.
92///
93/// Key sizes:
94///
95/// - Public key: 2592 bytes
96/// - Secret key: 4896 bytes
97/// - Signature: 4627 bytes
98///
99/// This parameter set offers the highest security level among the three ML-DSA
100/// variants, suitable for applications requiring maximum security assurance.
101pub const MLDSA87 = MLDSAImpl(.{
102 .name = "ML-DSA-87",
103 .k = 8,
104 .l = 7,
105 .eta = 2,
106 .omega = 75,
107 .tau = 60,
108 .gamma1_bits = 19,
109 .gamma2 = 261888, // (Q-1)/32
110 .tr_size = 64,
111 .ctilde_size = 64,
112});
113
114const N: usize = 256; // Degree of polynomials
115const Q: u32 = 8380417; // Modulus: 2^23 - 2^13 + 1
116const Q_BITS: u32 = 23;
117const D: u32 = 13; // Dropped bits in power2Round
118
119// Montgomery constant R = 2^32 mod q
120const R: u64 = 1 << 32;
121
122// Q^(-1) mod 2^32 = -(q^-1) mod 2^32
123const Q_INV: u32 = 4236238847;
124
125// (256)^(-1) * R^2 mod q, used in inverse NTT
126const R_OVER_256: u32 = 41978;
127
128// Primitive 512th root of unity
129const ZETA: u32 = 1753;
130
131const Params = struct {
132 name: []const u8,
133
134 // Matrix dimensions
135 k: u8, // Height of matrix A
136 l: u8, // Width of matrix A
137
138 // Sampling parameter
139 eta: u8, // Bound for secret coefficients
140
141 // Hint parameters
142 omega: u16, // Maximum number of hint bits
143
144 // Challenge parameter
145 tau: u16, // Weight of challenge polynomial
146
147 // Rounding parameters
148 gamma1_bits: u8, // Bits for gamma1
149 gamma2: u32, // Parameter for decompose
150
151 // Sizes
152 tr_size: usize, // Size of tr hash
153 ctilde_size: usize, // Size of challenge hash
154};
155
156const Poly = struct {
157 cs: [N]u32,
158
159 const zero: Poly = .{ .cs = .{0} ** N };
160
161 // Add two polynomials (no normalization)
162 fn add(a: Poly, b: Poly) Poly {
163 var ret: Poly = undefined;
164 for (0..N) |i| {
165 ret.cs[i] = a.cs[i] + b.cs[i];
166 }
167 return ret;
168 }
169
170 // Subtract two polynomials (assumes b coefficients < 2q)
171 fn sub(a: Poly, b: Poly) Poly {
172 var ret: Poly = undefined;
173 for (0..N) |i| {
174 ret.cs[i] = a.cs[i] +% (@as(u32, 2 * Q) -% b.cs[i]);
175 }
176 return ret;
177 }
178
179 // Reduce each coefficient to < 2q
180 fn reduceLe2Q(p: Poly) Poly {
181 var ret = p;
182 for (0..N) |i| {
183 ret.cs[i] = le2Q(ret.cs[i]);
184 }
185 return ret;
186 }
187
188 // Normalize coefficients to [0, q)
189 fn normalize(p: Poly) Poly {
190 var ret = p;
191 for (0..N) |i| {
192 ret.cs[i] = modQ(ret.cs[i]);
193 }
194 return ret;
195 }
196
197 // Normalize assuming coefficients already < 2q
198 fn normalizeAssumingLe2Q(p: Poly) Poly {
199 var ret = p;
200 for (0..N) |i| {
201 ret.cs[i] = le2qModQ(ret.cs[i]);
202 }
203 return ret;
204 }
205
206 // Pointwise multiplication in NTT domain (Montgomery form)
207 fn mulHat(a: Poly, b: Poly) Poly {
208 var ret: Poly = undefined;
209 for (0..N) |i| {
210 ret.cs[i] = montReduceLe2Q(@as(u64, a.cs[i]) * @as(u64, b.cs[i]));
211 }
212 return ret;
213 }
214
215 // Forward NTT
216 fn ntt(p: Poly) Poly {
217 var ret = p;
218 ret.nttInPlace();
219 return ret;
220 }
221
222 // In-place forward NTT
223 fn nttInPlace(p: *Poly) void {
224 var k: usize = 0;
225 var l: usize = N / 2;
226
227 while (l > 0) : (l >>= 1) {
228 var offset: usize = 0;
229 while (offset < N - l) : (offset += 2 * l) {
230 k += 1;
231 const zeta: u64 = zetas[k];
232
233 for (offset..offset + l) |j| {
234 const t = montReduceLe2Q(zeta * @as(u64, p.cs[j + l]));
235 p.cs[j + l] = p.cs[j] +% (2 * Q -% t);
236 p.cs[j] +%= t;
237 }
238 }
239 }
240 }
241
242 // Inverse NTT
243 fn invNTT(p: Poly) Poly {
244 var ret = p;
245 ret.invNTTInPlace();
246 return ret;
247 }
248
249 // In-place inverse NTT
250 fn invNTTInPlace(p: *Poly) void {
251 var k: usize = 0;
252 var l: usize = 1;
253
254 while (l < N) : (l <<= 1) {
255 var offset: usize = 0;
256 while (offset < N - l) : (offset += 2 * l) {
257 const zeta: u64 = inv_zetas[k];
258 k += 1;
259
260 for (offset..offset + l) |j| {
261 const t = p.cs[j];
262 p.cs[j] = t +% p.cs[j + l];
263 p.cs[j + l] = montReduceLe2Q(zeta * @as(u64, t +% 256 * Q -% p.cs[j + l]));
264 }
265 }
266 }
267
268 for (0..N) |j| {
269 p.cs[j] = montReduceLe2Q(@as(u64, R_OVER_256) * @as(u64, p.cs[j]));
270 }
271 }
272
273 /// Apply Power2Round to all coefficients
274 /// Returns both t0 and t1 polynomials
275 fn power2RoundPoly(p: Poly) struct { t0: Poly, t1: Poly } {
276 var t0 = Poly.zero;
277 var t1 = Poly.zero;
278 for (0..N) |i| {
279 const result = power2Round(p.cs[i]);
280 t0.cs[i] = result.a0_plus_q;
281 t1.cs[i] = result.a1;
282 }
283 return .{ .t0 = t0, .t1 = t1 };
284 }
285
286 // Check if infinity norm exceeds bound
287 fn exceeds(p: Poly, bound: u32) bool {
288 var result: u32 = 0;
289 for (0..N) |i| {
290 const x = @as(i32, @intCast((Q - 1) / 2)) - @as(i32, @intCast(p.cs[i]));
291 const abs_x = x ^ (x >> 31);
292 const norm = @as(i32, @intCast((Q - 1) / 2)) - abs_x;
293 const exceeds_bit = @intFromBool(@as(u32, @intCast(norm)) >= bound);
294 result |= exceeds_bit;
295 }
296 return result != 0;
297 }
298};
299
300fn PolyVec(comptime len: u8) type {
301 return struct {
302 ps: [len]Poly,
303
304 const Self = @This();
305 const zero: Self = .{ .ps = .{Poly.zero} ** len };
306
307 /// Apply a unary operation to each polynomial in the vector
308 fn map(v: Self, comptime op: fn (Poly) Poly) Self {
309 var ret: Self = undefined;
310 inline for (0..len) |i| {
311 ret.ps[i] = op(v.ps[i]);
312 }
313 return ret;
314 }
315
316 /// Apply a binary operation pairwise to two vectors
317 fn mapBinary(a: Self, b: Self, comptime op: fn (Poly, Poly) Poly) Self {
318 var ret: Self = undefined;
319 inline for (0..len) |i| {
320 ret.ps[i] = op(a.ps[i], b.ps[i]);
321 }
322 return ret;
323 }
324
325 /// Apply a binary operation between a vector and a scalar polynomial
326 fn mapBinaryPoly(v: Self, scalar: Poly, comptime op: fn (Poly, Poly) Poly) Self {
327 var ret: Self = undefined;
328 inline for (0..len) |i| {
329 ret.ps[i] = op(v.ps[i], scalar);
330 }
331 return ret;
332 }
333
334 fn add(a: Self, b: Self) Self {
335 return mapBinary(a, b, Poly.add);
336 }
337
338 fn sub(a: Self, b: Self) Self {
339 return mapBinary(a, b, Poly.sub);
340 }
341
342 fn ntt(v: Self) Self {
343 return map(v, Poly.ntt);
344 }
345
346 fn invNTT(v: Self) Self {
347 return map(v, Poly.invNTT);
348 }
349
350 fn normalize(v: Self) Self {
351 return map(v, Poly.normalize);
352 }
353
354 fn reduceLe2Q(v: Self) Self {
355 return map(v, Poly.reduceLe2Q);
356 }
357
358 fn normalizeAssumingLe2Q(v: Self) Self {
359 return map(v, Poly.normalizeAssumingLe2Q);
360 }
361
362 // Check if any polynomial in the vector exceeds the bound
363 fn exceeds(v: Self, bound: u32) bool {
364 var result = false;
365 for (0..len) |i| {
366 result = result or v.ps[i].exceeds(bound);
367 }
368 return result;
369 }
370
371 /// Apply Power2Round to each polynomial in the vector
372 /// Returns both t0 and t1 vectors
373 fn power2Round(v: Self, t0_out: *Self) Self {
374 var t1: Self = undefined;
375 for (0..len) |i| {
376 const result = v.ps[i].power2RoundPoly();
377 t0_out.ps[i] = result.t0;
378 t1.ps[i] = result.t1;
379 }
380 return t1;
381 }
382
383 /// Generic packing function for vectors
384 fn packWith(
385 v: Self,
386 buf: []u8,
387 comptime poly_size: usize,
388 comptime pack_fn: fn (Poly, []u8) void,
389 ) void {
390 inline for (0..len) |i| {
391 const offset = i * poly_size;
392 pack_fn(v.ps[i], buf[offset..][0..poly_size]);
393 }
394 }
395
396 /// Generic unpacking function for vectors
397 fn unpackWith(
398 comptime poly_size: usize,
399 comptime unpack_fn: fn ([]const u8) Poly,
400 buf: []const u8,
401 ) Self {
402 var result: Self = undefined;
403 inline for (0..len) |i| {
404 const offset = i * poly_size;
405 result.ps[i] = unpack_fn(buf[offset..][0..poly_size]);
406 }
407 return result;
408 }
409
410 /// Pack T1 vector to bytes
411 fn packT1(v: Self, buf: []u8) void {
412 const poly_size = (N * (Q_BITS - D)) / 8;
413 packWith(v, buf, poly_size, polyPackT1);
414 }
415
416 /// Unpack T1 vector from bytes
417 fn unpackT1(bytes: []const u8) Self {
418 const poly_size = (N * (Q_BITS - D)) / 8;
419 return unpackWith(poly_size, polyUnpackT1, bytes);
420 }
421
422 /// Pack T0 vector to bytes
423 fn packT0(v: Self, buf: []u8) void {
424 const poly_size = (N * D) / 8;
425 packWith(v, buf, poly_size, polyPackT0);
426 }
427
428 /// Unpack T0 vector from bytes
429 fn unpackT0(buf: []const u8) Self {
430 const poly_size = (N * D) / 8;
431 return unpackWith(poly_size, polyUnpackT0, buf);
432 }
433
434 /// Pack vector with coefficients in [-eta, eta]
435 fn packLeqEta(v: Self, comptime eta: u8, buf: []u8) void {
436 const poly_size = if (eta == 2) 96 else 128;
437 const pack_fn = struct {
438 fn pack(p: Poly, b: []u8) void {
439 polyPackLeqEta(p, eta, b);
440 }
441 }.pack;
442 packWith(v, buf, poly_size, pack_fn);
443 }
444
445 /// Unpack vector with coefficients in [-eta, eta]
446 fn unpackLeqEta(comptime eta: u8, buf: []const u8) Self {
447 const poly_size = if (eta == 2) 96 else 128;
448 const unpack_fn = struct {
449 fn unpack(b: []const u8) Poly {
450 return polyUnpackLeqEta(eta, b);
451 }
452 }.unpack;
453 return unpackWith(poly_size, unpack_fn, buf);
454 }
455
456 /// Pack vector of polynomials with coefficients < gamma1
457 fn packLeGamma1(v: Self, comptime gamma1_bits: u8, buf: []u8) void {
458 const poly_size = ((gamma1_bits + 1) * N) / 8;
459 const pack_fn = struct {
460 fn pack(p: Poly, b: []u8) void {
461 polyPackLeGamma1(p, gamma1_bits, b);
462 }
463 }.pack;
464 packWith(v, buf, poly_size, pack_fn);
465 }
466
467 /// Unpack vector of polynomials with coefficients < gamma1
468 fn unpackLeGamma1(comptime gamma1_bits: u8, buf: []const u8) Self {
469 const poly_size = ((gamma1_bits + 1) * N) / 8;
470 const unpack_fn = struct {
471 fn unpack(b: []const u8) Poly {
472 return polyUnpackLeGamma1(gamma1_bits, b);
473 }
474 }.unpack;
475 return unpackWith(poly_size, unpack_fn, buf);
476 }
477
478 /// Pack high bits w1 for signature verification
479 fn packW1(v: Self, comptime gamma1_bits: u8, buf: []u8) void {
480 const poly_size = (N * (Q_BITS - gamma1_bits)) / 8;
481 const pack_fn = struct {
482 fn pack(p: Poly, b: []u8) void {
483 polyPackW1(p, gamma1_bits, b);
484 }
485 }.pack;
486 packWith(v, buf, poly_size, pack_fn);
487 }
488
489 /// Decompose each polynomial in the vector into high and low bits
490 fn decomposeVec(v: Self, comptime gamma2: u32, w0_out: *Self) Self {
491 var w1: Self = undefined;
492 for (0..len) |i| {
493 for (0..N) |j| {
494 const r = decompose(v.ps[i].cs[j], gamma2);
495 w0_out.ps[i].cs[j] = r.a0_plus_q;
496 w1.ps[i].cs[j] = r.a1;
497 }
498 }
499 return w1;
500 }
501
502 /// Create hints for vector, returns hint population count
503 fn makeHintVec(w0mcs2pct0: Self, w1: Self, comptime gamma2: u32) struct { hint: Self, pop: u32 } {
504 var hint: Self = undefined;
505 var pop: u32 = 0;
506 for (0..len) |i| {
507 const result = polyMakeHint(w0mcs2pct0.ps[i], w1.ps[i], gamma2);
508 hint.ps[i] = result.hint;
509 pop += result.count;
510 }
511 return .{ .hint = hint, .pop = pop };
512 }
513
514 /// Apply hints to recover high bits
515 fn useHint(v: Self, hint: Self, comptime gamma2: u32) Self {
516 var result: Self = undefined;
517 for (0..len) |i| {
518 result.ps[i] = polyUseHint(v.ps[i], hint.ps[i], gamma2);
519 }
520 return result;
521 }
522
523 /// Multiply vector by 2^D (left shift)
524 fn mulBy2toD(v: Self) Self {
525 var result: Self = undefined;
526 for (0..len) |i| {
527 for (0..N) |j| {
528 result.ps[i].cs[j] = v.ps[i].cs[j] << D;
529 }
530 }
531 return result;
532 }
533
534 /// Sample vector with coefficients uniformly in (-gamma1, gamma1]
535 /// Wraps expandMask (FIPS 204: ExpandMask)
536 fn deriveUniformLeGamma1(comptime gamma1_bits: u8, seed: *const [64]u8, nonce: u16) Self {
537 var result: Self = undefined;
538 for (0..len) |i| {
539 result.ps[i] = expandMask(gamma1_bits, seed, nonce + @as(u16, @intCast(i)));
540 }
541 return result;
542 }
543
544 /// Pack hints into bytes
545 /// Format: for each polynomial, find positions where hint[i]=1, encode those positions
546 fn packHint(v: Self, comptime omega: u16, buf: []u8) bool {
547 var idx: usize = 0;
548 var count: u32 = 0;
549
550 for (0..len) |i| {
551 for (0..N) |j| {
552 if (v.ps[i].cs[j] != 0) {
553 count += 1;
554 }
555 }
556 }
557
558 if (count > omega) {
559 return false;
560 }
561
562 // Hint encoding format per FIPS 204:
563 // First omega bytes: positions of set bits across all polynomials
564 // Last len bytes: boundary indices showing where each polynomial's hints end
565 for (0..len) |i| {
566 for (0..N) |j| {
567 if (v.ps[i].cs[j] != 0) {
568 buf[idx] = @intCast(j);
569 idx += 1;
570 }
571 }
572 buf[omega + i] = @intCast(idx);
573 }
574
575 while (idx < omega) : (idx += 1) {
576 buf[idx] = 0;
577 }
578
579 return true;
580 }
581
582 /// Unpack hints from bytes
583 fn unpackHint(comptime omega: u16, buf: []const u8) ?Self {
584 var result: Self = .{ .ps = .{Poly.zero} ** len };
585 var prev_sop: u8 = 0; // previous switch-over-point
586
587 for (0..len) |i| {
588 const sop = buf[omega + i]; // switch-over-point
589 if (sop < prev_sop or sop > omega) {
590 return null; // ensures switch-over-points are increasing
591 }
592
593 var j = prev_sop;
594 while (j < sop) : (j += 1) {
595 // Validation: indices must be strictly increasing within each polynomial
596 if (j > prev_sop and buf[j] <= buf[j - 1]) {
597 return null;
598 }
599 const pos = buf[j];
600 if (pos >= N) {
601 return null;
602 }
603 result.ps[i].cs[pos] = 1;
604 }
605 prev_sop = sop;
606 }
607
608 var j = prev_sop;
609 while (j < omega) : (j += 1) {
610 if (buf[j] != 0) {
611 return null;
612 }
613 }
614
615 return result;
616 }
617 };
618}
619
620// Matrix of k x l polynomials
621
622fn Mat(comptime k: u8, comptime l: u8) type {
623 return struct {
624 rows: [k]PolyVec(l),
625
626 const Self = @This();
627 const VecL = PolyVec(l);
628 const VecK = PolyVec(k);
629
630 /// Expand matrix A from seed rho using SHAKE-128
631 /// This is the ExpandA function from FIPS 204
632 fn derive(rho: *const [32]u8) Self {
633 var m: Self = undefined;
634 for (0..k) |i| {
635 if (i + 1 < k) {
636 @prefetch(&m.rows[i + 1], .{ .rw = .write, .locality = 2 });
637 }
638 for (0..l) |j| {
639 // Nonce is i*256 + j
640 const nonce: u16 = (@as(u16, @intCast(i)) << 8) | @as(u16, @intCast(j));
641 m.rows[i].ps[j] = polyDeriveUniform(rho, nonce);
642 }
643 }
644 return m;
645 }
646
647 /// Multiply matrix by vector in NTT domain and return result in regular domain.
648 /// Takes a vector in NTT form and returns the product in regular form.
649 fn mulVec(self: Self, v_hat: VecL) VecK {
650 var result = VecK.zero;
651 for (0..k) |i| {
652 result.ps[i] = dotHat(l, self.rows[i], v_hat);
653 result.ps[i] = result.ps[i].reduceLe2Q();
654 result.ps[i] = result.ps[i].invNTT();
655 }
656 return result;
657 }
658
659 /// Multiply matrix by vector in NTT domain and return result in NTT domain.
660 /// Takes a vector in NTT form and returns the product in NTT form.
661 fn mulVecHat(self: Self, v_hat: VecL) VecK {
662 var result: VecK = undefined;
663 for (0..k) |i| {
664 result.ps[i] = dotHat(l, self.rows[i], v_hat);
665 }
666 return result;
667 }
668 };
669}
670
671// Dot product in NTT domain
672fn dotHat(comptime len: u8, a: PolyVec(len), b: PolyVec(len)) Poly {
673 var ret = Poly.zero;
674 for (0..len) |i| {
675 const prod = a.ps[i].mulHat(b.ps[i]);
676 ret = ret.add(prod);
677 }
678 return ret;
679}
680
681// Modular arithmetic operations
682
683// Reduce x to [0, 2q) using the fact that 2^23 = 2^13 - 1 (mod q)
684fn le2Q(x: u32) u32 {
685 // Write x = x1 * 2^23 + x2 with x2 < 2^23 and x1 < 2^9
686 // Then x = x2 + x1 * 2^13 - x1 (mod q)
687 // and x2 + x1 * 2^13 - x1 <= 2^23 + 2^13 < 2q
688 const x1 = x >> 23;
689 const x2 = x & 0x7FFFFF; // 2^23 - 1
690 return x2 +% (x1 << 13) -% x1;
691}
692
693// Reduce x to [0, q)
694fn modQ(x: u32) u32 {
695 return le2qModQ(le2Q(x));
696}
697
698// Given x < 2q, reduce to [0, q)
699fn le2qModQ(x: u32) u32 {
700 const r = x -% Q;
701 const mask = signMask(u32, r);
702 return r +% (mask & Q);
703}
704
705// Montgomery reduction: for x < q*2^32, return y < 2q where y ≡ x*R^(-1) (mod q)
706// where R = 2^32. This is used for efficient modular multiplication in NTT operations.
707fn montReduceLe2Q(x: u64) u32 {
708 const m = (x *% Q_INV) & 0xffffffff;
709 return @truncate((x +% m * @as(u64, Q)) >> 32);
710}
711
712// Precomputed zetas for NTT (Montgomery form)
713// zetas[i] = zeta^brv(i) * R mod q
714const zetas = computeZetas();
715
716fn computeZetas() [N]u32 {
717 @setEvalBranchQuota(100000);
718 var ret: [N]u32 = undefined;
719
720 for (0..N) |i| {
721 const brv_i = @bitReverse(@as(u8, @intCast(i)));
722 const power = modularPow(u32, ZETA, brv_i, Q);
723 ret[i] = toMont(power);
724 }
725
726 return ret;
727}
728
729// Precomputed inverse zetas for inverse NTT
730const inv_zetas = computeInvZetas();
731
732fn computeInvZetas() [N]u32 {
733 @setEvalBranchQuota(100000);
734 var ret: [N]u32 = undefined;
735
736 const inv_zeta = modularInverse(u32, ZETA, Q);
737
738 for (0..N) |i| {
739 const idx = 255 - i;
740 const brv_idx = @bitReverse(@as(u8, @intCast(idx)));
741
742 // Exponent is -(brv_idx - 256) = 256 - brv_idx
743 const exp: u32 = @as(u32, 256) - brv_idx;
744
745 // Compute inv_zeta^exp
746 const power = modularPow(u32, inv_zeta, exp, Q);
747
748 // Convert to Montgomery form
749 ret[i] = toMont(power);
750 }
751
752 return ret;
753}
754
755// Convert to Montgomery form: x -> x * R mod q
756fn toMont(x: u32) u32 {
757 // R = 2^32, R mod q can be computed as:
758 // 2^32 mod q = 2^32 mod (2^23 - 2^13 + 1)
759 // Using the identity 2^23 = 2^13 - 1 (mod q), we can reduce 2^32
760 // But it's easier to just do: return montReduce(x * R^2 mod q)
761 // where R^2 mod q is precomputed
762
763 // Computing R^2 mod q:
764 // R = 2^32, so R^2 = 2^64
765 // We can compute this by noting that R mod q first:
766 // 2^32 = 2^32 mod q
767 // But let's use a simpler approach: multiply x by R in the Montgomery domain
768 // Actually, the simplest is: x * R mod q = montReduceLe2Q(x * R^2 mod q)
769
770 // Precompute R^2 mod q at comptime
771 const r_mod_q = comptime blk: {
772 // 2^32 mod q - compute by successive squaring
773 var r: u64 = 1;
774 for (0..32) |_| {
775 r = (r * 2) % Q;
776 }
777 break :blk @as(u32, @intCast(r));
778 };
779
780 const r2_mod_q = comptime blk: {
781 const r = @as(u64, r_mod_q);
782 break :blk @as(u32, @intCast((r * r) % Q));
783 };
784
785 return montReduceLe2Q(@as(u64, x) * @as(u64, r2_mod_q));
786}
787
788/// Splits 0 ≤ a < Q into a0 and a1 with a = a1*2^D + a0
789/// and -2^(D-1) < a0 ≤ 2^(D-1). Returns a0 + Q and a1.
790/// FIPS 204: Power2Round (Algorithm 19)
791fn power2Round(a: u32) struct { a0_plus_q: u32, a1: u32 } {
792 // We effectively compute a0 = a mod± 2^D
793 // and a1 = (a - a0) / 2^D
794 var a0 = a & ((1 << D) - 1); // a mod 2^D
795
796 // a0 is one of 0, 1, ..., 2^(D-1)-1, 2^(D-1), 2^(D-1)+1, ..., 2^D-1
797 a0 -%= (1 << (D - 1)) + 1;
798 // now a0 is -2^(D-1)-1, -2^(D-1), ..., -2, -1, 0, ..., 2^(D-1)-2
799
800 // Next, add 2^D to those a0 that are negative (seen as i32)
801 a0 +%= @as(u32, @bitCast(@as(i32, @bitCast(a0)) >> 31)) & (1 << D);
802 // now a0 is 2^(D-1)-1, 2^(D-1), ..., 2^D-2, 2^D-1, 0, ..., 2^(D-1)-2
803
804 a0 -%= (1 << (D - 1)) - 1;
805 // now a0 is 0, 1, 2, ..., 2^(D-1)-1, 2^(D-1), -2^(D-1)+1, ..., -1
806
807 const a0_plus_q = Q +% a0;
808 const a1 = (a -% a0) >> D;
809
810 return .{ .a0_plus_q = a0_plus_q, .a1 = a1 };
811}
812
813/// Splits 0 ≤ a < q into a0 and a1 with a = a1*alpha + a0 with -alpha/2 < a0 ≤ alpha/2,
814/// except when we would have a1 = (q-1)/alpha in which case a1=0 is taken
815/// and -alpha/2 ≤ a0 < 0. Returns a0 + q. Note 0 ≤ a1 < (q-1)/alpha.
816/// Recall alpha = 2*gamma2.
817fn decompose(a: u32, comptime gamma2: u32) struct { a0_plus_q: u32, a1: u32 } {
818 const alpha = 2 * gamma2;
819
820 // a1 = ⌈a / 128⌉
821 var a1 = (a + 127) >> 7;
822
823 if (alpha == 523776) {
824 // For ML-DSA-87: gamma2 = 261888, alpha = 523776
825 // 1025/2^22 is close enough to 1/4092 so that a1 becomes a/alpha rounded down
826 a1 = ((a1 * 1025 + (1 << 21)) >> 22);
827
828 // For the corner-case a1 = (q-1)/alpha = 16, we have to set a1=0
829 a1 &= 15;
830 } else if (alpha == 190464) {
831 // For ML-DSA-65: gamma2 = 95232, alpha = 190464
832 // 11275/2^24 is close enough to 1/1488 so that a1 becomes a/alpha rounded down
833 a1 = ((a1 * 11275) + (1 << 23)) >> 24;
834
835 // For the corner-case a1 = (q-1)/alpha = 44, we have to set a1=0
836 a1 ^= @as(u32, @bitCast(@as(i32, @bitCast(43 -% a1)) >> 31)) & a1;
837 } else {
838 @compileError("unsupported gamma2/alpha value");
839 }
840
841 var a0_plus_q = a -% a1 * alpha;
842
843 // In the corner-case, when we set a1=0, we will incorrectly
844 // have a0 > (q-1)/2 and we'll need to subtract q. As we
845 // return a0 + q, that comes down to adding q if a0 < (q-1)/2.
846 a0_plus_q +%= @as(u32, @bitCast(@as(i32, @bitCast(a0_plus_q -% (Q - 1) / 2)) >> 31)) & Q;
847
848 return .{ .a0_plus_q = a0_plus_q, .a1 = a1 };
849}
850
851/// Creates a hint bit to help recover high bits after a small perturbation.
852/// Given:
853/// - z0: the modified low bits (r0 - f mod Q) where f is small
854/// - r1: the original high bits
855/// Returns 1 if a hint is needed, 0 otherwise.
856///
857/// This implements makeHint from FIPS 204. The hint helps recover r1 from
858/// r' = r - f without knowing f explicitly.
859fn makeHint(z0: u32, r1: u32, comptime gamma2: u32) u32 {
860 // If -alpha/2 < r0 - f <= alpha/2, then r1*alpha + r0 - f is a valid
861 // decomposition of r' with the restrictions of decompose() and so r'1 = r1.
862 // So the hint should be 0. This is covered by the first two inequalities.
863 // There is one other case: if r0 - f = -alpha/2, then r1*alpha + r0 - f is
864 // also a valid decomposition if r1 = 0. In the other cases a one is carried
865 // and the hint should be 1.
866
867 const cond1 = @intFromBool(z0 <= gamma2);
868 const cond2 = @intFromBool(z0 > Q - gamma2);
869 const eq_gamma2 = @intFromBool(z0 == Q - gamma2);
870 const r1_is_zero = @intFromBool(r1 == 0);
871 const cond3 = eq_gamma2 & r1_is_zero;
872
873 return 1 - (cond1 | cond2 | cond3);
874}
875
876/// Uses a hint to reconstruct high bits from a perturbed value.
877/// Given:
878/// - rp: the perturbed value (r' = r - f)
879/// - hint: the hint bit from makeHint
880/// Returns the reconstructed high bits r1.
881///
882/// This implements useHint from FIPS 204.
883fn useHint(rp: u32, hint: u32, comptime gamma2: u32) u32 {
884 const decomp = decompose(rp, gamma2);
885 const rp0_plus_q = decomp.a0_plus_q;
886 var rp1 = decomp.a1;
887
888 if (hint == 0) {
889 return rp1;
890 }
891
892 // Depending on gamma2, handle the adjustment differently
893 if (gamma2 == 261888) {
894 // ML-DSA-65 and ML-DSA-87: max r1 is 15
895 if (rp0_plus_q > Q) {
896 rp1 = (rp1 + 1) & 15;
897 } else {
898 rp1 = (rp1 -% 1) & 15;
899 }
900 } else if (gamma2 == 95232) {
901 // ML-DSA-44: max r1 is 43
902 if (rp0_plus_q > Q) {
903 if (rp1 == 43) {
904 rp1 = 0;
905 } else {
906 rp1 += 1;
907 }
908 } else {
909 if (rp1 == 0) {
910 rp1 = 43;
911 } else {
912 rp1 -= 1;
913 }
914 }
915 } else {
916 @compileError("unsupported gamma2 value");
917 }
918
919 return rp1;
920}
921
922/// Creates a hint polynomial for the difference between perturbed and original high bits.
923/// Returns the number of hint bits set to 1 (the population count).
924///
925/// This is used during signature generation to create hints that help verification
926/// recover the high bits without access to the secret.
927fn polyMakeHint(p0: Poly, p1: Poly, comptime gamma2: u32) struct { hint: Poly, count: u32 } {
928 var hint = Poly.zero;
929 var count: u32 = 0;
930
931 for (0..N) |i| {
932 const h = makeHint(p0.cs[i], p1.cs[i], gamma2);
933 hint.cs[i] = h;
934 count += h;
935 }
936
937 return .{ .hint = hint, .count = count };
938}
939
940/// Applies hints to reconstruct high bits from a perturbed polynomial.
941///
942/// This is used during signature verification to recover the high bits
943/// using the hints provided in the signature.
944fn polyUseHint(q: Poly, hint: Poly, comptime gamma2: u32) Poly {
945 var result = Poly.zero;
946
947 for (0..N) |i| {
948 result.cs[i] = useHint(q.cs[i], hint.cs[i], gamma2);
949 }
950
951 return result;
952}
953
954/// Pack polynomial with coefficients in [Q-eta, Q+eta] into bytes.
955/// For eta=2: packs coefficients into 3 bits each (96 bytes total)
956/// For eta=4: packs coefficients into 4 bits each (128 bytes total)
957/// Assumes coefficients are not normalized, but in [q-η, q+η].
958fn polyPackLeqEta(p: Poly, comptime eta: u8, buf: []u8) void {
959 comptime {
960 if (eta != 2 and eta != 4) {
961 @compileError("eta must be 2 or 4");
962 }
963 }
964
965 if (eta == 2) {
966 // 3 bits per coefficient: pack 8 coefficients into 3 bytes
967 var j: usize = 0;
968 var i: usize = 0;
969 while (i < buf.len) : (i += 3) {
970 const c0 = Q + eta - p.cs[j];
971 const c1 = Q + eta - p.cs[j + 1];
972 const c2 = Q + eta - p.cs[j + 2];
973 const c3 = Q + eta - p.cs[j + 3];
974 const c4 = Q + eta - p.cs[j + 4];
975 const c5 = Q + eta - p.cs[j + 5];
976 const c6 = Q + eta - p.cs[j + 6];
977 const c7 = Q + eta - p.cs[j + 7];
978
979 buf[i] = @truncate(c0 | (c1 << 3) | (c2 << 6));
980 buf[i + 1] = @truncate((c2 >> 2) | (c3 << 1) | (c4 << 4) | (c5 << 7));
981 buf[i + 2] = @truncate((c5 >> 1) | (c6 << 2) | (c7 << 5));
982
983 j += 8;
984 }
985 } else { // eta == 4
986 // 4 bits per coefficient: pack 2 coefficients into 1 byte
987 var j: usize = 0;
988 for (0..buf.len) |i| {
989 const c0 = Q + eta - p.cs[j];
990 const c1 = Q + eta - p.cs[j + 1];
991 buf[i] = @truncate(c0 | (c1 << 4));
992 j += 2;
993 }
994 }
995}
996
997/// Unpack polynomial with coefficients in [Q-eta, Q+eta] from bytes.
998/// Output coefficients will not be normalized, but in [q-η, q+η].
999fn polyUnpackLeqEta(comptime eta: u8, buf: []const u8) Poly {
1000 comptime {
1001 if (eta != 2 and eta != 4) {
1002 @compileError("eta must be 2 or 4");
1003 }
1004 }
1005
1006 var p = Poly.zero;
1007
1008 if (eta == 2) {
1009 // 3 bits per coefficient: unpack 8 coefficients from 3 bytes
1010 var j: usize = 0;
1011 var i: usize = 0;
1012 while (i < buf.len) : (i += 3) {
1013 p.cs[j] = Q + eta - (buf[i] & 7);
1014 p.cs[j + 1] = Q + eta - ((buf[i] >> 3) & 7);
1015 p.cs[j + 2] = Q + eta - ((buf[i] >> 6) | ((buf[i + 1] << 2) & 7));
1016 p.cs[j + 3] = Q + eta - ((buf[i + 1] >> 1) & 7);
1017 p.cs[j + 4] = Q + eta - ((buf[i + 1] >> 4) & 7);
1018 p.cs[j + 5] = Q + eta - ((buf[i + 1] >> 7) | ((buf[i + 2] << 1) & 7));
1019 p.cs[j + 6] = Q + eta - ((buf[i + 2] >> 2) & 7);
1020 p.cs[j + 7] = Q + eta - ((buf[i + 2] >> 5) & 7);
1021 j += 8;
1022 }
1023 } else { // eta == 4
1024 // 4 bits per coefficient: unpack 2 coefficients from 1 byte
1025 var j: usize = 0;
1026 for (0..buf.len) |i| {
1027 p.cs[j] = Q + eta - (buf[i] & 15);
1028 p.cs[j + 1] = Q + eta - (buf[i] >> 4);
1029 j += 2;
1030 }
1031 }
1032
1033 return p;
1034}
1035
1036/// Pack polynomial with coefficients < 1024 (T1) into bytes.
1037/// Packs 10 bits per coefficient: 4 coefficients into 5 bytes.
1038/// Assumes coefficients are normalized.
1039fn polyPackT1(p: Poly, buf: []u8) void {
1040 var j: usize = 0;
1041 var i: usize = 0;
1042 while (i < buf.len) : (i += 5) {
1043 buf[i] = @truncate(p.cs[j]);
1044 buf[i + 1] = @truncate((p.cs[j] >> 8) | (p.cs[j + 1] << 2));
1045 buf[i + 2] = @truncate((p.cs[j + 1] >> 6) | (p.cs[j + 2] << 4));
1046 buf[i + 3] = @truncate((p.cs[j + 2] >> 4) | (p.cs[j + 3] << 6));
1047 buf[i + 4] = @truncate(p.cs[j + 3] >> 2);
1048 j += 4;
1049 }
1050}
1051
1052/// Unpack polynomial with coefficients < 1024 (T1) from bytes.
1053/// Output coefficients will be normalized.
1054fn polyUnpackT1(buf: []const u8) Poly {
1055 var p = Poly.zero;
1056 var j: usize = 0;
1057 var i: usize = 0;
1058 while (i < buf.len) : (i += 5) {
1059 p.cs[j] = (@as(u32, buf[i]) | (@as(u32, buf[i + 1]) << 8)) & 0x3ff;
1060 p.cs[j + 1] = ((@as(u32, buf[i + 1]) >> 2) | (@as(u32, buf[i + 2]) << 6)) & 0x3ff;
1061 p.cs[j + 2] = ((@as(u32, buf[i + 2]) >> 4) | (@as(u32, buf[i + 3]) << 4)) & 0x3ff;
1062 p.cs[j + 3] = ((@as(u32, buf[i + 3]) >> 6) | (@as(u32, buf[i + 4]) << 2)) & 0x3ff;
1063 j += 4;
1064 }
1065 return p;
1066}
1067
1068/// Pack polynomial with coefficients in (-2^(D-1), 2^(D-1)] (T0) into bytes.
1069/// Packs 13 bits per coefficient: 8 coefficients into 13 bytes.
1070/// Assumes coefficients are not normalized, but in (q-2^(D-1), q+2^(D-1)].
1071fn polyPackT0(p: Poly, buf: []u8) void {
1072 const bound = 1 << (D - 1);
1073 var j: usize = 0;
1074 var i: usize = 0;
1075 while (i < buf.len) : (i += 13) {
1076 const p0 = Q + bound - p.cs[j];
1077 const p1 = Q + bound - p.cs[j + 1];
1078 const p2 = Q + bound - p.cs[j + 2];
1079 const p3 = Q + bound - p.cs[j + 3];
1080 const p4 = Q + bound - p.cs[j + 4];
1081 const p5 = Q + bound - p.cs[j + 5];
1082 const p6 = Q + bound - p.cs[j + 6];
1083 const p7 = Q + bound - p.cs[j + 7];
1084
1085 buf[i] = @truncate(p0 >> 0);
1086 buf[i + 1] = @truncate((p0 >> 8) | (p1 << 5));
1087 buf[i + 2] = @truncate(p1 >> 3);
1088 buf[i + 3] = @truncate((p1 >> 11) | (p2 << 2));
1089 buf[i + 4] = @truncate((p2 >> 6) | (p3 << 7));
1090 buf[i + 5] = @truncate(p3 >> 1);
1091 buf[i + 6] = @truncate((p3 >> 9) | (p4 << 4));
1092 buf[i + 7] = @truncate(p4 >> 4);
1093 buf[i + 8] = @truncate((p4 >> 12) | (p5 << 1));
1094 buf[i + 9] = @truncate((p5 >> 7) | (p6 << 6));
1095 buf[i + 10] = @truncate(p6 >> 2);
1096 buf[i + 11] = @truncate((p6 >> 10) | (p7 << 3));
1097 buf[i + 12] = @truncate(p7 >> 5);
1098
1099 j += 8;
1100 }
1101}
1102
1103/// Unpack polynomial with coefficients in (-2^(D-1), 2^(D-1)] (T0) from bytes.
1104/// Output coefficients will not be normalized, but in (-2^(D-1), 2^(D-1)].
1105fn polyUnpackT0(buf: []const u8) Poly {
1106 const bound = 1 << (D - 1);
1107 var p = Poly.zero;
1108 var j: usize = 0;
1109 var i: usize = 0;
1110 while (i < buf.len) : (i += 13) {
1111 p.cs[j] = Q + bound - ((@as(u32, buf[i]) | (@as(u32, buf[i + 1]) << 8)) & 0x1fff);
1112 p.cs[j + 1] = Q + bound - (((@as(u32, buf[i + 1]) >> 5) | (@as(u32, buf[i + 2]) << 3) | (@as(u32, buf[i + 3]) << 11)) & 0x1fff);
1113 p.cs[j + 2] = Q + bound - (((@as(u32, buf[i + 3]) >> 2) | (@as(u32, buf[i + 4]) << 6)) & 0x1fff);
1114 p.cs[j + 3] = Q + bound - (((@as(u32, buf[i + 4]) >> 7) | (@as(u32, buf[i + 5]) << 1) | (@as(u32, buf[i + 6]) << 9)) & 0x1fff);
1115 p.cs[j + 4] = Q + bound - (((@as(u32, buf[i + 6]) >> 4) | (@as(u32, buf[i + 7]) << 4) | (@as(u32, buf[i + 8]) << 12)) & 0x1fff);
1116 p.cs[j + 5] = Q + bound - (((@as(u32, buf[i + 8]) >> 1) | (@as(u32, buf[i + 9]) << 7)) & 0x1fff);
1117 p.cs[j + 6] = Q + bound - (((@as(u32, buf[i + 9]) >> 6) | (@as(u32, buf[i + 10]) << 2) | (@as(u32, buf[i + 11]) << 10)) & 0x1fff);
1118 p.cs[j + 7] = Q + bound - ((@as(u32, buf[i + 11]) >> 3) | (@as(u32, buf[i + 12]) << 5));
1119 j += 8;
1120 }
1121 return p;
1122}
1123
1124/// Convert coefficient from centered representation to non-negative.
1125/// Transforms value from [0,γ₁] ∪ (Q-γ₁, Q) to [0, 2γ₁).
1126fn centeredToPositive(val: u32, comptime gamma1: u32) u32 {
1127 var result = gamma1 -% val;
1128 result +%= (signMask(u32, result) & Q);
1129 return result;
1130}
1131
1132/// Pack polynomial with coefficients in (-gamma1, gamma1] into bytes.
1133/// For gamma1_bits=17: packs 18 bits per coefficient (4 coefficients into 9 bytes)
1134/// For gamma1_bits=19: packs 20 bits per coefficient (2 coefficients into 5 bytes)
1135/// Assumes coefficients are normalized.
1136fn polyPackLeGamma1(p: Poly, comptime gamma1_bits: u8, buf: []u8) void {
1137 const gamma1: u32 = @as(u32, 1) << gamma1_bits;
1138
1139 if (gamma1_bits == 17) {
1140 // Pack 4 coefficients into 9 bytes (18 bits each)
1141 var j: usize = 0;
1142 var i: usize = 0;
1143 while (i < buf.len) : (i += 9) {
1144 // Convert from [0,γ₁] ∪ (Q-γ₁, Q) to [0, 2γ₁)
1145 const p0 = centeredToPositive(p.cs[j], gamma1);
1146 const p1 = centeredToPositive(p.cs[j + 1], gamma1);
1147 const p2 = centeredToPositive(p.cs[j + 2], gamma1);
1148 const p3 = centeredToPositive(p.cs[j + 3], gamma1);
1149
1150 buf[i] = @truncate(p0);
1151 buf[i + 1] = @truncate(p0 >> 8);
1152 buf[i + 2] = @truncate((p0 >> 16) | (p1 << 2));
1153 buf[i + 3] = @truncate(p1 >> 6);
1154 buf[i + 4] = @truncate((p1 >> 14) | (p2 << 4));
1155 buf[i + 5] = @truncate(p2 >> 4);
1156 buf[i + 6] = @truncate((p2 >> 12) | (p3 << 6));
1157 buf[i + 7] = @truncate(p3 >> 2);
1158 buf[i + 8] = @truncate(p3 >> 10);
1159
1160 j += 4;
1161 }
1162 } else if (gamma1_bits == 19) {
1163 // Pack 2 coefficients into 5 bytes (20 bits each)
1164 var j: usize = 0;
1165 var i: usize = 0;
1166 while (i < buf.len) : (i += 5) {
1167 const p0 = centeredToPositive(p.cs[j], gamma1);
1168 const p1 = centeredToPositive(p.cs[j + 1], gamma1);
1169
1170 buf[i] = @truncate(p0);
1171 buf[i + 1] = @truncate(p0 >> 8);
1172 buf[i + 2] = @truncate((p0 >> 16) | (p1 << 4));
1173 buf[i + 3] = @truncate(p1 >> 4);
1174 buf[i + 4] = @truncate(p1 >> 12);
1175
1176 j += 2;
1177 }
1178 } else {
1179 @compileError("gamma1_bits must be 17 or 19");
1180 }
1181}
1182
1183/// Unpack polynomial with coefficients in (-gamma1, gamma1] from bytes.
1184/// Output coefficients will be normalized.
1185fn polyUnpackLeGamma1(comptime gamma1_bits: u8, buf: []const u8) Poly {
1186 const gamma1: u32 = @as(u32, 1) << gamma1_bits;
1187 var p = Poly.zero;
1188
1189 if (gamma1_bits == 17) {
1190 // Unpack 4 coefficients from 9 bytes (18 bits each)
1191 var j: usize = 0;
1192 var i: usize = 0;
1193 while (i < buf.len) : (i += 9) {
1194 var p0 = @as(u32, buf[i]) | (@as(u32, buf[i + 1]) << 8) | ((@as(u32, buf[i + 2]) & 0x3) << 16);
1195 var p1 = (@as(u32, buf[i + 2]) >> 2) | (@as(u32, buf[i + 3]) << 6) | ((@as(u32, buf[i + 4]) & 0xf) << 14);
1196 var p2 = (@as(u32, buf[i + 4]) >> 4) | (@as(u32, buf[i + 5]) << 4) | ((@as(u32, buf[i + 6]) & 0x3f) << 12);
1197 var p3 = (@as(u32, buf[i + 6]) >> 6) | (@as(u32, buf[i + 7]) << 2) | (@as(u32, buf[i + 8]) << 10);
1198
1199 // Convert from [0, 2γ₁) to (-γ₁, γ₁]
1200 p0 = centeredToPositive(p0, gamma1);
1201 p1 = centeredToPositive(p1, gamma1);
1202 p2 = centeredToPositive(p2, gamma1);
1203 p3 = centeredToPositive(p3, gamma1);
1204
1205 p.cs[j] = p0;
1206 p.cs[j + 1] = p1;
1207 p.cs[j + 2] = p2;
1208 p.cs[j + 3] = p3;
1209
1210 j += 4;
1211 }
1212 } else if (gamma1_bits == 19) {
1213 // Unpack 2 coefficients from 5 bytes (20 bits each)
1214 var j: usize = 0;
1215 var i: usize = 0;
1216 while (i < buf.len) : (i += 5) {
1217 var p0 = @as(u32, buf[i]) | (@as(u32, buf[i + 1]) << 8) | ((@as(u32, buf[i + 2]) & 0xf) << 16);
1218 var p1 = (@as(u32, buf[i + 2]) >> 4) | (@as(u32, buf[i + 3]) << 4) | (@as(u32, buf[i + 4]) << 12);
1219
1220 p0 = centeredToPositive(p0, gamma1);
1221 p1 = centeredToPositive(p1, gamma1);
1222
1223 p.cs[j] = p0;
1224 p.cs[j + 1] = p1;
1225
1226 j += 2;
1227 }
1228 } else {
1229 @compileError("gamma1_bits must be 17 or 19");
1230 }
1231
1232 return p;
1233}
1234
1235/// Pack W1 polynomial for verification.
1236/// For gamma1_bits=17: packs 6 bits per coefficient (4 coefficients into 3 bytes)
1237/// For gamma1_bits=19: packs 4 bits per coefficient (2 coefficients into 1 byte)
1238/// Assumes coefficients are normalized.
1239fn polyPackW1(p: Poly, comptime gamma1_bits: u8, buf: []u8) void {
1240 if (gamma1_bits == 17) {
1241 // Pack 4 coefficients into 3 bytes (6 bits each)
1242 var j: usize = 0;
1243 var i: usize = 0;
1244 while (i < buf.len) : (i += 3) {
1245 buf[i] = @truncate(p.cs[j] | (p.cs[j + 1] << 6));
1246 buf[i + 1] = @truncate((p.cs[j + 1] >> 2) | (p.cs[j + 2] << 4));
1247 buf[i + 2] = @truncate((p.cs[j + 2] >> 4) | (p.cs[j + 3] << 2));
1248 j += 4;
1249 }
1250 } else if (gamma1_bits == 19) {
1251 // Pack 2 coefficients into 1 byte (4 bits each) - equivalent to packLe16
1252 var j: usize = 0;
1253 for (0..buf.len) |i| {
1254 buf[i] = @truncate(p.cs[j] | (p.cs[j + 1] << 4));
1255 j += 2;
1256 }
1257 } else {
1258 @compileError("gamma1_bits must be 17 or 19");
1259 }
1260}
1261
1262fn polyDeriveUniform(seed: *const [32]u8, nonce: u16) Poly {
1263 var domain_sep: [2]u8 = undefined;
1264 domain_sep[0] = @truncate(nonce);
1265 domain_sep[1] = @truncate(nonce >> 8);
1266
1267 return sampleUniformRejection(
1268 Poly,
1269 Q,
1270 23,
1271 N,
1272 seed,
1273 &domain_sep,
1274 );
1275}
1276
1277/// Sample p uniformly with coefficients of norm less than or equal to η,
1278/// using the given seed and nonce with SHAKE-256.
1279/// The polynomial will not be normalized, but will have coefficients in [q-η, q+η].
1280/// FIPS 204: ExpandS (Algorithm 27)
1281fn expandS(comptime eta: u8, seed: *const [64]u8, nonce: u16) Poly {
1282 comptime {
1283 if (eta != 2 and eta != 4) {
1284 @compileError("eta must be 2 or 4");
1285 }
1286 }
1287
1288 var p = Poly.zero;
1289 var i: usize = 0;
1290
1291 var buf: [sha3.Shake256.block_length]u8 = undefined; // SHAKE-256 rate is 136 bytes
1292
1293 // Prepare input: seed || nonce (little-endian u16)
1294 var input: [66]u8 = undefined;
1295 @memcpy(input[0..64], seed);
1296 input[64] = @truncate(nonce);
1297 input[65] = @truncate(nonce >> 8);
1298
1299 var h = sha3.Shake256.init(.{});
1300 h.update(&input);
1301
1302 while (i < N) {
1303 h.squeeze(&buf);
1304
1305 // Process buffer: extract two samples per byte (4-bit nibbles)
1306 var j: usize = 0;
1307 while (j < buf.len and i < N) : (j += 1) {
1308 var t1 = @as(u32, buf[j]) & 15;
1309 var t2 = @as(u32, buf[j]) >> 4;
1310
1311 if (eta == 2) {
1312 // For eta=2: reject if t > 14, then reduce mod 5
1313 if (t1 <= 14) {
1314 t1 -%= ((205 * t1) >> 10) * 5; // reduce mod 5
1315 p.cs[i] = Q + eta - t1;
1316 i += 1;
1317 }
1318 if (t2 <= 14 and i < N) {
1319 t2 -%= ((205 * t2) >> 10) * 5; // reduce mod 5
1320 p.cs[i] = Q + eta - t2;
1321 i += 1;
1322 }
1323 } else if (eta == 4) {
1324 // For eta=4: accept if t <= 2*eta = 8
1325 if (t1 <= 2 * eta) {
1326 p.cs[i] = Q + eta - t1;
1327 i += 1;
1328 }
1329 if (t2 <= 2 * eta and i < N) {
1330 p.cs[i] = Q + eta - t2;
1331 i += 1;
1332 }
1333 }
1334 }
1335 }
1336
1337 return p;
1338}
1339
1340/// Sample p uniformly with τ non-zero coefficients in {Q-1, 1} using SHAKE-256.
1341/// This creates a "ball" polynomial with exactly tau non-zero ±1 coefficients.
1342/// The polynomial will be normalized with coefficients in {0, 1, Q-1}.
1343/// FIPS 204: SampleInBall (Algorithm 18)
1344fn sampleInBall(comptime tau: u16, seed: []const u8) Poly {
1345 var p = Poly.zero;
1346
1347 var buf: [sha3.Shake256.block_length]u8 = undefined; // SHAKE-256 rate is 136 bytes
1348
1349 var h = sha3.Shake256.init(.{});
1350 h.update(seed);
1351 h.squeeze(&buf);
1352
1353 // Extract signs from first 8 bytes
1354 var signs: u64 = 0;
1355 for (0..8) |j| {
1356 signs |= @as(u64, buf[j]) << @intCast(j * 8);
1357 }
1358 var buf_off: usize = 8;
1359
1360 // Generate tau non-zero coefficients using Fisher-Yates shuffle
1361 // Start with N-tau zeros, then add tau ±1 values
1362 var i: u16 = N - tau;
1363 while (i < N) : (i += 1) {
1364 var b: u16 = undefined;
1365
1366 // Find location using rejection sampling
1367 while (true) {
1368 if (buf_off >= buf.len) {
1369 h.squeeze(&buf);
1370 buf_off = 0;
1371 }
1372
1373 b = buf[buf_off];
1374 buf_off += 1;
1375
1376 if (b <= i) {
1377 break;
1378 }
1379 }
1380
1381 // Shuffle: move existing value to position i
1382 p.cs[i] = p.cs[b];
1383
1384 // Set position b to ±1 based on sign bit
1385 p.cs[b] = 1;
1386 const sign_bit: u1 = @truncate(signs);
1387 const mask = bitMask(u32, sign_bit);
1388 p.cs[b] ^= mask & (1 | (Q - 1));
1389 signs >>= 1;
1390 }
1391
1392 return p;
1393}
1394
1395/// Sample a polynomial with coefficients uniformly distributed in (-gamma1, gamma1]
1396/// Used for sampling the masking vector y during signing
1397/// FIPS 204: ExpandMask (Algorithm 28)
1398fn expandMask(comptime gamma1_bits: u8, seed: *const [64]u8, nonce: u16) Poly {
1399 const packed_size = ((gamma1_bits + 1) * N) / 8;
1400 var buf: [packed_size]u8 = undefined;
1401
1402 // Construct IV: seed || nonce (little-endian)
1403 var iv: [66]u8 = undefined;
1404 @memcpy(iv[0..64], seed);
1405 iv[64] = @truncate(nonce & 0xFF);
1406 iv[65] = @truncate(nonce >> 8);
1407
1408 var h = sha3.Shake256.init(.{});
1409 h.update(&iv);
1410 h.squeeze(&buf);
1411
1412 // Unpack the polynomial
1413 return polyUnpackLeGamma1(gamma1_bits, &buf);
1414}
1415
1416fn MLDSAImpl(comptime p: Params) type {
1417 return struct {
1418 pub const params = p;
1419 pub const name = p.name;
1420 pub const gamma1: u32 = @as(u32, 1) << p.gamma1_bits;
1421 pub const beta: u32 = p.tau * p.eta;
1422 pub const alpha: u32 = 2 * p.gamma2;
1423
1424 const Self = @This();
1425 const PolyVecL = PolyVec(p.l);
1426 const PolyVecK = PolyVec(p.k);
1427 const MatKxL = Mat(p.k, p.l);
1428
1429 /// Length of the seed used for deterministic key generation (32 bytes).
1430 pub const seed_length: usize = 32;
1431
1432 /// Length (in bytes) of optional random bytes, for non-deterministic signatures.
1433 pub const noise_length = 32;
1434
1435 /// Size of an encoded public key in bytes.
1436 pub const public_key_bytes: usize = 32 + polyT1PackedSize() * p.k;
1437
1438 /// Size of an encoded secret key in bytes.
1439 pub const private_key_bytes: usize = 32 + 32 + p.tr_size +
1440 polyLeqEtaPackedSize() * (p.l + p.k) + polyT0PackedSize() * p.k;
1441
1442 /// Size of an encoded signature in bytes.
1443 pub const signature_bytes: usize = p.ctilde_size +
1444 polyLeGamma1PackedSize() * p.l + p.omega + p.k;
1445
1446 // Packed sizes for different polynomial representations
1447 fn polyLeqEtaPackedSize() usize {
1448 // For eta=2: 3 bits per coefficient (values in [0,4])
1449 // For eta=4: 4 bits per coefficient (values in [0,8])
1450 const double_eta_bits = if (p.eta == 2) 3 else 4;
1451 return (N * double_eta_bits) / 8;
1452 }
1453
1454 fn polyLeGamma1PackedSize() usize {
1455 return ((p.gamma1_bits + 1) * N) / 8;
1456 }
1457
1458 fn polyT1PackedSize() usize {
1459 return (N * (Q_BITS - D)) / 8;
1460 }
1461
1462 fn polyT0PackedSize() usize {
1463 return (N * D) / 8;
1464 }
1465
1466 fn polyW1PackedSize() usize {
1467 return (N * (Q_BITS - p.gamma1_bits)) / 8;
1468 }
1469
1470 /// Helper function to compute CRH (Collision Resistant Hash) using SHAKE-256.
1471 /// This consolidates the repeated pattern of init-update-squeeze for hash operations.
1472 fn crh(comptime outsize: usize, inputs: anytype) [outsize]u8 {
1473 var h = sha3.Shake256.init(.{});
1474 inline for (inputs) |input| {
1475 h.update(input);
1476 }
1477 var out: [outsize]u8 = undefined;
1478 h.squeeze(&out);
1479 return out;
1480 }
1481
1482 /// Helper function to compute t = As1 + s2.
1483 /// This is used during key generation and public key reconstruction.
1484 fn computeT(A: MatKxL, s1_hat: PolyVecL, s2: PolyVecK) PolyVecK {
1485 const t = A.mulVec(s1_hat).add(s2);
1486 return t.normalize();
1487 }
1488
1489 /// ML-DSA public key
1490 pub const PublicKey = struct {
1491 /// Size of the encoded public key in bytes
1492 pub const encoded_length: usize = 32 + polyT1PackedSize() * p.k;
1493
1494 rho: [32]u8, // Seed for matrix A
1495 t1: PolyVecK, // High bits of t = As1 + s2
1496
1497 // Cached values
1498 t1_packed: [polyT1PackedSize() * p.k]u8,
1499 A: MatKxL,
1500 tr: [p.tr_size]u8, // CRH(rho || t1)
1501
1502 /// Encode public key to bytes
1503 pub fn toBytes(self: PublicKey) [encoded_length]u8 {
1504 var out: [encoded_length]u8 = undefined;
1505 @memcpy(out[0..32], &self.rho);
1506 @memcpy(out[32..], &self.t1_packed);
1507 return out;
1508 }
1509
1510 /// Decode public key from bytes
1511 pub fn fromBytes(bytes: [encoded_length]u8) !PublicKey {
1512 var pk: PublicKey = undefined;
1513 @memcpy(&pk.rho, bytes[0..32]);
1514 @memcpy(&pk.t1_packed, bytes[32..]);
1515
1516 pk.t1 = PolyVecK.unpackT1(pk.t1_packed[0..]);
1517 pk.A = MatKxL.derive(&pk.rho);
1518 pk.tr = crh(p.tr_size, .{&bytes});
1519
1520 return pk;
1521 }
1522 };
1523
1524 /// ML-DSA secret key
1525 pub const SecretKey = struct {
1526 /// Size of the encoded secret key in bytes
1527 pub const encoded_length: usize = 32 + 32 + p.tr_size +
1528 polyLeqEtaPackedSize() * (p.l + p.k) + polyT0PackedSize() * p.k;
1529
1530 rho: [32]u8, // Seed for matrix A
1531 key: [32]u8, // Seed for signature generation randomness
1532 tr: [p.tr_size]u8, // CRH(rho || t1)
1533 s1: PolyVecL, // Secret vector 1
1534 s2: PolyVecK, // Secret vector 2
1535 t0: PolyVecK, // Low bits of t = As1 + s2
1536
1537 // Cached values (in NTT domain)
1538 A: MatKxL,
1539 s1_hat: PolyVecL,
1540 s2_hat: PolyVecK,
1541 t0_hat: PolyVecK,
1542
1543 /// Encode secret key to bytes
1544 pub fn toBytes(self: SecretKey) [encoded_length]u8 {
1545 var out: [encoded_length]u8 = undefined;
1546 var offset: usize = 0;
1547
1548 @memcpy(out[offset .. offset + 32], &self.rho);
1549 offset += 32;
1550
1551 @memcpy(out[offset .. offset + 32], &self.key);
1552 offset += 32;
1553
1554 @memcpy(out[offset .. offset + p.tr_size], &self.tr);
1555 offset += p.tr_size;
1556
1557 if (p.eta == 2) {
1558 self.s1.packLeqEta(2, out[offset..][0 .. p.l * polyLeqEtaPackedSize()]);
1559 } else {
1560 self.s1.packLeqEta(4, out[offset..][0 .. p.l * polyLeqEtaPackedSize()]);
1561 }
1562 offset += p.l * polyLeqEtaPackedSize();
1563
1564 if (p.eta == 2) {
1565 self.s2.packLeqEta(2, out[offset..][0 .. p.k * polyLeqEtaPackedSize()]);
1566 } else {
1567 self.s2.packLeqEta(4, out[offset..][0 .. p.k * polyLeqEtaPackedSize()]);
1568 }
1569 offset += p.k * polyLeqEtaPackedSize();
1570
1571 self.t0.packT0(out[offset..][0 .. p.k * polyT0PackedSize()]);
1572 offset += p.k * polyT0PackedSize();
1573
1574 return out;
1575 }
1576
1577 /// Decode secret key from bytes
1578 pub fn fromBytes(bytes: [encoded_length]u8) !SecretKey {
1579 var sk: SecretKey = undefined;
1580 var offset: usize = 0;
1581
1582 @memcpy(&sk.rho, bytes[offset .. offset + 32]);
1583 offset += 32;
1584
1585 @memcpy(&sk.key, bytes[offset .. offset + 32]);
1586 offset += 32;
1587
1588 @memcpy(&sk.tr, bytes[offset .. offset + p.tr_size]);
1589 offset += p.tr_size;
1590
1591 sk.s1 = if (p.eta == 2)
1592 PolyVecL.unpackLeqEta(2, bytes[offset..][0 .. p.l * polyLeqEtaPackedSize()])
1593 else
1594 PolyVecL.unpackLeqEta(4, bytes[offset..][0 .. p.l * polyLeqEtaPackedSize()]);
1595 offset += p.l * polyLeqEtaPackedSize();
1596
1597 sk.s2 = if (p.eta == 2)
1598 PolyVecK.unpackLeqEta(2, bytes[offset..][0 .. p.k * polyLeqEtaPackedSize()])
1599 else
1600 PolyVecK.unpackLeqEta(4, bytes[offset..][0 .. p.k * polyLeqEtaPackedSize()]);
1601 offset += p.k * polyLeqEtaPackedSize();
1602
1603 sk.t0 = PolyVecK.unpackT0(bytes[offset..][0 .. p.k * polyT0PackedSize()]);
1604 offset += p.k * polyT0PackedSize();
1605
1606 // Compute cached NTT values for efficient signing
1607 sk.A = MatKxL.derive(&sk.rho);
1608 sk.s1_hat = sk.s1.ntt();
1609 sk.s2_hat = sk.s2.ntt();
1610 sk.t0_hat = sk.t0.ntt();
1611
1612 return sk;
1613 }
1614
1615 /// Compute the public key from this private key
1616 pub fn public(self: *const SecretKey) PublicKey {
1617 var pk: PublicKey = undefined;
1618 pk.rho = self.rho;
1619 pk.A = self.A;
1620 pk.tr = self.tr;
1621
1622 // Reconstruct t = As1 + s2, then extract high bits t1
1623 // Using power2Round: t = t1 * 2^D + t0
1624 const t = computeT(self.A, self.s1_hat, self.s2);
1625
1626 var t0_unused: PolyVecK = undefined;
1627 pk.t1 = t.power2Round(&t0_unused);
1628 pk.t1.packT1(&pk.t1_packed);
1629
1630 return pk;
1631 }
1632
1633 /// Create a Signer for incrementally signing a message.
1634 /// The noise parameter can be null for deterministic signatures,
1635 /// or provide randomness for hedged signatures (recommended for fault attack resistance).
1636 pub fn signer(self: *const SecretKey, noise: ?[noise_length]u8) !Signer {
1637 return self.signerWithContext(noise, "");
1638 }
1639
1640 /// Create a Signer for incrementally signing a message with context.
1641 /// The noise parameter can be null for deterministic signatures,
1642 /// or provide randomness for hedged signatures (recommended for fault attack resistance).
1643 /// The context parameter is an optional context string (max 255 bytes).
1644 pub fn signerWithContext(self: *const SecretKey, noise: ?[noise_length]u8, context: []const u8) ContextTooLongError!Signer {
1645 return Signer.init(self, noise, context);
1646 }
1647 };
1648
1649 /// Generate a new key pair from a seed (deterministic)
1650 pub fn newKeyFromSeed(seed: *const [seed_length]u8) struct { pk: PublicKey, sk: SecretKey } {
1651 var sk: SecretKey = undefined;
1652 var pk: PublicKey = undefined;
1653
1654 // NIST mode: expand seed || k || l using SHAKE-256 to get 128-byte expanded seed
1655 const e_seed = crh(128, .{ seed, &[_]u8{ p.k, p.l } });
1656
1657 @memcpy(&pk.rho, e_seed[0..32]);
1658 const s_seed = e_seed[32..96];
1659 @memcpy(&sk.key, e_seed[96..128]);
1660 @memcpy(&sk.rho, &pk.rho);
1661
1662 sk.A = MatKxL.derive(&pk.rho);
1663 pk.A = sk.A;
1664
1665 const s_seed_array: *const [64]u8 = s_seed[0..64];
1666 for (0..p.l) |i| {
1667 sk.s1.ps[i] = expandS(p.eta, s_seed_array, @intCast(i));
1668 }
1669
1670 for (0..p.k) |i| {
1671 sk.s2.ps[i] = expandS(p.eta, s_seed_array, @intCast(p.l + i));
1672 }
1673
1674 sk.s1_hat = sk.s1.ntt();
1675 sk.s2_hat = sk.s2.ntt();
1676
1677 const t = computeT(sk.A, sk.s1_hat, sk.s2);
1678
1679 pk.t1 = t.power2Round(&sk.t0);
1680 sk.t0_hat = sk.t0.ntt();
1681 pk.t1.packT1(&pk.t1_packed);
1682
1683 // tr = H(pk) = H(rho || t1)
1684 const pk_bytes = pk.toBytes();
1685 const tr = crh(p.tr_size, .{&pk_bytes});
1686 sk.tr = tr;
1687 pk.tr = tr;
1688
1689 return .{ .pk = pk, .sk = sk };
1690 }
1691
1692 /// ML-DSA signature
1693 pub const Signature = struct {
1694 /// Size of the encoded signature in bytes
1695 pub const encoded_length: usize = p.ctilde_size +
1696 polyLeGamma1PackedSize() * p.l + p.omega + p.k;
1697
1698 c_tilde: [p.ctilde_size]u8, // Challenge hash
1699 z: PolyVecL, // Response vector
1700 hint: PolyVecK, // Hint vector
1701
1702 /// Encode signature to bytes
1703 pub fn toBytes(self: Signature) [encoded_length]u8 {
1704 var out: [encoded_length]u8 = undefined;
1705 var offset: usize = 0;
1706
1707 @memcpy(out[offset .. offset + p.ctilde_size], &self.c_tilde);
1708 offset += p.ctilde_size;
1709
1710 self.z.packLeGamma1(p.gamma1_bits, out[offset .. offset + polyLeGamma1PackedSize() * p.l]);
1711 offset += polyLeGamma1PackedSize() * p.l;
1712
1713 _ = self.hint.packHint(p.omega, out[offset..]);
1714
1715 return out;
1716 }
1717
1718 /// Decode signature from bytes
1719 pub fn fromBytes(bytes: [encoded_length]u8) EncodingError!Signature {
1720 var sig: Signature = undefined;
1721 var offset: usize = 0;
1722
1723 @memcpy(&sig.c_tilde, bytes[offset .. offset + p.ctilde_size]);
1724 offset += p.ctilde_size;
1725
1726 sig.z = PolyVecL.unpackLeGamma1(p.gamma1_bits, bytes[offset .. offset + polyLeGamma1PackedSize() * p.l]);
1727 offset += polyLeGamma1PackedSize() * p.l;
1728
1729 // Validate ||z||_inf < gamma1 - beta per FIPS 204
1730 if (sig.z.exceeds(gamma1 - beta)) {
1731 return error.InvalidEncoding;
1732 }
1733
1734 sig.hint = PolyVecK.unpackHint(p.omega, bytes[offset..]) orelse return error.InvalidEncoding;
1735
1736 return sig;
1737 }
1738
1739 pub const VerifyError = Verifier.InitError || Verifier.VerifyError;
1740
1741 /// Verify this signature against a message and public key.
1742 /// Returns an error if the signature is invalid.
1743 pub fn verify(
1744 sig: Signature,
1745 msg: []const u8,
1746 public_key: PublicKey,
1747 ) VerifyError!void {
1748 return sig.verifyWithContext(msg, public_key, "");
1749 }
1750
1751 /// Verify this signature against a message and public key with context.
1752 /// Returns an error if the signature is invalid.
1753 /// The context parameter is an optional context string (max 255 bytes).
1754 pub fn verifyWithContext(
1755 sig: Signature,
1756 msg: []const u8,
1757 public_key: PublicKey,
1758 context: []const u8,
1759 ) VerifyError!void {
1760 if (context.len > 255) {
1761 return error.SignatureVerificationFailed;
1762 }
1763
1764 var h = sha3.Shake256.init(.{});
1765 h.update(&public_key.tr);
1766 h.update(&[_]u8{0}); // Domain separator: 0 for pure ML-DSA
1767 h.update(&[_]u8{@intCast(context.len)});
1768 if (context.len > 0) {
1769 h.update(context);
1770 }
1771 h.update(msg);
1772 var mu: [64]u8 = undefined;
1773 h.squeeze(&mu);
1774
1775 const z_hat = sig.z.ntt();
1776 const Az = public_key.A.mulVecHat(z_hat);
1777
1778 // Compute w' ≈ Az - 2^d·c·t1 (approximate w used in signing)
1779 var Az2dct1 = public_key.t1.mulBy2toD();
1780 Az2dct1 = Az2dct1.ntt();
1781 const c_poly = sampleInBall(p.tau, &sig.c_tilde);
1782 const c_hat = c_poly.ntt();
1783 for (0..p.k) |i| {
1784 Az2dct1.ps[i] = Az2dct1.ps[i].mulHat(c_hat);
1785 }
1786 Az2dct1 = Az.sub(Az2dct1);
1787 Az2dct1 = Az2dct1.reduceLe2Q();
1788 Az2dct1 = Az2dct1.invNTT();
1789 Az2dct1 = Az2dct1.normalizeAssumingLe2Q();
1790
1791 // Apply hints to recover high bits w1'
1792 var w1_prime = Az2dct1.useHint(sig.hint, p.gamma2);
1793 var w1_packed: [polyW1PackedSize() * p.k]u8 = undefined;
1794 w1_prime.packW1(p.gamma1_bits, &w1_packed);
1795
1796 const c_prime = crh(p.ctilde_size, .{ &mu, &w1_packed });
1797
1798 if (!mem.eql(u8, &c_prime, &sig.c_tilde)) {
1799 return error.SignatureVerificationFailed;
1800 }
1801 }
1802
1803 /// Create a Verifier for incrementally verifying a signature.
1804 pub fn verifier(self: Signature, public_key: PublicKey) !Verifier {
1805 return self.verifierWithContext(public_key, "");
1806 }
1807
1808 /// Create a Verifier for incrementally verifying a signature with context.
1809 /// The context parameter is an optional context string (max 255 bytes).
1810 pub fn verifierWithContext(self: Signature, public_key: PublicKey, context: []const u8) ContextTooLongError!Verifier {
1811 return Verifier.init(self, public_key, context);
1812 }
1813 };
1814
1815 /// A Signer is used to incrementally compute a signature over a streamed message.
1816 /// It can be obtained from a `SecretKey` or `KeyPair`, using the `signer()` function.
1817 pub const Signer = struct {
1818 h: sha3.Shake256, // For computing μ = CRH(tr || msg)
1819 secret_key: *const SecretKey,
1820 rnd: [32]u8,
1821
1822 /// Initialize a new Signer.
1823 /// The noise parameter can be null for deterministic signatures,
1824 /// or provide randomness for hedged signatures (recommended for fault attack resistance).
1825 /// The context parameter is an optional context string (max 255 bytes).
1826 pub fn init(secret_key: *const SecretKey, noise: ?[noise_length]u8, context: []const u8) ContextTooLongError!Signer {
1827 if (context.len > 255) {
1828 return error.ContextTooLong;
1829 }
1830
1831 var h = sha3.Shake256.init(.{});
1832 h.update(&secret_key.tr);
1833 h.update(&[_]u8{0}); // Domain separator: 0 for pure ML-DSA
1834 h.update(&[_]u8{@intCast(context.len)});
1835 if (context.len > 0) {
1836 h.update(context);
1837 }
1838
1839 return Signer{
1840 .h = h,
1841 .secret_key = secret_key,
1842 .rnd = noise orelse .{0} ** 32,
1843 };
1844 }
1845
1846 /// Add new data to the message being signed.
1847 pub fn update(self: *Signer, data: []const u8) void {
1848 self.h.update(data);
1849 }
1850
1851 /// Compute a signature over the entire message.
1852 pub fn finalize(self: *Signer) Signature {
1853 var mu: [64]u8 = undefined;
1854 self.h.squeeze(&mu);
1855
1856 const rho_prime = crh(64, .{ &self.secret_key.key, &self.rnd, &mu });
1857
1858 var sig: Signature = undefined;
1859 var y_nonce: u16 = 0;
1860
1861 // Rejection sampling loop (FIPS 204 Algorithm 2, steps 5-16)
1862 var attempt: u32 = 0;
1863 while (true) {
1864 attempt += 1;
1865 if (attempt >= 576) { // (6/7)⁵⁷⁶ < 2⁻¹²⁸
1866 @branchHint(.unlikely);
1867 unreachable;
1868 }
1869
1870 const y = PolyVecL.deriveUniformLeGamma1(p.gamma1_bits, &rho_prime, y_nonce);
1871 y_nonce += @intCast(p.l);
1872
1873 const y_hat = y.ntt();
1874 var w = self.secret_key.A.mulVec(y_hat);
1875
1876 w = w.normalize();
1877 var w0: PolyVecK = undefined;
1878 const w1 = w.decomposeVec(p.gamma2, &w0);
1879 var w1_packed: [polyW1PackedSize() * p.k]u8 = undefined;
1880 w1.packW1(p.gamma1_bits, &w1_packed);
1881
1882 sig.c_tilde = crh(p.ctilde_size, .{ &mu, &w1_packed });
1883
1884 const c_poly = sampleInBall(p.tau, &sig.c_tilde);
1885 const c_hat = c_poly.ntt();
1886
1887 // Rejection check: ensure masking is effective
1888 var w0mcs2: PolyVecK = undefined;
1889 for (0..p.k) |i| {
1890 w0mcs2.ps[i] = c_hat.mulHat(self.secret_key.s2_hat.ps[i]);
1891 w0mcs2.ps[i] = w0mcs2.ps[i].invNTT();
1892 }
1893 w0mcs2 = w0.sub(w0mcs2);
1894 w0mcs2 = w0mcs2.normalize();
1895
1896 if (w0mcs2.exceeds(p.gamma2 - beta)) {
1897 continue;
1898 }
1899
1900 // Compute response z = y + c·s1
1901 for (0..p.l) |i| {
1902 sig.z.ps[i] = c_hat.mulHat(self.secret_key.s1_hat.ps[i]);
1903 sig.z.ps[i] = sig.z.ps[i].invNTT();
1904 }
1905 sig.z = sig.z.add(y);
1906 sig.z = sig.z.normalize();
1907
1908 if (sig.z.exceeds(gamma1 - beta)) {
1909 continue;
1910 }
1911
1912 var ct0: PolyVecK = undefined;
1913 for (0..p.k) |i| {
1914 ct0.ps[i] = c_hat.mulHat(self.secret_key.t0_hat.ps[i]);
1915 ct0.ps[i] = ct0.ps[i].invNTT();
1916 }
1917 ct0 = ct0.reduceLe2Q();
1918 ct0 = ct0.normalize();
1919
1920 if (ct0.exceeds(p.gamma2)) {
1921 continue;
1922 }
1923
1924 // Generate hints for verification
1925 var w0mcs2pct0 = w0mcs2.add(ct0);
1926 w0mcs2pct0 = w0mcs2pct0.reduceLe2Q();
1927 w0mcs2pct0 = w0mcs2pct0.normalizeAssumingLe2Q();
1928 const hint_result = PolyVecK.makeHintVec(w0mcs2pct0, w1, p.gamma2);
1929 if (hint_result.pop > p.omega) {
1930 continue;
1931 }
1932 sig.hint = hint_result.hint;
1933
1934 return sig;
1935 }
1936 }
1937 };
1938
1939 /// A Verifier is used to incrementally verify a signature over a streamed message.
1940 /// It can be obtained from a `Signature`, using the `verifier()` function.
1941 pub const Verifier = struct {
1942 h: sha3.Shake256, // For computing μ = CRH(tr || msg)
1943 signature: Signature,
1944 public_key: PublicKey,
1945
1946 pub const InitError = EncodingError;
1947 pub const VerifyError = SignatureVerificationError;
1948
1949 /// Initialize a new Verifier.
1950 /// The context parameter is an optional context string (max 255 bytes).
1951 pub fn init(signature: Signature, public_key: PublicKey, context: []const u8) ContextTooLongError!Verifier {
1952 if (context.len > 255) {
1953 return error.ContextTooLong;
1954 }
1955
1956 var h = sha3.Shake256.init(.{});
1957 h.update(&public_key.tr);
1958 h.update(&[_]u8{0}); // Domain separator: 0 for pure ML-DSA
1959 h.update(&[_]u8{@intCast(context.len)}); // Context length
1960 if (context.len > 0) {
1961 h.update(context);
1962 }
1963
1964 return Verifier{
1965 .h = h,
1966 .signature = signature,
1967 .public_key = public_key,
1968 };
1969 }
1970
1971 /// Add new content to the message to be verified.
1972 pub fn update(self: *Verifier, data: []const u8) void {
1973 self.h.update(data);
1974 }
1975
1976 /// Verify that the signature is valid for the entire message.
1977 pub fn verify(self: *Verifier) SignatureVerificationError!void {
1978 var mu: [64]u8 = undefined;
1979 self.h.squeeze(&mu);
1980
1981 const z_hat = self.signature.z.ntt();
1982 const Az = self.public_key.A.mulVecHat(z_hat);
1983
1984 // Compute w' ≈ Az - 2^d·c·t1 (approximate w used in signing)
1985 var Az2dct1 = self.public_key.t1.mulBy2toD();
1986 Az2dct1 = Az2dct1.ntt();
1987 const c_poly = sampleInBall(p.tau, &self.signature.c_tilde);
1988 const c_hat = c_poly.ntt();
1989 for (0..p.k) |i| {
1990 Az2dct1.ps[i] = Az2dct1.ps[i].mulHat(c_hat);
1991 }
1992 Az2dct1 = Az.sub(Az2dct1);
1993 Az2dct1 = Az2dct1.reduceLe2Q();
1994 Az2dct1 = Az2dct1.invNTT();
1995 Az2dct1 = Az2dct1.normalizeAssumingLe2Q();
1996
1997 // Apply hints to recover high bits w1'
1998 var w1_prime = Az2dct1.useHint(self.signature.hint, p.gamma2);
1999 var w1_packed: [polyW1PackedSize() * p.k]u8 = undefined;
2000 w1_prime.packW1(p.gamma1_bits, &w1_packed);
2001
2002 const c_prime = crh(p.ctilde_size, .{ &mu, &w1_packed });
2003
2004 if (!mem.eql(u8, &c_prime, &self.signature.c_tilde)) {
2005 return error.SignatureVerificationFailed;
2006 }
2007 }
2008 };
2009
2010 /// A key pair consisting of a secret key and its corresponding public key.
2011 pub const KeyPair = struct {
2012 /// Length (in bytes) of a seed required to create a key pair.
2013 pub const seed_length = Self.seed_length;
2014
2015 /// The public key component.
2016 public_key: PublicKey,
2017
2018 /// The secret key component.
2019 secret_key: SecretKey,
2020
2021 /// Generate a new random key pair.
2022 /// This uses the system's cryptographically secure random number generator.
2023 ///
2024 /// `crypto.random.bytes` must be supported by the target.
2025 pub fn generate() KeyPair {
2026 var seed: [Self.seed_length]u8 = undefined;
2027 crypto.random.bytes(&seed);
2028 return generateDeterministic(seed) catch unreachable;
2029 }
2030
2031 /// Generate a key pair deterministically from a seed.
2032 /// Use for testing or when reproducibility is required.
2033 /// The seed should be generated using a cryptographically secure random source.
2034 pub fn generateDeterministic(seed: [32]u8) !KeyPair {
2035 const keys = newKeyFromSeed(&seed);
2036 return .{
2037 .public_key = keys.pk,
2038 .secret_key = keys.sk,
2039 };
2040 }
2041
2042 /// Derive the public key from an existing secret key.
2043 /// This recomputes the public key components from the secret key.
2044 pub fn fromSecretKey(sk: SecretKey) !KeyPair {
2045 var pk: PublicKey = undefined;
2046 pk.rho = sk.rho;
2047 pk.tr = sk.tr;
2048 pk.A = sk.A;
2049
2050 const t = computeT(sk.A, sk.s1_hat, sk.s2);
2051
2052 var t0: PolyVecK = undefined;
2053 pk.t1 = t.power2Round(&t0);
2054 pk.t1.packT1(&pk.t1_packed);
2055
2056 return .{
2057 .public_key = pk,
2058 .secret_key = sk,
2059 };
2060 }
2061
2062 /// Create a Signer for incrementally signing a message.
2063 /// The noise parameter can be null for deterministic signatures,
2064 /// or provide randomness for hedged signatures (recommended for fault attack resistance).
2065 pub fn signer(self: *const KeyPair, noise: ?[noise_length]u8) !Signer {
2066 return self.secret_key.signer(noise);
2067 }
2068
2069 /// Create a Signer for incrementally signing a message with context.
2070 /// The noise parameter can be null for deterministic signatures,
2071 /// or provide randomness for hedged signatures (recommended for fault attack resistance).
2072 /// The context parameter is an optional context string (max 255 bytes).
2073 pub fn signerWithContext(self: *const KeyPair, noise: ?[noise_length]u8, context: []const u8) ContextTooLongError!Signer {
2074 return self.secret_key.signerWithContext(noise, context);
2075 }
2076
2077 /// Sign a message using this key pair.
2078 /// The noise parameter can be null for deterministic signatures,
2079 /// or provide randomness for hedged signatures (recommended for fault attack resistance).
2080 pub fn sign(
2081 kp: KeyPair,
2082 msg: []const u8,
2083 noise: ?[noise_length]u8,
2084 ) !Signature {
2085 return kp.signWithContext(msg, noise, "");
2086 }
2087
2088 /// Sign a message using this key pair with context.
2089 /// The noise parameter can be null for deterministic signatures,
2090 /// or provide randomness for hedged signatures (recommended for fault attack resistance).
2091 /// The context parameter is an optional context string (max 255 bytes).
2092 pub fn signWithContext(
2093 kp: KeyPair,
2094 msg: []const u8,
2095 noise: ?[noise_length]u8,
2096 context: []const u8,
2097 ) ContextTooLongError!Signature {
2098 var st = try kp.signerWithContext(noise, context);
2099 st.update(msg);
2100 return st.finalize();
2101 }
2102 };
2103 };
2104}
2105
2106test "modular arithmetic" {
2107 // Test Montgomery reduction
2108 const x: u64 = 12345678;
2109 const y = montReduceLe2Q(x);
2110 try testing.expect(y < 2 * Q);
2111
2112 // Test modQ
2113 try testing.expectEqual(@as(u32, 0), modQ(Q));
2114 try testing.expectEqual(@as(u32, 1), modQ(Q + 1));
2115}
2116
2117test "polynomial operations" {
2118 var p1 = Poly.zero;
2119 p1.cs[0] = 1;
2120 p1.cs[1] = 2;
2121
2122 var p2 = Poly.zero;
2123 p2.cs[0] = 3;
2124 p2.cs[1] = 4;
2125
2126 const p3 = p1.add(p2);
2127 try testing.expectEqual(@as(u32, 4), p3.cs[0]);
2128 try testing.expectEqual(@as(u32, 6), p3.cs[1]);
2129}
2130
2131test "NTT and inverse NTT" {
2132 // Create a test polynomial in REGULAR FORM (not Montgomery)
2133 var p = Poly.zero;
2134 for (0..N) |i| {
2135 p.cs[i] = @intCast(i % Q);
2136 }
2137
2138 // Apply NTT then inverse NTT
2139 // According to Dilithium spec: NTT followed by invNTT multiplies by R
2140 // So result will be p * R (i.e., p in Montgomery form)
2141 var p_ntt = p.ntt();
2142
2143 // Reduce before invNTT (as Go test does)
2144 p_ntt = p_ntt.reduceLe2Q();
2145
2146 const p_restored = p_ntt.invNTT();
2147
2148 // Reduce and normalize
2149 const p_reduced = p_restored.reduceLe2Q();
2150 const p_norm = p_reduced.normalize();
2151
2152 // Check if we get p * R (which equals toMont(p))
2153 for (0..N) |i| {
2154 const original: u32 = @intCast(i % Q);
2155 const expected = toMont(original);
2156 const expected_norm = modQ(expected);
2157 try testing.expectEqual(expected_norm, p_norm.cs[i]);
2158 }
2159}
2160
2161test "parameter set instantiation" {
2162 // Just verify we can instantiate all three parameter sets
2163 const ml44 = MLDSA44;
2164 const ml65 = MLDSA65;
2165 const ml87 = MLDSA87;
2166
2167 try testing.expectEqualStrings("ML-DSA-44", ml44.name);
2168 try testing.expectEqualStrings("ML-DSA-65", ml65.name);
2169 try testing.expectEqualStrings("ML-DSA-87", ml87.name);
2170}
2171
2172test "compare zetas with Go implementation" {
2173 // First 16 zetas from Go implementation (in Montgomery form)
2174 const go_zetas = [16]u32{
2175 4193792, 25847, 5771523, 7861508, 237124, 7602457, 7504169,
2176 466468, 1826347, 2353451, 8021166, 6288512, 3119733, 5495562,
2177 3111497, 2680103,
2178 };
2179
2180 // Compare our computed zetas with Go's
2181 for (0..16) |i| {
2182 try testing.expectEqual(go_zetas[i], zetas[i]);
2183 }
2184}
2185
2186test "NTT with simple polynomial" {
2187 // Test with a very simple polynomial: just one coefficient set to 1 in regular form
2188 var p = Poly.zero;
2189 p.cs[0] = 1;
2190
2191 var p_ntt = p.ntt();
2192
2193 // Reduce before invNTT (as Go test does)
2194 p_ntt = p_ntt.reduceLe2Q();
2195
2196 const p_restored = p_ntt.invNTT();
2197
2198 // Result should be 1 * R = toMont(1) in Montgomery form
2199 const p_reduced = p_restored.reduceLe2Q();
2200 const p_norm = p_reduced.normalize();
2201
2202 const expected = modQ(toMont(1));
2203 try testing.expectEqual(expected, p_norm.cs[0]);
2204
2205 // All other coefficients should be 0 * R = 0
2206 for (1..N) |i| {
2207 try testing.expectEqual(@as(u32, 0), p_norm.cs[i]);
2208 }
2209}
2210
2211test "Montgomery reduction correctness" {
2212 // Test that Montgomery reduction works correctly
2213 // montReduceLe2Q(a * b * R) = a * b mod q (where a, b are in Montgomery form)
2214
2215 const x: u32 = 12345;
2216 const y: u32 = 67890;
2217
2218 // Convert to Montgomery form
2219 const x_mont = toMont(x);
2220 const y_mont = toMont(y);
2221
2222 // Multiply in Montgomery form
2223 const product_mont = montReduceLe2Q(@as(u64, x_mont) * @as(u64, y_mont));
2224
2225 // Convert back from Montgomery form
2226 const product = montReduceLe2Q(@as(u64, product_mont));
2227
2228 // Direct multiplication mod q
2229 const expected = modQ(@as(u32, @intCast((@as(u64, x) * @as(u64, y)) % Q)));
2230
2231 try testing.expectEqual(expected, modQ(product));
2232}
2233
2234// Removed debug test - was causing noise in output
2235
2236test "compare inv_zetas with Go implementation" {
2237 // First 16 inv_zetas from Go implementation
2238 const go_inv_zetas = [16]u32{
2239 6403635, 846154, 6979993, 4442679, 1362209, 48306, 4460757,
2240 554416, 3545687, 6767575, 976891, 8196974, 2286327, 420899,
2241 2235985, 2939036,
2242 };
2243
2244 // Compare our computed inv_zetas with Go's
2245 for (0..16) |i| {
2246 if (inv_zetas[i] != go_inv_zetas[i]) {
2247 std.debug.print("Mismatch at inv_zetas[{d}]: got {d}, expected {d}\n", .{ i, inv_zetas[i], go_inv_zetas[i] });
2248 }
2249 try testing.expectEqual(go_inv_zetas[i], inv_zetas[i]);
2250 }
2251}
2252
2253test "power2Round correctness" {
2254 // Test that power2Round correctly splits values
2255 // For all a in [0, Q), we should have a = a1*2^D + a0
2256 // where -2^(D-1) < a0 <= 2^(D-1)
2257
2258 // Test a few specific values
2259 const test_values = [_]u32{ 0, 1, Q / 2, Q - 1, 12345, 8380416 };
2260
2261 for (test_values) |a| {
2262 if (a >= Q) continue;
2263
2264 const result = power2Round(a);
2265 const a0 = @as(i32, @bitCast(result.a0_plus_q -% Q));
2266 const a1 = result.a1;
2267
2268 // Check reconstruction: a = a1*2^D + a0
2269 const reconstructed = @as(i32, @bitCast(a1 << D)) + a0;
2270 try testing.expectEqual(@as(i32, @bitCast(a)), reconstructed);
2271
2272 // Check a0 bounds: -2^(D-1) < a0 <= 2^(D-1)
2273 const bound: i32 = 1 << (D - 1);
2274 try testing.expect(a0 > -bound and a0 <= bound);
2275 }
2276}
2277
2278test "decompose correctness for ML-DSA-65" {
2279 // Test decompose with gamma2 = 95232 (ML-DSA-44)
2280 const gamma2 = 95232;
2281 const alpha = 2 * gamma2;
2282
2283 const test_values = [_]u32{ 0, 1, Q / 2, Q - 1, 12345 };
2284
2285 for (test_values) |a| {
2286 if (a >= Q) continue;
2287
2288 const result = decompose(a, gamma2);
2289 const a0 = @as(i32, @bitCast(result.a0_plus_q -% Q));
2290 const a1 = result.a1;
2291
2292 // Check reconstruction: a = a1*alpha + a0 (mod Q)
2293 var reconstructed: i64 = @as(i64, @intCast(a1)) * @as(i64, @intCast(alpha)) + @as(i64, a0);
2294 reconstructed = @mod(reconstructed, @as(i64, Q));
2295 try testing.expectEqual(@as(i64, @intCast(a)), reconstructed);
2296
2297 // Check a0 bounds (approximately)
2298 const bound: i32 = @intCast(alpha / 2);
2299 try testing.expect(@abs(a0) <= bound);
2300 }
2301}
2302
2303test "decompose correctness for ML-DSA-87" {
2304 // Test decompose with gamma2 = 261888 (ML-DSA-65 and ML-DSA-87)
2305 const gamma2 = 261888;
2306 const alpha = 2 * gamma2;
2307
2308 const test_values = [_]u32{ 0, 1, Q / 2, Q - 1, 12345 };
2309
2310 for (test_values) |a| {
2311 if (a >= Q) continue;
2312
2313 const result = decompose(a, gamma2);
2314 const a0 = @as(i32, @bitCast(result.a0_plus_q -% Q));
2315 const a1 = result.a1;
2316
2317 // Check reconstruction: a = a1*alpha + a0 (mod Q)
2318 var reconstructed: i64 = @as(i64, @intCast(a1)) * @as(i64, @intCast(alpha)) + @as(i64, a0);
2319 reconstructed = @mod(reconstructed, @as(i64, Q));
2320 try testing.expectEqual(@as(i64, @intCast(a)), reconstructed);
2321
2322 // Check a0 bounds (approximately)
2323 const bound: i32 = @intCast(alpha / 2);
2324 try testing.expect(@abs(a0) <= bound);
2325 }
2326}
2327
2328test "polyDeriveUniform deterministic" {
2329 // Test that polyDeriveUniform produces deterministic results
2330 const seed: [32]u8 = .{0x01} ++ .{0x00} ** 31;
2331 const nonce: u16 = 0;
2332
2333 const p1 = polyDeriveUniform(&seed, nonce);
2334 const p2 = polyDeriveUniform(&seed, nonce);
2335
2336 // Should be identical
2337 for (0..N) |i| {
2338 try testing.expectEqual(p1.cs[i], p2.cs[i]);
2339 }
2340
2341 // All coefficients should be in [0, Q)
2342 for (0..N) |i| {
2343 try testing.expect(p1.cs[i] < Q);
2344 }
2345}
2346
2347test "polyDeriveUniform different nonces" {
2348 // Test that different nonces produce different polynomials
2349 const seed: [32]u8 = .{0x01} ++ .{0x00} ** 31;
2350
2351 const p1 = polyDeriveUniform(&seed, 0);
2352 const p2 = polyDeriveUniform(&seed, 1);
2353
2354 // Should be different
2355 var different = false;
2356 for (0..N) |i| {
2357 if (p1.cs[i] != p2.cs[i]) {
2358 different = true;
2359 break;
2360 }
2361 }
2362 try testing.expect(different);
2363}
2364
2365test "expandS with eta=2" {
2366 // Test eta=2 sampling
2367 const seed: [64]u8 = .{0x02} ++ .{0x00} ** 63;
2368 const nonce: u16 = 0;
2369
2370 const p = expandS(2, &seed, nonce);
2371
2372 // All coefficients should be in [Q-eta, Q+eta]
2373 // The function returns coefficients as Q + eta - t, where t is in [0, 2*eta]
2374 // So coefficients are in [Q-eta, Q+eta]
2375 for (0..N) |i| {
2376 const c = p.cs[i];
2377 // Check that c is in [Q-2, Q+2]
2378 try testing.expect(c >= Q - 2 and c <= Q + 2);
2379 }
2380}
2381
2382test "expandS with eta=4" {
2383 // Test eta=4 sampling
2384 const seed: [64]u8 = .{0x03} ++ .{0x00} ** 63;
2385 const nonce: u16 = 0;
2386
2387 const p = expandS(4, &seed, nonce);
2388
2389 // All coefficients should be in [Q-eta, Q+eta]
2390 for (0..N) |i| {
2391 const c = p.cs[i];
2392 // Check bounds (coefficients are around Q ± eta)
2393 const diff = if (c >= Q) c - Q else Q - c;
2394 try testing.expect(diff <= 4);
2395 }
2396}
2397
2398test "sampleInBall has correct weight" {
2399 // Test that ball polynomial has exactly tau non-zero coefficients
2400 const tau = 39; // From ML-DSA-44
2401 const seed: [32]u8 = .{0x04} ++ .{0x00} ** 31;
2402
2403 const p = sampleInBall(tau, &seed);
2404
2405 // Count non-zero coefficients
2406 var count: u32 = 0;
2407 for (0..N) |i| {
2408 if (p.cs[i] != 0) {
2409 count += 1;
2410 // Non-zero coefficients should be 1 or Q-1
2411 try testing.expect(p.cs[i] == 1 or p.cs[i] == Q - 1);
2412 }
2413 }
2414
2415 try testing.expectEqual(tau, count);
2416}
2417
2418test "sampleInBall deterministic" {
2419 // Test that ball sampling is deterministic
2420 const tau = 49; // From ML-DSA-65
2421 const seed: [32]u8 = .{0x05} ++ .{0x00} ** 31;
2422
2423 const p1 = sampleInBall(tau, &seed);
2424 const p2 = sampleInBall(tau, &seed);
2425
2426 // Should be identical
2427 for (0..N) |i| {
2428 try testing.expectEqual(p1.cs[i], p2.cs[i]);
2429 }
2430}
2431
2432test "polyPackLeqEta / polyUnpackLeqEta roundtrip for eta=2" {
2433 // Test packing and unpacking for eta=2
2434 const eta = 2;
2435
2436 // Create a test polynomial with coefficients in [Q-eta, Q+eta]
2437 var p = Poly.zero;
2438 for (0..N) |i| {
2439 // Use various values in range
2440 const val = @as(u32, @intCast(i % 5)); // 0, 1, 2, 3, 4
2441 p.cs[i] = Q + eta - val;
2442 }
2443
2444 // Pack it
2445 var buf: [96]u8 = undefined; // eta=2: 3 bits per coeff = 96 bytes
2446 polyPackLeqEta(p, eta, &buf);
2447
2448 // Unpack it
2449 const p2 = polyUnpackLeqEta(eta, &buf);
2450
2451 // Should be identical
2452 for (0..N) |i| {
2453 try testing.expectEqual(p.cs[i], p2.cs[i]);
2454 }
2455}
2456
2457test "polyPackLeqEta / polyUnpackLeqEta roundtrip for eta=4" {
2458 // Test packing and unpacking for eta=4
2459 const eta = 4;
2460
2461 // Create a test polynomial with coefficients in [Q-eta, Q+eta]
2462 var p = Poly.zero;
2463 for (0..N) |i| {
2464 // Use various values in range
2465 const val = @as(u32, @intCast(i % 9)); // 0, 1, 2, ..., 8
2466 p.cs[i] = Q + eta - val;
2467 }
2468
2469 // Pack it
2470 var buf: [128]u8 = undefined; // eta=4: 4 bits per coeff = 128 bytes
2471 polyPackLeqEta(p, eta, &buf);
2472
2473 // Unpack it
2474 const p2 = polyUnpackLeqEta(eta, &buf);
2475
2476 // Should be identical
2477 for (0..N) |i| {
2478 try testing.expectEqual(p.cs[i], p2.cs[i]);
2479 }
2480}
2481
2482test "polyPackT1 / polyUnpackT1 roundtrip" {
2483 // Create a test polynomial with coefficients < 1024
2484 var p = Poly.zero;
2485 for (0..N) |i| {
2486 p.cs[i] = @intCast(i % 1024);
2487 }
2488
2489 // Pack it
2490 var buf: [320]u8 = undefined; // (256 * 10) / 8 = 320 bytes
2491 polyPackT1(p, &buf);
2492
2493 // Unpack it
2494 const p2 = polyUnpackT1(&buf);
2495
2496 // Should be identical
2497 for (0..N) |i| {
2498 try testing.expectEqual(p.cs[i], p2.cs[i]);
2499 }
2500}
2501
2502test "polyPackT0 / polyUnpackT0 roundtrip" {
2503 // Create a test polynomial with coefficients in (Q-2^12, Q+2^12]
2504 // This is the range (-2^12, 2^12] represented as unsigned around Q
2505 const bound = 1 << 12; // 2^(D-1) where D=13
2506 var p = Poly.zero;
2507 for (0..N) |i| {
2508 // Cycle through valid range for T0
2509 // Values should be Q + offset where offset is in (-bound, bound]
2510 const cycle_val = @as(i32, @intCast(i % (2 * bound))); // 0 to 2*bound-1
2511 const offset = cycle_val - bound + 1; // (-bound+1) to bound
2512 p.cs[i] = @as(u32, @intCast(@as(i32, Q) + offset));
2513 }
2514
2515 // Pack it
2516 var buf: [416]u8 = undefined; // (256 * 13) / 8 = 416 bytes
2517 polyPackT0(p, &buf);
2518
2519 // Unpack it
2520 const p2 = polyUnpackT0(&buf);
2521
2522 // Should be identical
2523 for (0..N) |i| {
2524 try testing.expectEqual(p.cs[i], p2.cs[i]);
2525 }
2526}
2527
2528test "polyPackLeGamma1 / polyUnpackLeGamma1 roundtrip gamma1_bits=17" {
2529 const gamma1_bits = 17;
2530 const gamma1: u32 = @as(u32, 1) << gamma1_bits;
2531
2532 // Create a test polynomial with coefficients in (-gamma1, gamma1]
2533 // Normalized: [0, gamma1] ∪ (Q-gamma1, Q)
2534 var p = Poly.zero;
2535 for (0..N) |i| {
2536 if (i % 2 == 0) {
2537 // Positive values: [0, gamma1]
2538 p.cs[i] = @intCast((i / 2) % (gamma1 + 1));
2539 } else {
2540 // Negative values: (Q-gamma1, Q)
2541 const neg_val: u32 = @intCast(((i / 2) % gamma1) + 1);
2542 p.cs[i] = Q - neg_val;
2543 }
2544 }
2545
2546 // Pack it
2547 var buf: [576]u8 = undefined; // (256 * 18) / 8 = 576 bytes
2548 polyPackLeGamma1(p, gamma1_bits, &buf);
2549
2550 // Unpack it
2551 const p2 = polyUnpackLeGamma1(gamma1_bits, &buf);
2552
2553 // Should be identical
2554 for (0..N) |i| {
2555 try testing.expectEqual(p.cs[i], p2.cs[i]);
2556 }
2557}
2558
2559test "polyPackLeGamma1 / polyUnpackLeGamma1 roundtrip gamma1_bits=19" {
2560 const gamma1_bits = 19;
2561 const gamma1: u32 = @as(u32, 1) << gamma1_bits;
2562
2563 // Create a test polynomial with coefficients in (-gamma1, gamma1]
2564 var p = Poly.zero;
2565 for (0..N) |i| {
2566 if (i % 2 == 0) {
2567 // Positive values: [0, gamma1]
2568 p.cs[i] = @intCast((i / 2) % (gamma1 + 1));
2569 } else {
2570 // Negative values: (Q-gamma1, Q)
2571 const neg_val: u32 = @intCast(((i / 2) % gamma1) + 1);
2572 p.cs[i] = Q - neg_val;
2573 }
2574 }
2575
2576 // Pack it
2577 var buf: [640]u8 = undefined; // (256 * 20) / 8 = 640 bytes
2578 polyPackLeGamma1(p, gamma1_bits, &buf);
2579
2580 // Unpack it
2581 const p2 = polyUnpackLeGamma1(gamma1_bits, &buf);
2582
2583 // Should be identical
2584 for (0..N) |i| {
2585 try testing.expectEqual(p.cs[i], p2.cs[i]);
2586 }
2587}
2588
2589test "polyPackW1 for gamma1_bits=17" {
2590 const gamma1_bits = 17;
2591
2592 // Create a test polynomial with small coefficients (w1 values < 64)
2593 var p = Poly.zero;
2594 for (0..N) |i| {
2595 p.cs[i] = @intCast(i % 64); // 6-bit values
2596 }
2597
2598 // Pack it
2599 var buf: [192]u8 = undefined; // (256 * 6) / 8 = 192 bytes
2600 polyPackW1(p, gamma1_bits, &buf);
2601
2602 // Verify basic properties
2603 // All bytes should be used
2604 var non_zero = false;
2605 for (buf) |b| {
2606 if (b != 0) {
2607 non_zero = true;
2608 break;
2609 }
2610 }
2611 try testing.expect(non_zero);
2612}
2613
2614test "polyPackW1 for gamma1_bits=19" {
2615 const gamma1_bits = 19;
2616
2617 // Create a test polynomial with small coefficients (w1 values < 16)
2618 var p = Poly.zero;
2619 for (0..N) |i| {
2620 p.cs[i] = @intCast(i % 16); // 4-bit values
2621 }
2622
2623 // Pack it
2624 var buf: [128]u8 = undefined; // (256 * 4) / 8 = 128 bytes
2625 polyPackW1(p, gamma1_bits, &buf);
2626
2627 // Verify basic properties
2628 var non_zero = false;
2629 for (buf) |b| {
2630 if (b != 0) {
2631 non_zero = true;
2632 break;
2633 }
2634 }
2635 try testing.expect(non_zero);
2636}
2637
2638test "makeHint and useHint correctness for gamma2=261888" {
2639 // Test for ML-DSA-65 and ML-DSA-87
2640 const gamma2: u32 = 261888;
2641
2642 // Test a selection of values to verify the hint mechanism works
2643 const test_values = [_]u32{ 0, 100, 1000, 10000, 100000, 1000000, Q / 2, Q - 1 };
2644
2645 for (test_values) |w| {
2646 // Decompose w to get w0 and w1
2647 const decomp = decompose(w, gamma2);
2648 const w0_plus_q = decomp.a0_plus_q;
2649 const w1 = decomp.a1;
2650
2651 // Test with various small perturbations f in [0, gamma2]
2652 const perturbations = [_]u32{ 0, 1, 10, 100, 1000, gamma2 / 2, gamma2 };
2653
2654 for (perturbations) |f| {
2655 // Test f (positive perturbation)
2656 const z0_pos = (w0_plus_q +% Q -% f) % Q;
2657 const hint_pos = makeHint(z0_pos, w1, gamma2);
2658 const w_perturbed_pos = (w +% Q -% f) % Q;
2659 const w1_recovered_pos = useHint(w_perturbed_pos, hint_pos, gamma2);
2660 try testing.expectEqual(w1, w1_recovered_pos);
2661
2662 // Test -f (negative perturbation)
2663 if (f > 0) {
2664 const z0_neg = (w0_plus_q +% f) % Q;
2665 const hint_neg = makeHint(z0_neg, w1, gamma2);
2666 const w_perturbed_neg = (w +% f) % Q;
2667 const w1_recovered_neg = useHint(w_perturbed_neg, hint_neg, gamma2);
2668 try testing.expectEqual(w1, w1_recovered_neg);
2669 }
2670 }
2671 }
2672}
2673
2674test "makeHint and useHint correctness for gamma2=95232" {
2675 // Test for ML-DSA-44
2676 const gamma2: u32 = 95232;
2677
2678 // Test a selection of values to verify the hint mechanism works
2679 const test_values = [_]u32{ 0, 100, 1000, 10000, 100000, 1000000, Q / 2, Q - 1 };
2680
2681 for (test_values) |w| {
2682 // Decompose w to get w0 and w1
2683 const decomp = decompose(w, gamma2);
2684 const w0_plus_q = decomp.a0_plus_q;
2685 const w1 = decomp.a1;
2686
2687 // Test with various small perturbations f in [0, gamma2]
2688 const perturbations = [_]u32{ 0, 1, 10, 100, 1000, gamma2 / 2, gamma2 };
2689
2690 for (perturbations) |f| {
2691 // Test f (positive perturbation)
2692 const z0_pos = (w0_plus_q +% Q -% f) % Q;
2693 const hint_pos = makeHint(z0_pos, w1, gamma2);
2694 const w_perturbed_pos = (w +% Q -% f) % Q;
2695 const w1_recovered_pos = useHint(w_perturbed_pos, hint_pos, gamma2);
2696 try testing.expectEqual(w1, w1_recovered_pos);
2697
2698 // Test -f (negative perturbation)
2699 if (f > 0) {
2700 const z0_neg = (w0_plus_q +% f) % Q;
2701 const hint_neg = makeHint(z0_neg, w1, gamma2);
2702 const w_perturbed_neg = (w +% f) % Q;
2703 const w1_recovered_neg = useHint(w_perturbed_neg, hint_neg, gamma2);
2704 try testing.expectEqual(w1, w1_recovered_neg);
2705 }
2706 }
2707 }
2708}
2709
2710test "polyMakeHint basic functionality" {
2711 const gamma2: u32 = 261888;
2712
2713 // Create test polynomials
2714 var p0 = Poly.zero;
2715 var p1 = Poly.zero;
2716
2717 // Fill with test values
2718 for (0..N) |i| {
2719 p0.cs[i] = @intCast((i * 17) % Q);
2720 p1.cs[i] = @intCast((i * 3) % 16); // High bits are at most 15 for gamma2=261888
2721 }
2722
2723 // Make hints
2724 const result = polyMakeHint(p0, p1, gamma2);
2725 const hint = result.hint;
2726 const count = result.count;
2727
2728 // Verify that hints are binary
2729 for (0..N) |i| {
2730 try testing.expect(hint.cs[i] == 0 or hint.cs[i] == 1);
2731 }
2732
2733 // Verify that count matches the number of 1s in hint
2734 var actual_count: u32 = 0;
2735 for (0..N) |i| {
2736 actual_count += hint.cs[i];
2737 }
2738 try testing.expectEqual(count, actual_count);
2739}
2740
2741test "polyUseHint reconstruction" {
2742 const gamma2: u32 = 261888;
2743
2744 // Create a test polynomial q
2745 var q = Poly.zero;
2746 for (0..N) |i| {
2747 q.cs[i] = @intCast((i * 123) % Q);
2748 }
2749
2750 // Decompose q to get high and low bits
2751 var q0_plus_q_array: [N]u32 = undefined;
2752 var q1_array: [N]u32 = undefined;
2753 for (0..N) |i| {
2754 const decomp = decompose(q.cs[i], gamma2);
2755 q0_plus_q_array[i] = decomp.a0_plus_q;
2756 q1_array[i] = decomp.a1;
2757 }
2758
2759 const q0_plus_q = Poly{ .cs = q0_plus_q_array };
2760 const q1 = Poly{ .cs = q1_array };
2761
2762 // Create hints (in this case, they'll mostly be 0 since q and q are the same)
2763 const hint_result = polyMakeHint(q0_plus_q, q1, gamma2);
2764 const hint = hint_result.hint;
2765
2766 // Use hints to recover high bits
2767 const recovered = polyUseHint(q, hint, gamma2);
2768
2769 // Recovered should match original high bits q1
2770 for (0..N) |i| {
2771 try testing.expectEqual(q1.cs[i], recovered.cs[i]);
2772 }
2773}
2774
2775test "hint roundtrip with perturbation" {
2776 const gamma2: u32 = 261888;
2777
2778 // Create a test polynomial w
2779 var w = Poly.zero;
2780 for (0..N) |i| {
2781 w.cs[i] = @intCast((i * 7919) % Q);
2782 }
2783
2784 // Decompose w to get w0 and w1
2785 var w0_plus_q = Poly.zero;
2786 var w1 = Poly.zero;
2787 for (0..N) |i| {
2788 const decomp = decompose(w.cs[i], gamma2);
2789 w0_plus_q.cs[i] = decomp.a0_plus_q;
2790 w1.cs[i] = decomp.a1;
2791 }
2792
2793 // Apply a small perturbation
2794 var f = Poly.zero;
2795 for (0..N) |i| {
2796 // Small perturbation in [-gamma2, gamma2]
2797 const f_val = @as(u32, @intCast(i % 1000));
2798 f.cs[i] = if (i % 2 == 0) f_val else Q -% f_val;
2799 }
2800
2801 // Compute w' = w - f and z0 = w0 - f
2802 var w_prime = Poly.zero;
2803 var z0 = Poly.zero;
2804 for (0..N) |i| {
2805 w_prime.cs[i] = (w.cs[i] +% Q -% f.cs[i]) % Q;
2806 z0.cs[i] = (w0_plus_q.cs[i] +% Q -% f.cs[i]) % Q;
2807 }
2808
2809 // Make hints
2810 const hint_result = polyMakeHint(z0, w1, gamma2);
2811 const hint = hint_result.hint;
2812
2813 // Use hints to recover w1 from w_prime
2814 const w1_recovered = polyUseHint(w_prime, hint, gamma2);
2815
2816 // Verify that we recovered the original high bits
2817 for (0..N) |i| {
2818 try testing.expectEqual(w1.cs[i], w1_recovered.cs[i]);
2819 }
2820}
2821
2822// Parameterized test helper for key generation
2823
2824fn testKeyGenerationBasic(comptime MlDsa: type, seed: [32]u8) !void {
2825 const result = MlDsa.newKeyFromSeed(&seed);
2826 const pk = result.pk;
2827 const sk = result.sk;
2828
2829 // Basic sanity checks
2830 try testing.expect(pk.rho.len == 32);
2831 try testing.expect(sk.rho.len == 32);
2832 try testing.expectEqualSlices(u8, &pk.rho, &sk.rho);
2833
2834 // Verify tr matches between pk and sk
2835 try testing.expectEqualSlices(u8, &pk.tr, &sk.tr);
2836
2837 // Test toBytes/fromBytes round-trip for public key
2838 const pk_bytes = pk.toBytes();
2839 const pk2 = try MlDsa.PublicKey.fromBytes(pk_bytes);
2840 try testing.expectEqualSlices(u8, &pk.rho, &pk2.rho);
2841 try testing.expectEqualSlices(u8, &pk.tr, &pk2.tr);
2842
2843 // Test toBytes/fromBytes round-trip for secret key
2844 const sk_bytes = sk.toBytes();
2845 const sk2 = try MlDsa.SecretKey.fromBytes(sk_bytes);
2846 try testing.expectEqualSlices(u8, &sk.rho, &sk2.rho);
2847 try testing.expectEqualSlices(u8, &sk.key, &sk2.key);
2848 try testing.expectEqualSlices(u8, &sk.tr, &sk2.tr);
2849}
2850
2851test "Key generation basic - all variants" {
2852 inline for (.{
2853 .{ .variant = MLDSA44, .seed_byte = 0x44 },
2854 .{ .variant = MLDSA65, .seed_byte = 0x65 },
2855 .{ .variant = MLDSA87, .seed_byte = 0x87 },
2856 }) |config| {
2857 const seed = [_]u8{config.seed_byte} ** 32;
2858 try testKeyGenerationBasic(config.variant, seed);
2859 }
2860}
2861
2862test "Key generation determinism" {
2863 const seed = [_]u8{ 0x12, 0x34, 0x56, 0x78 } ++ [_]u8{0xAB} ** 28;
2864
2865 // Generate two key pairs from the same seed
2866 const result1 = MLDSA44.newKeyFromSeed(&seed);
2867 const result2 = MLDSA44.newKeyFromSeed(&seed);
2868
2869 // They should be identical
2870 const pk_bytes1 = result1.pk.toBytes();
2871 const pk_bytes2 = result2.pk.toBytes();
2872 try testing.expectEqualSlices(u8, &pk_bytes1, &pk_bytes2);
2873
2874 const sk_bytes1 = result1.sk.toBytes();
2875 const sk_bytes2 = result2.sk.toBytes();
2876 try testing.expectEqualSlices(u8, &sk_bytes1, &sk_bytes2);
2877}
2878
2879test "Private key can compute public key" {
2880 const seed = [_]u8{0xFF} ** 32;
2881 const result = MLDSA44.newKeyFromSeed(&seed);
2882 const pk = result.pk;
2883 const sk = result.sk;
2884
2885 // Compute public key from private key
2886 const pk_from_sk = sk.public();
2887
2888 // Pack both public keys and compare
2889 const pk_bytes1 = pk.toBytes();
2890 const pk_bytes2 = pk_from_sk.toBytes();
2891
2892 try testing.expectEqualSlices(u8, &pk_bytes1, &pk_bytes2);
2893}
2894
2895// Parameterized test helper for sign and verify
2896fn testSignAndVerify(comptime MlDsa: type, seed: [32]u8, message: []const u8) !void {
2897 const result = MlDsa.newKeyFromSeed(&seed);
2898 const kp = try MlDsa.KeyPair.fromSecretKey(result.sk);
2899
2900 // Sign the message
2901 const sig = try kp.sign(message, null);
2902
2903 // Verify the signature
2904 try sig.verify(message, kp.public_key);
2905}
2906
2907test "Sign and verify - all variants" {
2908 inline for (.{
2909 .{ .variant = MLDSA44, .seed_byte = 0x44, .message = "Hello, ML-DSA-44!" },
2910 .{ .variant = MLDSA65, .seed_byte = 0x65, .message = "Hello, ML-DSA-65!" },
2911 .{ .variant = MLDSA87, .seed_byte = 0x87, .message = "Hello, ML-DSA-87!" },
2912 }) |config| {
2913 const seed = [_]u8{config.seed_byte} ** 32;
2914 try testSignAndVerify(config.variant, seed, config.message);
2915 }
2916}
2917
2918test "Invalid signature rejection" {
2919 const seed = [_]u8{0x99} ** 32;
2920 const result = MLDSA44.newKeyFromSeed(&seed);
2921 const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk);
2922
2923 const message = "Original message";
2924
2925 // Sign the message
2926 const sig = try kp.sign(message, null);
2927
2928 // Verify with wrong message should fail
2929 const wrong_message = "Modified message";
2930 try testing.expectError(error.SignatureVerificationFailed, sig.verify(wrong_message, kp.public_key));
2931
2932 // Modify signature and verify should fail
2933 var corrupted_sig_bytes = sig.toBytes();
2934 corrupted_sig_bytes[0] ^= 0xFF;
2935 const corrupted_sig = try MLDSA44.Signature.fromBytes(corrupted_sig_bytes);
2936 try testing.expectError(error.SignatureVerificationFailed, corrupted_sig.verify(message, kp.public_key));
2937}
2938
2939test "Context string support" {
2940 const seed = [_]u8{0xAA} ** 32;
2941 const result = MLDSA44.newKeyFromSeed(&seed);
2942 const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk);
2943
2944 const message = "Test message";
2945 const context1 = "context1";
2946 const context2 = "context2";
2947
2948 // Sign with context1
2949 const sig1 = try kp.signWithContext(message, null, context1);
2950
2951 // Verify with correct context should succeed
2952 try sig1.verifyWithContext(message, kp.public_key, context1);
2953
2954 // Verify with wrong context should fail
2955 try testing.expectError(error.SignatureVerificationFailed, sig1.verifyWithContext(message, kp.public_key, context2));
2956
2957 // Verify with empty context should fail
2958 try testing.expectError(error.SignatureVerificationFailed, sig1.verify(message, kp.public_key));
2959
2960 // Sign with empty context
2961 const sig2 = try kp.sign(message, null);
2962
2963 // Verify with empty context should succeed
2964 try sig2.verify(message, kp.public_key);
2965
2966 // Verify with non-empty context should fail
2967 try testing.expectError(error.SignatureVerificationFailed, sig2.verifyWithContext(message, kp.public_key, context1));
2968
2969 // Test maximum context length (255 bytes)
2970 const max_context = [_]u8{0xBB} ** 255;
2971 const sig3 = try kp.signWithContext(message, null, &max_context);
2972 try sig3.verifyWithContext(message, kp.public_key, &max_context);
2973
2974 // Test context too long (256 bytes should fail)
2975 const too_long_context = [_]u8{0xCC} ** 256;
2976 try testing.expectError(error.ContextTooLong, kp.signWithContext(message, null, &too_long_context));
2977}
2978
2979test "Context string with streaming API" {
2980 const seed = [_]u8{0xDD} ** 32;
2981 const result = MLDSA44.newKeyFromSeed(&seed);
2982 const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk);
2983
2984 const context = "streaming-context";
2985 const message_part1 = "Hello, ";
2986 const message_part2 = "World!";
2987
2988 // Sign using streaming API with context
2989 var signer = try kp.signerWithContext(null, context);
2990 signer.update(message_part1);
2991 signer.update(message_part2);
2992 const sig = signer.finalize();
2993
2994 // Verify using streaming API with context
2995 var verifier = try sig.verifierWithContext(kp.public_key, context);
2996 verifier.update(message_part1);
2997 verifier.update(message_part2);
2998 try verifier.verify();
2999
3000 // Verify with wrong context should fail
3001 var verifier_wrong = try sig.verifierWithContext(kp.public_key, "wrong");
3002 verifier_wrong.update(message_part1);
3003 verifier_wrong.update(message_part2);
3004 try testing.expectError(error.SignatureVerificationFailed, verifier_wrong.verify());
3005}
3006
3007test "Signature determinism (same rnd)" {
3008 const seed = [_]u8{0x11} ** 32;
3009 const result = MLDSA44.newKeyFromSeed(&seed);
3010 const sk = result.sk;
3011
3012 const message = "Deterministic test";
3013 const rnd = [_]u8{0x22} ** 32;
3014
3015 // Sign twice with same randomness using streaming API
3016 var st1 = try sk.signer(rnd);
3017 st1.update(message);
3018 const sig1 = st1.finalize();
3019
3020 var st2 = try sk.signer(rnd);
3021 st2.update(message);
3022 const sig2 = st2.finalize();
3023
3024 // Signatures should be identical
3025 try testing.expectEqualSlices(u8, &sig1.toBytes(), &sig2.toBytes());
3026}
3027
3028test "Signature toBytes/fromBytes roundtrip" {
3029 const seed = [_]u8{0x33} ** 32;
3030 const result = MLDSA44.newKeyFromSeed(&seed);
3031 const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk);
3032
3033 const message = "toBytes/fromBytes test";
3034
3035 // Sign the message
3036 const sig = try kp.sign(message, null);
3037 const sig_bytes = sig.toBytes();
3038
3039 // Unpack and repack
3040 const sig_reparsed = try MLDSA44.Signature.fromBytes(sig_bytes);
3041
3042 const repacked = sig_reparsed.toBytes();
3043
3044 // Should match original
3045 try testing.expectEqualSlices(u8, &sig_bytes, &repacked);
3046}
3047
3048test "Empty message signing" {
3049 const seed = [_]u8{0x44} ** 32;
3050 const result = MLDSA44.newKeyFromSeed(&seed);
3051 const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk);
3052
3053 const message = "";
3054
3055 // Sign empty message
3056 const sig = try kp.sign(message, null);
3057
3058 // Verify should work
3059 try sig.verify(message, kp.public_key);
3060}
3061
3062test "Long message signing" {
3063 const seed = [_]u8{0x55} ** 32;
3064 const result = MLDSA44.newKeyFromSeed(&seed);
3065 const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk);
3066
3067 // Create a long message (1KB)
3068 const long_message = [_]u8{0xAB} ** 1024;
3069
3070 // Sign long message
3071 const sig = try kp.sign(&long_message, null);
3072
3073 // Verify should work
3074 try sig.verify(&long_message, kp.public_key);
3075}
3076
3077// Helper function to decode hex string into bytes
3078fn hexToBytes(comptime hex: []const u8, out: []u8) !void {
3079 if (hex.len != out.len * 2) return error.InvalidLength;
3080
3081 var i: usize = 0;
3082 while (i < out.len) : (i += 1) {
3083 const hi = try std.fmt.charToDigit(hex[i * 2], 16);
3084 const lo = try std.fmt.charToDigit(hex[i * 2 + 1], 16);
3085 out[i] = (hi << 4) | lo;
3086 }
3087}
3088
3089test "ML-DSA-44 KAT test vector 0" {
3090 // Test vector from NIST ML-DSA KAT (count = 0)
3091 // xi is the seed for key generation (Algorithm 1, line 1)
3092 const xi_hex = "f696484048ec21f96cf50a56d0759c448f3779752f0383d37449690694cf7a68";
3093 const pk_hex_start = "bd4e96f9a038ab5e36214fe69c0b1cb835ef9d7c8417e76aecd152f5cddebec8";
3094 const msg_hex = "6dbbc4375136df3b07f7c70e639e223e";
3095
3096 // Parse xi (32-byte seed for key generation)
3097 var xi: [32]u8 = undefined;
3098 try hexToBytes(xi_hex, &xi);
3099
3100 // Generate keys from xi
3101 const result = MLDSA44.newKeyFromSeed(&xi);
3102 const pk = result.pk;
3103 const sk = result.sk;
3104
3105 // Verify public key starts with expected bytes
3106 const pk_bytes = pk.toBytes();
3107
3108 var expected_pk_start: [32]u8 = undefined;
3109 try hexToBytes(pk_hex_start, &expected_pk_start);
3110
3111 // Check first 32 bytes of public key match
3112 try testing.expectEqualSlices(u8, &expected_pk_start, pk_bytes[0..32]);
3113
3114 // Parse message
3115 var msg: [16]u8 = undefined;
3116 try hexToBytes(msg_hex, &msg);
3117
3118 // Sign the message (deterministic mode with fixed randomness)
3119 const kp = try MLDSA44.KeyPair.fromSecretKey(sk);
3120 const sig = try kp.sign(&msg, null);
3121
3122 // Verify the signature
3123 try sig.verify(&msg, kp.public_key);
3124}
3125
3126test "ML-DSA-65 KAT test vector 0" {
3127 // Test vector from NIST ML-DSA KAT (count = 0)
3128 // xi is the seed for key generation (Algorithm 1, line 1)
3129 const xi_hex = "f696484048ec21f96cf50a56d0759c448f3779752f0383d37449690694cf7a68";
3130 const pk_hex_start = "e50d03fff3b3a70961abbb92a390008dec1283f603f50cdbaaa3d00bd659bc76";
3131 const msg_hex = "6dbbc4375136df3b07f7c70e639e223e";
3132
3133 // Parse xi (32-byte seed for key generation)
3134 var xi: [32]u8 = undefined;
3135 try hexToBytes(xi_hex, &xi);
3136
3137 // Generate keys from xi
3138 const result = MLDSA65.newKeyFromSeed(&xi);
3139 const pk = result.pk;
3140 const sk = result.sk;
3141
3142 // Verify public key starts with expected bytes
3143 const pk_bytes = pk.toBytes();
3144
3145 var expected_pk_start: [32]u8 = undefined;
3146 try hexToBytes(pk_hex_start, &expected_pk_start);
3147
3148 // Check first 32 bytes of public key match
3149 try testing.expectEqualSlices(u8, &expected_pk_start, pk_bytes[0..32]);
3150
3151 // Parse message
3152 var msg: [16]u8 = undefined;
3153 try hexToBytes(msg_hex, &msg);
3154
3155 // Sign the message
3156 const kp = try MLDSA65.KeyPair.fromSecretKey(sk);
3157 const sig = try kp.sign(&msg, null);
3158
3159 // Verify the signature
3160 try sig.verify(&msg, kp.public_key);
3161}
3162
3163test "ML-DSA-87 KAT test vector 0" {
3164 // Test vector from NIST ML-DSA KAT (count = 0)
3165 // xi is the seed for key generation (Algorithm 1, line 1)
3166 const xi_hex = "f696484048ec21f96cf50a56d0759c448f3779752f0383d37449690694cf7a68";
3167 const pk_hex_start = "bc89b367d4288f47c71a74679d0fcffbe041de41b5da2f5fc66d8e28c5899494";
3168 const msg_hex = "6dbbc4375136df3b07f7c70e639e223e";
3169
3170 // Parse xi (32-byte seed for key generation)
3171 var xi: [32]u8 = undefined;
3172 try hexToBytes(xi_hex, &xi);
3173
3174 // Generate keys from xi
3175 const result = MLDSA87.newKeyFromSeed(&xi);
3176 const pk = result.pk;
3177 const sk = result.sk;
3178
3179 // Verify public key starts with expected bytes
3180 const pk_bytes = pk.toBytes();
3181
3182 var expected_pk_start: [32]u8 = undefined;
3183 try hexToBytes(pk_hex_start, &expected_pk_start);
3184
3185 // Check first 32 bytes of public key match
3186 try testing.expectEqualSlices(u8, &expected_pk_start, pk_bytes[0..32]);
3187
3188 // Parse message
3189 var msg: [16]u8 = undefined;
3190 try hexToBytes(msg_hex, &msg);
3191
3192 // Sign the message
3193 const kp = try MLDSA87.KeyPair.fromSecretKey(sk);
3194 const sig = try kp.sign(&msg, null);
3195
3196 // Verify the signature
3197 try sig.verify(&msg, kp.public_key);
3198}
3199
3200test "KeyPair API - generate and sign" {
3201 // Test the new KeyPair API with random generation
3202 const kp = MLDSA44.KeyPair.generate();
3203 const msg = "Test message for KeyPair API";
3204
3205 // Sign with deterministic mode (no noise)
3206 const sig = try kp.sign(msg, null);
3207
3208 // Verify using Signature.verify API
3209 try sig.verify(msg, kp.public_key);
3210}
3211
3212test "KeyPair API - generateDeterministic" {
3213 // Test deterministic key generation
3214 const seed = [_]u8{42} ** 32;
3215 const kp1 = try MLDSA44.KeyPair.generateDeterministic(seed);
3216 const kp2 = try MLDSA44.KeyPair.generateDeterministic(seed);
3217
3218 // Same seed should produce same keys
3219 const pk1_bytes = kp1.public_key.toBytes();
3220 const pk2_bytes = kp2.public_key.toBytes();
3221 try testing.expectEqualSlices(u8, &pk1_bytes, &pk2_bytes);
3222}
3223
3224test "KeyPair API - fromSecretKey" {
3225 // Generate a key pair
3226 const kp1 = MLDSA44.KeyPair.generate();
3227
3228 // Derive public key from secret key
3229 const kp2 = try MLDSA44.KeyPair.fromSecretKey(kp1.secret_key);
3230
3231 // Public keys should match
3232 const pk1_bytes = kp1.public_key.toBytes();
3233 const pk2_bytes = kp2.public_key.toBytes();
3234 try testing.expectEqualSlices(u8, &pk1_bytes, &pk2_bytes);
3235}
3236
3237test "Signature verification with noise" {
3238 // Test signing with randomness (hedged signatures)
3239 const kp = MLDSA65.KeyPair.generate();
3240 const msg = "Message to be signed with randomness";
3241
3242 // Create some noise
3243 const noise = [_]u8{ 1, 2, 3, 4, 5 } ++ [_]u8{0} ** 27;
3244
3245 // Sign with noise
3246 const sig = try kp.sign(msg, noise);
3247
3248 // Verify should still work
3249 try sig.verify(msg, kp.public_key);
3250}
3251
3252test "Signature verification failure" {
3253 // Test that invalid signatures are rejected
3254 const kp = MLDSA44.KeyPair.generate();
3255 const msg = "Original message";
3256 const sig = try kp.sign(msg, null);
3257
3258 // Verify with wrong message should fail
3259 const wrong_msg = "Different message";
3260 try testing.expectError(error.SignatureVerificationFailed, sig.verify(wrong_msg, kp.public_key));
3261}
3262
3263test "Streaming API - sign and verify" {
3264 const seed = [_]u8{0x55} ** 32;
3265 const kp = try MLDSA44.KeyPair.generateDeterministic(seed);
3266
3267 const msg = "Test message for streaming API";
3268
3269 // Sign using streaming API
3270 var signer = try kp.signer(null);
3271 signer.update(msg);
3272 const sig = signer.finalize();
3273
3274 // Verify using streaming API
3275 var verifier = try sig.verifier(kp.public_key);
3276 verifier.update(msg);
3277 try verifier.verify();
3278}
3279
3280test "Streaming API - chunked message" {
3281 const seed = [_]u8{0x66} ** 32;
3282 const kp = try MLDSA44.KeyPair.generateDeterministic(seed);
3283
3284 // Create a message in chunks
3285 const chunk1 = "Hello, ";
3286 const chunk2 = "streaming ";
3287 const chunk3 = "world!";
3288 const full_msg = chunk1 ++ chunk2 ++ chunk3;
3289
3290 // Sign with chunks
3291 var signer = try kp.signer(null);
3292 signer.update(chunk1);
3293 signer.update(chunk2);
3294 signer.update(chunk3);
3295 const sig_chunked = signer.finalize();
3296
3297 // Sign with full message for comparison
3298 var signer2 = try kp.signer(null);
3299 signer2.update(full_msg);
3300 const sig_full = signer2.finalize();
3301
3302 // Signatures should be identical
3303 try testing.expectEqualSlices(u8, &sig_chunked.toBytes(), &sig_full.toBytes());
3304
3305 // Verify with chunks
3306 const sig = sig_chunked;
3307 var verifier = try sig.verifier(kp.public_key);
3308 verifier.update(chunk1);
3309 verifier.update(chunk2);
3310 verifier.update(chunk3);
3311 try verifier.verify();
3312}
3313
3314test "Streaming API - large message" {
3315 const seed = [_]u8{0x77} ** 32;
3316 const kp = try MLDSA44.KeyPair.generateDeterministic(seed);
3317
3318 // Create a large message (1MB)
3319 const chunk_size = 4096;
3320 const num_chunks = 256;
3321 var chunk: [chunk_size]u8 = undefined;
3322 for (0..chunk_size) |i| {
3323 chunk[i] = @intCast(i % 256);
3324 }
3325
3326 // Sign streaming
3327 var signer = try kp.signer(null);
3328 for (0..num_chunks) |_| {
3329 signer.update(&chunk);
3330 }
3331 const sig = signer.finalize();
3332
3333 // Verify streaming
3334 var verifier = try sig.verifier(kp.public_key);
3335 for (0..num_chunks) |_| {
3336 verifier.update(&chunk);
3337 }
3338 try verifier.verify();
3339}
3340
3341test "Streaming API - all parameter sets" {
3342 const test_msg = "Streaming test for all ML-DSA parameter sets";
3343
3344 // ML-DSA-44
3345 {
3346 const seed = [_]u8{0x44} ** 32;
3347 const kp = try MLDSA44.KeyPair.generateDeterministic(seed);
3348 var signer = try kp.signer(null);
3349 signer.update(test_msg);
3350 const sig = signer.finalize();
3351 var verifier = try sig.verifier(kp.public_key);
3352 verifier.update(test_msg);
3353 try verifier.verify();
3354 }
3355
3356 // ML-DSA-65
3357 {
3358 const seed = [_]u8{0x65} ** 32;
3359 const kp = try MLDSA65.KeyPair.generateDeterministic(seed);
3360 var signer = try kp.signer(null);
3361 signer.update(test_msg);
3362 const sig = signer.finalize();
3363 var verifier = try sig.verifier(kp.public_key);
3364 verifier.update(test_msg);
3365 try verifier.verify();
3366 }
3367
3368 // ML-DSA-87
3369 {
3370 const seed = [_]u8{0x87} ** 32;
3371 const kp = try MLDSA87.KeyPair.generateDeterministic(seed);
3372 var signer = try kp.signer(null);
3373 signer.update(test_msg);
3374 const sig = signer.finalize();
3375 var verifier = try sig.verifier(kp.public_key);
3376 verifier.update(test_msg);
3377 try verifier.verify();
3378 }
3379}
3380
3381/// Extended Euclidian Algorithm
3382/// Only meant to be used on comptime values; correctness matters, performance doesn't.
3383fn extendedEuclidean(comptime T: type, comptime a_: T, comptime b_: T) struct { gcd: T, x: T, y: T } {
3384 var a = a_;
3385 var b = b_;
3386 var x0: T = 1;
3387 var x1: T = 0;
3388 var y0: T = 0;
3389 var y1: T = 1;
3390
3391 while (b != 0) {
3392 const q = @divTrunc(a, b);
3393 const temp_a = a;
3394 a = b;
3395 b = temp_a - q * b;
3396
3397 const temp_x = x0;
3398 x0 = x1;
3399 x1 = temp_x - q * x1;
3400
3401 const temp_y = y0;
3402 y0 = y1;
3403 y1 = temp_y - q * y1;
3404 }
3405
3406 return .{ .gcd = a, .x = x0, .y = y0 };
3407}
3408
3409/// Modular inversion: computes a^(-1) mod p
3410/// Requires gcd(a,p) = 1. The result is normalized to the range [0, p).
3411fn modularInverse(comptime T: type, comptime a: T, comptime p: T) T {
3412 // Use a signed type for EEA computation
3413 const type_info = @typeInfo(T);
3414 const SignedT = if (type_info == .int and type_info.int.signedness == .unsigned)
3415 std.meta.Int(.signed, type_info.int.bits)
3416 else
3417 T;
3418
3419 const a_signed = @as(SignedT, @intCast(a));
3420 const p_signed = @as(SignedT, @intCast(p));
3421
3422 const r = extendedEuclidean(SignedT, a_signed, p_signed);
3423 assert(r.gcd == 1);
3424
3425 // Normalize result to [0, p)
3426 var result = r.x;
3427 while (result < 0) {
3428 result += p_signed;
3429 }
3430
3431 return @intCast(result);
3432}
3433
3434/// Modular exponentiation: computes a^s mod p using square-and-multiply algorithm.
3435fn modularPow(comptime T: type, comptime a: T, s: T, comptime p: T) T {
3436 const type_info = @typeInfo(T);
3437 const bits = type_info.int.bits;
3438 const WideT = std.meta.Int(.unsigned, bits * 2);
3439
3440 var ret: T = 1;
3441 var base: T = a;
3442 var exp = s;
3443
3444 while (exp > 0) {
3445 if (exp & 1 == 1) {
3446 ret = @intCast((@as(WideT, ret) * @as(WideT, base)) % p);
3447 }
3448 base = @intCast((@as(WideT, base) * @as(WideT, base)) % p);
3449 exp >>= 1;
3450 }
3451
3452 return ret;
3453}
3454
3455/// Creates an all-ones or all-zeros mask from a single bit value.
3456/// Returns all 1s (0xFF...FF) if bit == 1, all 0s if bit == 0.
3457fn bitMask(comptime T: type, bit: T) T {
3458 const type_info = @typeInfo(T);
3459 if (type_info != .int or type_info.int.signedness != .unsigned) {
3460 @compileError("bitMask requires an unsigned integer type");
3461 }
3462 return -%bit;
3463}
3464
3465/// Creates a mask from the sign bit of a signed integer.
3466/// Returns all 1s (0xFF...FF) if x < 0, all 0s if x >= 0.
3467fn signMask(comptime T: type, x: T) std.meta.Int(.unsigned, @typeInfo(T).int.bits) {
3468 const type_info = @typeInfo(T);
3469 if (type_info != .int) {
3470 @compileError("signMask requires an integer type");
3471 }
3472
3473 const bits = type_info.int.bits;
3474 const SignedT = std.meta.Int(.signed, bits);
3475
3476 // Convert to signed if needed, arithmetic right shift to propagate sign bit
3477 const x_signed: SignedT = if (type_info.int.signedness == .signed) x else @bitCast(x);
3478 const shifted = x_signed >> (bits - 1);
3479 return @bitCast(shifted);
3480}
3481
3482/// Montgomery reduction: for input x, returns y where y ≡ x*R^(-1) (mod q).
3483/// This is a generic implementation parameterized by the modulus q, its inverse qInv,
3484/// the Montgomery constant R, and the result bound.
3485///
3486/// For ML-DSA: R = 2^32, returns y < 2q
3487/// For ML-KEM: R = 2^16, returns y in range (-q, q)
3488fn montgomeryReduce(
3489 comptime InT: type,
3490 comptime OutT: type,
3491 comptime q: comptime_int,
3492 comptime qInv: comptime_int,
3493 comptime r_bits: comptime_int,
3494 x: InT,
3495) OutT {
3496 const mask = (@as(InT, 1) << r_bits) - 1;
3497 const m_full = (x *% qInv) & mask;
3498 const m: OutT = @truncate(m_full);
3499
3500 const yR = x -% @as(InT, m) * @as(InT, q);
3501 const y_shifted = @as(std.meta.Int(.unsigned, @typeInfo(InT).Int.bits), @bitCast(yR)) >> r_bits;
3502 return @bitCast(@as(std.meta.Int(.unsigned, @typeInfo(OutT).Int.bits), @truncate(y_shifted)));
3503}
3504
3505/// Uniform sampling using SHAKE-128 with rejection sampling.
3506/// Samples polynomial coefficients uniformly from [0, q) using rejection sampling.
3507///
3508/// Parameters:
3509/// - PolyType: The polynomial type to return
3510/// - q: Modulus
3511/// - bits_per_coef: Number of bits per coefficient (12 or 23)
3512/// - n: Number of coefficients
3513/// - seed: Random seed
3514/// - domain_sep: Domain separation bytes (appended to seed)
3515fn sampleUniformRejection(
3516 comptime PolyType: type,
3517 comptime q: comptime_int,
3518 comptime bits_per_coef: comptime_int,
3519 comptime n: comptime_int,
3520 seed: []const u8,
3521 domain_sep: []const u8,
3522) PolyType {
3523 var h = sha3.Shake128.init(.{});
3524 h.update(seed);
3525 h.update(domain_sep);
3526
3527 const buf_len = sha3.Shake128.block_length; // 168 bytes
3528 var buf: [buf_len]u8 = undefined;
3529
3530 var ret: PolyType = undefined;
3531 var coef_idx: usize = 0;
3532
3533 if (bits_per_coef == 12) {
3534 // ML-KEM path: pack 2 coefficients per 3 bytes (12 bits each)
3535 outer: while (true) {
3536 h.squeeze(&buf);
3537
3538 var j: usize = 0;
3539 while (j < buf_len) : (j += 3) {
3540 const b0 = @as(u16, buf[j]);
3541 const b1 = @as(u16, buf[j + 1]);
3542 const b2 = @as(u16, buf[j + 2]);
3543
3544 const ts: [2]u16 = .{
3545 b0 | ((b1 & 0xf) << 8),
3546 (b1 >> 4) | (b2 << 4),
3547 };
3548
3549 inline for (ts) |t| {
3550 if (t < q) {
3551 ret.cs[coef_idx] = @intCast(t);
3552 coef_idx += 1;
3553 if (coef_idx == n) break :outer;
3554 }
3555 }
3556 }
3557 }
3558 } else if (bits_per_coef == 23) {
3559 // ML-DSA path: 1 coefficient per 3 bytes (23 bits)
3560 while (coef_idx < n) {
3561 h.squeeze(&buf);
3562
3563 var j: usize = 0;
3564 while (j < buf_len and coef_idx < n) : (j += 3) {
3565 const t = (@as(u32, buf[j]) |
3566 (@as(u32, buf[j + 1]) << 8) |
3567 (@as(u32, buf[j + 2]) << 16)) & 0x7fffff;
3568
3569 if (t < q) {
3570 ret.cs[coef_idx] = @intCast(t);
3571 coef_idx += 1;
3572 }
3573 }
3574 }
3575 } else {
3576 @compileError("bits_per_coef must be 12 or 23");
3577 }
3578
3579 return ret;
3580}
3581
3582test "bitMask and signMask helpers" {
3583 try testing.expectEqual(@as(u32, 0x00000000), bitMask(u32, 0));
3584 try testing.expectEqual(@as(u32, 0xFFFFFFFF), bitMask(u32, 1));
3585 try testing.expectEqual(@as(u8, 0x00), bitMask(u8, 0));
3586 try testing.expectEqual(@as(u8, 0xFF), bitMask(u8, 1));
3587 try testing.expectEqual(@as(u64, 0x0000000000000000), bitMask(u64, 0));
3588 try testing.expectEqual(@as(u64, 0xFFFFFFFFFFFFFFFF), bitMask(u64, 1));
3589
3590 try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -1));
3591 try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -100));
3592 try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 0));
3593 try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 1));
3594 try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 100));
3595
3596 try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(u32, 0x80000000)); // MSB set
3597 try testing.expectEqual(@as(u32, 0x00000000), signMask(u32, 0x7FFFFFFF)); // MSB clear
3598}