master
   1//! Module-Lattice-Based Digital Signature Algorithm (ML-DSA) as specified in NIST FIPS 204.
   2//!
   3//! ML-DSA is a post-quantum secure digital signature scheme based on the hardness
   4//! of the Module Learning With Errors (MLWE) and Module Short Integer Solution (MSIS)
   5//! problems over module lattices.
   6//!
   7//! We provide three parameter sets:
   8//!
   9//! - ML-DSA-44: NIST security category 2 (128-bit security)
  10//! - ML-DSA-65: NIST security category 3 (192-bit security)
  11//! - ML-DSA-87: NIST security category 5 (256-bit security)
  12
  13const std = @import("std");
  14const builtin = @import("builtin");
  15const testing = std.testing;
  16const assert = std.debug.assert;
  17const crypto = std.crypto;
  18const errors = std.crypto.errors;
  19const math = std.math;
  20const mem = std.mem;
  21const sha3 = crypto.hash.sha3;
  22
  23const ContextTooLongError = errors.ContextTooLongError;
  24const EncodingError = errors.EncodingError;
  25const SignatureVerificationError = errors.SignatureVerificationError;
  26
  27/// ML-DSA-44 (Module-Lattice-Based Digital Signature Algorithm, 44 parameter set)
  28/// as specified in NIST FIPS 204.
  29///
  30/// This is a post-quantum signature scheme providing NIST security category 2,
  31/// which is roughly equivalent to the security of SHA-256 or AES-128.
  32///
  33/// Key sizes:
  34///
  35/// - Public key: 1312 bytes
  36/// - Secret key: 2560 bytes
  37/// - Signature: 2420 bytes
  38///
  39/// Example usage:
  40///
  41/// ```zig
  42/// const kp = MLDSA44.KeyPair.generate();
  43/// const msg = "Hello, post-quantum world!";
  44/// const sig = try kp.sign(msg, null);
  45/// try sig.verify(msg, kp.public_key);
  46/// ```
  47pub const MLDSA44 = MLDSAImpl(.{
  48    .name = "ML-DSA-44",
  49    .k = 4,
  50    .l = 4,
  51    .eta = 2,
  52    .omega = 80,
  53    .tau = 39,
  54    .gamma1_bits = 17,
  55    .gamma2 = 95232, // (Q-1)/88
  56    .tr_size = 64,
  57    .ctilde_size = 32,
  58});
  59
  60/// ML-DSA-65 (Module-Lattice-Based Digital Signature Algorithm, 65 parameter set)
  61/// as specified in NIST FIPS 204.
  62///
  63/// This is a post-quantum signature scheme providing NIST security category 3,
  64/// which is roughly equivalent to the security of SHA-384 or AES-192.
  65///
  66/// Key sizes:
  67///
  68/// - Public key: 1952 bytes
  69/// - Secret key: 4032 bytes
  70/// - Signature: 3309 bytes
  71///
  72/// This parameter set offers higher security than ML-DSA-44 at the cost of
  73/// larger keys and signatures.
  74pub const MLDSA65 = MLDSAImpl(.{
  75    .name = "ML-DSA-65",
  76    .k = 6,
  77    .l = 5,
  78    .eta = 4,
  79    .omega = 55,
  80    .tau = 49,
  81    .gamma1_bits = 19,
  82    .gamma2 = 261888, // (Q-1)/32
  83    .tr_size = 64,
  84    .ctilde_size = 48,
  85});
  86
  87/// ML-DSA-87 (Module-Lattice-Based Digital Signature Algorithm, 87 parameter set)
  88/// as specified in NIST FIPS 204.
  89///
  90/// This is a post-quantum signature scheme providing NIST security category 5,
  91/// which is roughly equivalent to the security of SHA-512 or AES-256.
  92///
  93/// Key sizes:
  94///
  95/// - Public key: 2592 bytes
  96/// - Secret key: 4896 bytes
  97/// - Signature: 4627 bytes
  98///
  99/// This parameter set offers the highest security level among the three ML-DSA
 100/// variants, suitable for applications requiring maximum security assurance.
 101pub const MLDSA87 = MLDSAImpl(.{
 102    .name = "ML-DSA-87",
 103    .k = 8,
 104    .l = 7,
 105    .eta = 2,
 106    .omega = 75,
 107    .tau = 60,
 108    .gamma1_bits = 19,
 109    .gamma2 = 261888, // (Q-1)/32
 110    .tr_size = 64,
 111    .ctilde_size = 64,
 112});
 113
 114const N: usize = 256; // Degree of polynomials
 115const Q: u32 = 8380417; // Modulus: 2^23 - 2^13 + 1
 116const Q_BITS: u32 = 23;
 117const D: u32 = 13; // Dropped bits in power2Round
 118
 119// Montgomery constant R = 2^32 mod q
 120const R: u64 = 1 << 32;
 121
 122// Q^(-1) mod 2^32 = -(q^-1) mod 2^32
 123const Q_INV: u32 = 4236238847;
 124
 125// (256)^(-1) * R^2 mod q, used in inverse NTT
 126const R_OVER_256: u32 = 41978;
 127
 128// Primitive 512th root of unity
 129const ZETA: u32 = 1753;
 130
 131const Params = struct {
 132    name: []const u8,
 133
 134    // Matrix dimensions
 135    k: u8, // Height of matrix A
 136    l: u8, // Width of matrix A
 137
 138    // Sampling parameter
 139    eta: u8, // Bound for secret coefficients
 140
 141    // Hint parameters
 142    omega: u16, // Maximum number of hint bits
 143
 144    // Challenge parameter
 145    tau: u16, // Weight of challenge polynomial
 146
 147    // Rounding parameters
 148    gamma1_bits: u8, // Bits for gamma1
 149    gamma2: u32, // Parameter for decompose
 150
 151    // Sizes
 152    tr_size: usize, // Size of tr hash
 153    ctilde_size: usize, // Size of challenge hash
 154};
 155
 156const Poly = struct {
 157    cs: [N]u32,
 158
 159    const zero: Poly = .{ .cs = .{0} ** N };
 160
 161    // Add two polynomials (no normalization)
 162    fn add(a: Poly, b: Poly) Poly {
 163        var ret: Poly = undefined;
 164        for (0..N) |i| {
 165            ret.cs[i] = a.cs[i] + b.cs[i];
 166        }
 167        return ret;
 168    }
 169
 170    // Subtract two polynomials (assumes b coefficients < 2q)
 171    fn sub(a: Poly, b: Poly) Poly {
 172        var ret: Poly = undefined;
 173        for (0..N) |i| {
 174            ret.cs[i] = a.cs[i] +% (@as(u32, 2 * Q) -% b.cs[i]);
 175        }
 176        return ret;
 177    }
 178
 179    // Reduce each coefficient to < 2q
 180    fn reduceLe2Q(p: Poly) Poly {
 181        var ret = p;
 182        for (0..N) |i| {
 183            ret.cs[i] = le2Q(ret.cs[i]);
 184        }
 185        return ret;
 186    }
 187
 188    // Normalize coefficients to [0, q)
 189    fn normalize(p: Poly) Poly {
 190        var ret = p;
 191        for (0..N) |i| {
 192            ret.cs[i] = modQ(ret.cs[i]);
 193        }
 194        return ret;
 195    }
 196
 197    // Normalize assuming coefficients already < 2q
 198    fn normalizeAssumingLe2Q(p: Poly) Poly {
 199        var ret = p;
 200        for (0..N) |i| {
 201            ret.cs[i] = le2qModQ(ret.cs[i]);
 202        }
 203        return ret;
 204    }
 205
 206    // Pointwise multiplication in NTT domain (Montgomery form)
 207    fn mulHat(a: Poly, b: Poly) Poly {
 208        var ret: Poly = undefined;
 209        for (0..N) |i| {
 210            ret.cs[i] = montReduceLe2Q(@as(u64, a.cs[i]) * @as(u64, b.cs[i]));
 211        }
 212        return ret;
 213    }
 214
 215    // Forward NTT
 216    fn ntt(p: Poly) Poly {
 217        var ret = p;
 218        ret.nttInPlace();
 219        return ret;
 220    }
 221
 222    // In-place forward NTT
 223    fn nttInPlace(p: *Poly) void {
 224        var k: usize = 0;
 225        var l: usize = N / 2;
 226
 227        while (l > 0) : (l >>= 1) {
 228            var offset: usize = 0;
 229            while (offset < N - l) : (offset += 2 * l) {
 230                k += 1;
 231                const zeta: u64 = zetas[k];
 232
 233                for (offset..offset + l) |j| {
 234                    const t = montReduceLe2Q(zeta * @as(u64, p.cs[j + l]));
 235                    p.cs[j + l] = p.cs[j] +% (2 * Q -% t);
 236                    p.cs[j] +%= t;
 237                }
 238            }
 239        }
 240    }
 241
 242    // Inverse NTT
 243    fn invNTT(p: Poly) Poly {
 244        var ret = p;
 245        ret.invNTTInPlace();
 246        return ret;
 247    }
 248
 249    // In-place inverse NTT
 250    fn invNTTInPlace(p: *Poly) void {
 251        var k: usize = 0;
 252        var l: usize = 1;
 253
 254        while (l < N) : (l <<= 1) {
 255            var offset: usize = 0;
 256            while (offset < N - l) : (offset += 2 * l) {
 257                const zeta: u64 = inv_zetas[k];
 258                k += 1;
 259
 260                for (offset..offset + l) |j| {
 261                    const t = p.cs[j];
 262                    p.cs[j] = t +% p.cs[j + l];
 263                    p.cs[j + l] = montReduceLe2Q(zeta * @as(u64, t +% 256 * Q -% p.cs[j + l]));
 264                }
 265            }
 266        }
 267
 268        for (0..N) |j| {
 269            p.cs[j] = montReduceLe2Q(@as(u64, R_OVER_256) * @as(u64, p.cs[j]));
 270        }
 271    }
 272
 273    /// Apply Power2Round to all coefficients
 274    /// Returns both t0 and t1 polynomials
 275    fn power2RoundPoly(p: Poly) struct { t0: Poly, t1: Poly } {
 276        var t0 = Poly.zero;
 277        var t1 = Poly.zero;
 278        for (0..N) |i| {
 279            const result = power2Round(p.cs[i]);
 280            t0.cs[i] = result.a0_plus_q;
 281            t1.cs[i] = result.a1;
 282        }
 283        return .{ .t0 = t0, .t1 = t1 };
 284    }
 285
 286    // Check if infinity norm exceeds bound
 287    fn exceeds(p: Poly, bound: u32) bool {
 288        var result: u32 = 0;
 289        for (0..N) |i| {
 290            const x = @as(i32, @intCast((Q - 1) / 2)) - @as(i32, @intCast(p.cs[i]));
 291            const abs_x = x ^ (x >> 31);
 292            const norm = @as(i32, @intCast((Q - 1) / 2)) - abs_x;
 293            const exceeds_bit = @intFromBool(@as(u32, @intCast(norm)) >= bound);
 294            result |= exceeds_bit;
 295        }
 296        return result != 0;
 297    }
 298};
 299
 300fn PolyVec(comptime len: u8) type {
 301    return struct {
 302        ps: [len]Poly,
 303
 304        const Self = @This();
 305        const zero: Self = .{ .ps = .{Poly.zero} ** len };
 306
 307        /// Apply a unary operation to each polynomial in the vector
 308        fn map(v: Self, comptime op: fn (Poly) Poly) Self {
 309            var ret: Self = undefined;
 310            inline for (0..len) |i| {
 311                ret.ps[i] = op(v.ps[i]);
 312            }
 313            return ret;
 314        }
 315
 316        /// Apply a binary operation pairwise to two vectors
 317        fn mapBinary(a: Self, b: Self, comptime op: fn (Poly, Poly) Poly) Self {
 318            var ret: Self = undefined;
 319            inline for (0..len) |i| {
 320                ret.ps[i] = op(a.ps[i], b.ps[i]);
 321            }
 322            return ret;
 323        }
 324
 325        /// Apply a binary operation between a vector and a scalar polynomial
 326        fn mapBinaryPoly(v: Self, scalar: Poly, comptime op: fn (Poly, Poly) Poly) Self {
 327            var ret: Self = undefined;
 328            inline for (0..len) |i| {
 329                ret.ps[i] = op(v.ps[i], scalar);
 330            }
 331            return ret;
 332        }
 333
 334        fn add(a: Self, b: Self) Self {
 335            return mapBinary(a, b, Poly.add);
 336        }
 337
 338        fn sub(a: Self, b: Self) Self {
 339            return mapBinary(a, b, Poly.sub);
 340        }
 341
 342        fn ntt(v: Self) Self {
 343            return map(v, Poly.ntt);
 344        }
 345
 346        fn invNTT(v: Self) Self {
 347            return map(v, Poly.invNTT);
 348        }
 349
 350        fn normalize(v: Self) Self {
 351            return map(v, Poly.normalize);
 352        }
 353
 354        fn reduceLe2Q(v: Self) Self {
 355            return map(v, Poly.reduceLe2Q);
 356        }
 357
 358        fn normalizeAssumingLe2Q(v: Self) Self {
 359            return map(v, Poly.normalizeAssumingLe2Q);
 360        }
 361
 362        // Check if any polynomial in the vector exceeds the bound
 363        fn exceeds(v: Self, bound: u32) bool {
 364            var result = false;
 365            for (0..len) |i| {
 366                result = result or v.ps[i].exceeds(bound);
 367            }
 368            return result;
 369        }
 370
 371        /// Apply Power2Round to each polynomial in the vector
 372        /// Returns both t0 and t1 vectors
 373        fn power2Round(v: Self, t0_out: *Self) Self {
 374            var t1: Self = undefined;
 375            for (0..len) |i| {
 376                const result = v.ps[i].power2RoundPoly();
 377                t0_out.ps[i] = result.t0;
 378                t1.ps[i] = result.t1;
 379            }
 380            return t1;
 381        }
 382
 383        /// Generic packing function for vectors
 384        fn packWith(
 385            v: Self,
 386            buf: []u8,
 387            comptime poly_size: usize,
 388            comptime pack_fn: fn (Poly, []u8) void,
 389        ) void {
 390            inline for (0..len) |i| {
 391                const offset = i * poly_size;
 392                pack_fn(v.ps[i], buf[offset..][0..poly_size]);
 393            }
 394        }
 395
 396        /// Generic unpacking function for vectors
 397        fn unpackWith(
 398            comptime poly_size: usize,
 399            comptime unpack_fn: fn ([]const u8) Poly,
 400            buf: []const u8,
 401        ) Self {
 402            var result: Self = undefined;
 403            inline for (0..len) |i| {
 404                const offset = i * poly_size;
 405                result.ps[i] = unpack_fn(buf[offset..][0..poly_size]);
 406            }
 407            return result;
 408        }
 409
 410        /// Pack T1 vector to bytes
 411        fn packT1(v: Self, buf: []u8) void {
 412            const poly_size = (N * (Q_BITS - D)) / 8;
 413            packWith(v, buf, poly_size, polyPackT1);
 414        }
 415
 416        /// Unpack T1 vector from bytes
 417        fn unpackT1(bytes: []const u8) Self {
 418            const poly_size = (N * (Q_BITS - D)) / 8;
 419            return unpackWith(poly_size, polyUnpackT1, bytes);
 420        }
 421
 422        /// Pack T0 vector to bytes
 423        fn packT0(v: Self, buf: []u8) void {
 424            const poly_size = (N * D) / 8;
 425            packWith(v, buf, poly_size, polyPackT0);
 426        }
 427
 428        /// Unpack T0 vector from bytes
 429        fn unpackT0(buf: []const u8) Self {
 430            const poly_size = (N * D) / 8;
 431            return unpackWith(poly_size, polyUnpackT0, buf);
 432        }
 433
 434        /// Pack vector with coefficients in [-eta, eta]
 435        fn packLeqEta(v: Self, comptime eta: u8, buf: []u8) void {
 436            const poly_size = if (eta == 2) 96 else 128;
 437            const pack_fn = struct {
 438                fn pack(p: Poly, b: []u8) void {
 439                    polyPackLeqEta(p, eta, b);
 440                }
 441            }.pack;
 442            packWith(v, buf, poly_size, pack_fn);
 443        }
 444
 445        /// Unpack vector with coefficients in [-eta, eta]
 446        fn unpackLeqEta(comptime eta: u8, buf: []const u8) Self {
 447            const poly_size = if (eta == 2) 96 else 128;
 448            const unpack_fn = struct {
 449                fn unpack(b: []const u8) Poly {
 450                    return polyUnpackLeqEta(eta, b);
 451                }
 452            }.unpack;
 453            return unpackWith(poly_size, unpack_fn, buf);
 454        }
 455
 456        /// Pack vector of polynomials with coefficients < gamma1
 457        fn packLeGamma1(v: Self, comptime gamma1_bits: u8, buf: []u8) void {
 458            const poly_size = ((gamma1_bits + 1) * N) / 8;
 459            const pack_fn = struct {
 460                fn pack(p: Poly, b: []u8) void {
 461                    polyPackLeGamma1(p, gamma1_bits, b);
 462                }
 463            }.pack;
 464            packWith(v, buf, poly_size, pack_fn);
 465        }
 466
 467        /// Unpack vector of polynomials with coefficients < gamma1
 468        fn unpackLeGamma1(comptime gamma1_bits: u8, buf: []const u8) Self {
 469            const poly_size = ((gamma1_bits + 1) * N) / 8;
 470            const unpack_fn = struct {
 471                fn unpack(b: []const u8) Poly {
 472                    return polyUnpackLeGamma1(gamma1_bits, b);
 473                }
 474            }.unpack;
 475            return unpackWith(poly_size, unpack_fn, buf);
 476        }
 477
 478        /// Pack high bits w1 for signature verification
 479        fn packW1(v: Self, comptime gamma1_bits: u8, buf: []u8) void {
 480            const poly_size = (N * (Q_BITS - gamma1_bits)) / 8;
 481            const pack_fn = struct {
 482                fn pack(p: Poly, b: []u8) void {
 483                    polyPackW1(p, gamma1_bits, b);
 484                }
 485            }.pack;
 486            packWith(v, buf, poly_size, pack_fn);
 487        }
 488
 489        /// Decompose each polynomial in the vector into high and low bits
 490        fn decomposeVec(v: Self, comptime gamma2: u32, w0_out: *Self) Self {
 491            var w1: Self = undefined;
 492            for (0..len) |i| {
 493                for (0..N) |j| {
 494                    const r = decompose(v.ps[i].cs[j], gamma2);
 495                    w0_out.ps[i].cs[j] = r.a0_plus_q;
 496                    w1.ps[i].cs[j] = r.a1;
 497                }
 498            }
 499            return w1;
 500        }
 501
 502        /// Create hints for vector, returns hint population count
 503        fn makeHintVec(w0mcs2pct0: Self, w1: Self, comptime gamma2: u32) struct { hint: Self, pop: u32 } {
 504            var hint: Self = undefined;
 505            var pop: u32 = 0;
 506            for (0..len) |i| {
 507                const result = polyMakeHint(w0mcs2pct0.ps[i], w1.ps[i], gamma2);
 508                hint.ps[i] = result.hint;
 509                pop += result.count;
 510            }
 511            return .{ .hint = hint, .pop = pop };
 512        }
 513
 514        /// Apply hints to recover high bits
 515        fn useHint(v: Self, hint: Self, comptime gamma2: u32) Self {
 516            var result: Self = undefined;
 517            for (0..len) |i| {
 518                result.ps[i] = polyUseHint(v.ps[i], hint.ps[i], gamma2);
 519            }
 520            return result;
 521        }
 522
 523        /// Multiply vector by 2^D (left shift)
 524        fn mulBy2toD(v: Self) Self {
 525            var result: Self = undefined;
 526            for (0..len) |i| {
 527                for (0..N) |j| {
 528                    result.ps[i].cs[j] = v.ps[i].cs[j] << D;
 529                }
 530            }
 531            return result;
 532        }
 533
 534        /// Sample vector with coefficients uniformly in (-gamma1, gamma1]
 535        /// Wraps expandMask (FIPS 204: ExpandMask)
 536        fn deriveUniformLeGamma1(comptime gamma1_bits: u8, seed: *const [64]u8, nonce: u16) Self {
 537            var result: Self = undefined;
 538            for (0..len) |i| {
 539                result.ps[i] = expandMask(gamma1_bits, seed, nonce + @as(u16, @intCast(i)));
 540            }
 541            return result;
 542        }
 543
 544        /// Pack hints into bytes
 545        /// Format: for each polynomial, find positions where hint[i]=1, encode those positions
 546        fn packHint(v: Self, comptime omega: u16, buf: []u8) bool {
 547            var idx: usize = 0;
 548            var count: u32 = 0;
 549
 550            for (0..len) |i| {
 551                for (0..N) |j| {
 552                    if (v.ps[i].cs[j] != 0) {
 553                        count += 1;
 554                    }
 555                }
 556            }
 557
 558            if (count > omega) {
 559                return false;
 560            }
 561
 562            // Hint encoding format per FIPS 204:
 563            // First omega bytes: positions of set bits across all polynomials
 564            // Last len bytes: boundary indices showing where each polynomial's hints end
 565            for (0..len) |i| {
 566                for (0..N) |j| {
 567                    if (v.ps[i].cs[j] != 0) {
 568                        buf[idx] = @intCast(j);
 569                        idx += 1;
 570                    }
 571                }
 572                buf[omega + i] = @intCast(idx);
 573            }
 574
 575            while (idx < omega) : (idx += 1) {
 576                buf[idx] = 0;
 577            }
 578
 579            return true;
 580        }
 581
 582        /// Unpack hints from bytes
 583        fn unpackHint(comptime omega: u16, buf: []const u8) ?Self {
 584            var result: Self = .{ .ps = .{Poly.zero} ** len };
 585            var prev_sop: u8 = 0; // previous switch-over-point
 586
 587            for (0..len) |i| {
 588                const sop = buf[omega + i]; // switch-over-point
 589                if (sop < prev_sop or sop > omega) {
 590                    return null; // ensures switch-over-points are increasing
 591                }
 592
 593                var j = prev_sop;
 594                while (j < sop) : (j += 1) {
 595                    // Validation: indices must be strictly increasing within each polynomial
 596                    if (j > prev_sop and buf[j] <= buf[j - 1]) {
 597                        return null;
 598                    }
 599                    const pos = buf[j];
 600                    if (pos >= N) {
 601                        return null;
 602                    }
 603                    result.ps[i].cs[pos] = 1;
 604                }
 605                prev_sop = sop;
 606            }
 607
 608            var j = prev_sop;
 609            while (j < omega) : (j += 1) {
 610                if (buf[j] != 0) {
 611                    return null;
 612                }
 613            }
 614
 615            return result;
 616        }
 617    };
 618}
 619
 620// Matrix of k x l polynomials
 621
 622fn Mat(comptime k: u8, comptime l: u8) type {
 623    return struct {
 624        rows: [k]PolyVec(l),
 625
 626        const Self = @This();
 627        const VecL = PolyVec(l);
 628        const VecK = PolyVec(k);
 629
 630        /// Expand matrix A from seed rho using SHAKE-128
 631        /// This is the ExpandA function from FIPS 204
 632        fn derive(rho: *const [32]u8) Self {
 633            var m: Self = undefined;
 634            for (0..k) |i| {
 635                if (i + 1 < k) {
 636                    @prefetch(&m.rows[i + 1], .{ .rw = .write, .locality = 2 });
 637                }
 638                for (0..l) |j| {
 639                    // Nonce is i*256 + j
 640                    const nonce: u16 = (@as(u16, @intCast(i)) << 8) | @as(u16, @intCast(j));
 641                    m.rows[i].ps[j] = polyDeriveUniform(rho, nonce);
 642                }
 643            }
 644            return m;
 645        }
 646
 647        /// Multiply matrix by vector in NTT domain and return result in regular domain.
 648        /// Takes a vector in NTT form and returns the product in regular form.
 649        fn mulVec(self: Self, v_hat: VecL) VecK {
 650            var result = VecK.zero;
 651            for (0..k) |i| {
 652                result.ps[i] = dotHat(l, self.rows[i], v_hat);
 653                result.ps[i] = result.ps[i].reduceLe2Q();
 654                result.ps[i] = result.ps[i].invNTT();
 655            }
 656            return result;
 657        }
 658
 659        /// Multiply matrix by vector in NTT domain and return result in NTT domain.
 660        /// Takes a vector in NTT form and returns the product in NTT form.
 661        fn mulVecHat(self: Self, v_hat: VecL) VecK {
 662            var result: VecK = undefined;
 663            for (0..k) |i| {
 664                result.ps[i] = dotHat(l, self.rows[i], v_hat);
 665            }
 666            return result;
 667        }
 668    };
 669}
 670
 671// Dot product in NTT domain
 672fn dotHat(comptime len: u8, a: PolyVec(len), b: PolyVec(len)) Poly {
 673    var ret = Poly.zero;
 674    for (0..len) |i| {
 675        const prod = a.ps[i].mulHat(b.ps[i]);
 676        ret = ret.add(prod);
 677    }
 678    return ret;
 679}
 680
 681// Modular arithmetic operations
 682
 683// Reduce x to [0, 2q) using the fact that 2^23 = 2^13 - 1 (mod q)
 684fn le2Q(x: u32) u32 {
 685    // Write x = x1 * 2^23 + x2 with x2 < 2^23 and x1 < 2^9
 686    // Then x = x2 + x1 * 2^13 - x1 (mod q)
 687    // and x2 + x1 * 2^13 - x1 <= 2^23 + 2^13 < 2q
 688    const x1 = x >> 23;
 689    const x2 = x & 0x7FFFFF; // 2^23 - 1
 690    return x2 +% (x1 << 13) -% x1;
 691}
 692
 693// Reduce x to [0, q)
 694fn modQ(x: u32) u32 {
 695    return le2qModQ(le2Q(x));
 696}
 697
 698// Given x < 2q, reduce to [0, q)
 699fn le2qModQ(x: u32) u32 {
 700    const r = x -% Q;
 701    const mask = signMask(u32, r);
 702    return r +% (mask & Q);
 703}
 704
 705// Montgomery reduction: for x < q*2^32, return y < 2q where y ≡ x*R^(-1) (mod q)
 706// where R = 2^32. This is used for efficient modular multiplication in NTT operations.
 707fn montReduceLe2Q(x: u64) u32 {
 708    const m = (x *% Q_INV) & 0xffffffff;
 709    return @truncate((x +% m * @as(u64, Q)) >> 32);
 710}
 711
 712// Precomputed zetas for NTT (Montgomery form)
 713// zetas[i] = zeta^brv(i) * R mod q
 714const zetas = computeZetas();
 715
 716fn computeZetas() [N]u32 {
 717    @setEvalBranchQuota(100000);
 718    var ret: [N]u32 = undefined;
 719
 720    for (0..N) |i| {
 721        const brv_i = @bitReverse(@as(u8, @intCast(i)));
 722        const power = modularPow(u32, ZETA, brv_i, Q);
 723        ret[i] = toMont(power);
 724    }
 725
 726    return ret;
 727}
 728
 729// Precomputed inverse zetas for inverse NTT
 730const inv_zetas = computeInvZetas();
 731
 732fn computeInvZetas() [N]u32 {
 733    @setEvalBranchQuota(100000);
 734    var ret: [N]u32 = undefined;
 735
 736    const inv_zeta = modularInverse(u32, ZETA, Q);
 737
 738    for (0..N) |i| {
 739        const idx = 255 - i;
 740        const brv_idx = @bitReverse(@as(u8, @intCast(idx)));
 741
 742        // Exponent is -(brv_idx - 256) = 256 - brv_idx
 743        const exp: u32 = @as(u32, 256) - brv_idx;
 744
 745        // Compute inv_zeta^exp
 746        const power = modularPow(u32, inv_zeta, exp, Q);
 747
 748        // Convert to Montgomery form
 749        ret[i] = toMont(power);
 750    }
 751
 752    return ret;
 753}
 754
 755// Convert to Montgomery form: x -> x * R mod q
 756fn toMont(x: u32) u32 {
 757    // R = 2^32, R mod q can be computed as:
 758    // 2^32 mod q = 2^32 mod (2^23 - 2^13 + 1)
 759    // Using the identity 2^23 = 2^13 - 1 (mod q), we can reduce 2^32
 760    // But it's easier to just do: return montReduce(x * R^2 mod q)
 761    // where R^2 mod q is precomputed
 762
 763    // Computing R^2 mod q:
 764    // R = 2^32, so R^2 = 2^64
 765    // We can compute this by noting that R mod q first:
 766    // 2^32 = 2^32 mod q
 767    // But let's use a simpler approach: multiply x by R in the Montgomery domain
 768    // Actually, the simplest is: x * R mod q = montReduceLe2Q(x * R^2 mod q)
 769
 770    // Precompute R^2 mod q at comptime
 771    const r_mod_q = comptime blk: {
 772        // 2^32 mod q - compute by successive squaring
 773        var r: u64 = 1;
 774        for (0..32) |_| {
 775            r = (r * 2) % Q;
 776        }
 777        break :blk @as(u32, @intCast(r));
 778    };
 779
 780    const r2_mod_q = comptime blk: {
 781        const r = @as(u64, r_mod_q);
 782        break :blk @as(u32, @intCast((r * r) % Q));
 783    };
 784
 785    return montReduceLe2Q(@as(u64, x) * @as(u64, r2_mod_q));
 786}
 787
 788/// Splits 0 ≤ a < Q into a0 and a1 with a = a1*2^D + a0
 789/// and -2^(D-1) < a0 ≤ 2^(D-1). Returns a0 + Q and a1.
 790/// FIPS 204: Power2Round (Algorithm 19)
 791fn power2Round(a: u32) struct { a0_plus_q: u32, a1: u32 } {
 792    // We effectively compute a0 = a mod± 2^D
 793    //                    and a1 = (a - a0) / 2^D
 794    var a0 = a & ((1 << D) - 1); // a mod 2^D
 795
 796    // a0 is one of 0, 1, ..., 2^(D-1)-1, 2^(D-1), 2^(D-1)+1, ..., 2^D-1
 797    a0 -%= (1 << (D - 1)) + 1;
 798    // now a0 is -2^(D-1)-1, -2^(D-1), ..., -2, -1, 0, ..., 2^(D-1)-2
 799
 800    // Next, add 2^D to those a0 that are negative (seen as i32)
 801    a0 +%= @as(u32, @bitCast(@as(i32, @bitCast(a0)) >> 31)) & (1 << D);
 802    // now a0 is 2^(D-1)-1, 2^(D-1), ..., 2^D-2, 2^D-1, 0, ..., 2^(D-1)-2
 803
 804    a0 -%= (1 << (D - 1)) - 1;
 805    // now a0 is 0, 1, 2, ..., 2^(D-1)-1, 2^(D-1), -2^(D-1)+1, ..., -1
 806
 807    const a0_plus_q = Q +% a0;
 808    const a1 = (a -% a0) >> D;
 809
 810    return .{ .a0_plus_q = a0_plus_q, .a1 = a1 };
 811}
 812
 813/// Splits 0 ≤ a < q into a0 and a1 with a = a1*alpha + a0 with -alpha/2 < a0 ≤ alpha/2,
 814/// except when we would have a1 = (q-1)/alpha in which case a1=0 is taken
 815/// and -alpha/2 ≤ a0 < 0. Returns a0 + q. Note 0 ≤ a1 < (q-1)/alpha.
 816/// Recall alpha = 2*gamma2.
 817fn decompose(a: u32, comptime gamma2: u32) struct { a0_plus_q: u32, a1: u32 } {
 818    const alpha = 2 * gamma2;
 819
 820    // a1 = ⌈a / 128⌉
 821    var a1 = (a + 127) >> 7;
 822
 823    if (alpha == 523776) {
 824        // For ML-DSA-87: gamma2 = 261888, alpha = 523776
 825        // 1025/2^22 is close enough to 1/4092 so that a1 becomes a/alpha rounded down
 826        a1 = ((a1 * 1025 + (1 << 21)) >> 22);
 827
 828        // For the corner-case a1 = (q-1)/alpha = 16, we have to set a1=0
 829        a1 &= 15;
 830    } else if (alpha == 190464) {
 831        // For ML-DSA-65: gamma2 = 95232, alpha = 190464
 832        // 11275/2^24 is close enough to 1/1488 so that a1 becomes a/alpha rounded down
 833        a1 = ((a1 * 11275) + (1 << 23)) >> 24;
 834
 835        // For the corner-case a1 = (q-1)/alpha = 44, we have to set a1=0
 836        a1 ^= @as(u32, @bitCast(@as(i32, @bitCast(43 -% a1)) >> 31)) & a1;
 837    } else {
 838        @compileError("unsupported gamma2/alpha value");
 839    }
 840
 841    var a0_plus_q = a -% a1 * alpha;
 842
 843    // In the corner-case, when we set a1=0, we will incorrectly
 844    // have a0 > (q-1)/2 and we'll need to subtract q. As we
 845    // return a0 + q, that comes down to adding q if a0 < (q-1)/2.
 846    a0_plus_q +%= @as(u32, @bitCast(@as(i32, @bitCast(a0_plus_q -% (Q - 1) / 2)) >> 31)) & Q;
 847
 848    return .{ .a0_plus_q = a0_plus_q, .a1 = a1 };
 849}
 850
 851/// Creates a hint bit to help recover high bits after a small perturbation.
 852/// Given:
 853/// - z0: the modified low bits (r0 - f mod Q) where f is small
 854/// - r1: the original high bits
 855/// Returns 1 if a hint is needed, 0 otherwise.
 856///
 857/// This implements makeHint from FIPS 204. The hint helps recover r1 from
 858/// r' = r - f without knowing f explicitly.
 859fn makeHint(z0: u32, r1: u32, comptime gamma2: u32) u32 {
 860    // If -alpha/2 < r0 - f <= alpha/2, then r1*alpha + r0 - f is a valid
 861    // decomposition of r' with the restrictions of decompose() and so r'1 = r1.
 862    // So the hint should be 0. This is covered by the first two inequalities.
 863    // There is one other case: if r0 - f = -alpha/2, then r1*alpha + r0 - f is
 864    // also a valid decomposition if r1 = 0. In the other cases a one is carried
 865    // and the hint should be 1.
 866
 867    const cond1 = @intFromBool(z0 <= gamma2);
 868    const cond2 = @intFromBool(z0 > Q - gamma2);
 869    const eq_gamma2 = @intFromBool(z0 == Q - gamma2);
 870    const r1_is_zero = @intFromBool(r1 == 0);
 871    const cond3 = eq_gamma2 & r1_is_zero;
 872
 873    return 1 - (cond1 | cond2 | cond3);
 874}
 875
 876/// Uses a hint to reconstruct high bits from a perturbed value.
 877/// Given:
 878/// - rp: the perturbed value (r' = r - f)
 879/// - hint: the hint bit from makeHint
 880/// Returns the reconstructed high bits r1.
 881///
 882/// This implements useHint from FIPS 204.
 883fn useHint(rp: u32, hint: u32, comptime gamma2: u32) u32 {
 884    const decomp = decompose(rp, gamma2);
 885    const rp0_plus_q = decomp.a0_plus_q;
 886    var rp1 = decomp.a1;
 887
 888    if (hint == 0) {
 889        return rp1;
 890    }
 891
 892    // Depending on gamma2, handle the adjustment differently
 893    if (gamma2 == 261888) {
 894        // ML-DSA-65 and ML-DSA-87: max r1 is 15
 895        if (rp0_plus_q > Q) {
 896            rp1 = (rp1 + 1) & 15;
 897        } else {
 898            rp1 = (rp1 -% 1) & 15;
 899        }
 900    } else if (gamma2 == 95232) {
 901        // ML-DSA-44: max r1 is 43
 902        if (rp0_plus_q > Q) {
 903            if (rp1 == 43) {
 904                rp1 = 0;
 905            } else {
 906                rp1 += 1;
 907            }
 908        } else {
 909            if (rp1 == 0) {
 910                rp1 = 43;
 911            } else {
 912                rp1 -= 1;
 913            }
 914        }
 915    } else {
 916        @compileError("unsupported gamma2 value");
 917    }
 918
 919    return rp1;
 920}
 921
 922/// Creates a hint polynomial for the difference between perturbed and original high bits.
 923/// Returns the number of hint bits set to 1 (the population count).
 924///
 925/// This is used during signature generation to create hints that help verification
 926/// recover the high bits without access to the secret.
 927fn polyMakeHint(p0: Poly, p1: Poly, comptime gamma2: u32) struct { hint: Poly, count: u32 } {
 928    var hint = Poly.zero;
 929    var count: u32 = 0;
 930
 931    for (0..N) |i| {
 932        const h = makeHint(p0.cs[i], p1.cs[i], gamma2);
 933        hint.cs[i] = h;
 934        count += h;
 935    }
 936
 937    return .{ .hint = hint, .count = count };
 938}
 939
 940/// Applies hints to reconstruct high bits from a perturbed polynomial.
 941///
 942/// This is used during signature verification to recover the high bits
 943/// using the hints provided in the signature.
 944fn polyUseHint(q: Poly, hint: Poly, comptime gamma2: u32) Poly {
 945    var result = Poly.zero;
 946
 947    for (0..N) |i| {
 948        result.cs[i] = useHint(q.cs[i], hint.cs[i], gamma2);
 949    }
 950
 951    return result;
 952}
 953
 954/// Pack polynomial with coefficients in [Q-eta, Q+eta] into bytes.
 955/// For eta=2: packs coefficients into 3 bits each (96 bytes total)
 956/// For eta=4: packs coefficients into 4 bits each (128 bytes total)
 957/// Assumes coefficients are not normalized, but in [q-η, q+η].
 958fn polyPackLeqEta(p: Poly, comptime eta: u8, buf: []u8) void {
 959    comptime {
 960        if (eta != 2 and eta != 4) {
 961            @compileError("eta must be 2 or 4");
 962        }
 963    }
 964
 965    if (eta == 2) {
 966        // 3 bits per coefficient: pack 8 coefficients into 3 bytes
 967        var j: usize = 0;
 968        var i: usize = 0;
 969        while (i < buf.len) : (i += 3) {
 970            const c0 = Q + eta - p.cs[j];
 971            const c1 = Q + eta - p.cs[j + 1];
 972            const c2 = Q + eta - p.cs[j + 2];
 973            const c3 = Q + eta - p.cs[j + 3];
 974            const c4 = Q + eta - p.cs[j + 4];
 975            const c5 = Q + eta - p.cs[j + 5];
 976            const c6 = Q + eta - p.cs[j + 6];
 977            const c7 = Q + eta - p.cs[j + 7];
 978
 979            buf[i] = @truncate(c0 | (c1 << 3) | (c2 << 6));
 980            buf[i + 1] = @truncate((c2 >> 2) | (c3 << 1) | (c4 << 4) | (c5 << 7));
 981            buf[i + 2] = @truncate((c5 >> 1) | (c6 << 2) | (c7 << 5));
 982
 983            j += 8;
 984        }
 985    } else { // eta == 4
 986        // 4 bits per coefficient: pack 2 coefficients into 1 byte
 987        var j: usize = 0;
 988        for (0..buf.len) |i| {
 989            const c0 = Q + eta - p.cs[j];
 990            const c1 = Q + eta - p.cs[j + 1];
 991            buf[i] = @truncate(c0 | (c1 << 4));
 992            j += 2;
 993        }
 994    }
 995}
 996
 997/// Unpack polynomial with coefficients in [Q-eta, Q+eta] from bytes.
 998/// Output coefficients will not be normalized, but in [q-η, q+η].
 999fn polyUnpackLeqEta(comptime eta: u8, buf: []const u8) Poly {
1000    comptime {
1001        if (eta != 2 and eta != 4) {
1002            @compileError("eta must be 2 or 4");
1003        }
1004    }
1005
1006    var p = Poly.zero;
1007
1008    if (eta == 2) {
1009        // 3 bits per coefficient: unpack 8 coefficients from 3 bytes
1010        var j: usize = 0;
1011        var i: usize = 0;
1012        while (i < buf.len) : (i += 3) {
1013            p.cs[j] = Q + eta - (buf[i] & 7);
1014            p.cs[j + 1] = Q + eta - ((buf[i] >> 3) & 7);
1015            p.cs[j + 2] = Q + eta - ((buf[i] >> 6) | ((buf[i + 1] << 2) & 7));
1016            p.cs[j + 3] = Q + eta - ((buf[i + 1] >> 1) & 7);
1017            p.cs[j + 4] = Q + eta - ((buf[i + 1] >> 4) & 7);
1018            p.cs[j + 5] = Q + eta - ((buf[i + 1] >> 7) | ((buf[i + 2] << 1) & 7));
1019            p.cs[j + 6] = Q + eta - ((buf[i + 2] >> 2) & 7);
1020            p.cs[j + 7] = Q + eta - ((buf[i + 2] >> 5) & 7);
1021            j += 8;
1022        }
1023    } else { // eta == 4
1024        // 4 bits per coefficient: unpack 2 coefficients from 1 byte
1025        var j: usize = 0;
1026        for (0..buf.len) |i| {
1027            p.cs[j] = Q + eta - (buf[i] & 15);
1028            p.cs[j + 1] = Q + eta - (buf[i] >> 4);
1029            j += 2;
1030        }
1031    }
1032
1033    return p;
1034}
1035
1036/// Pack polynomial with coefficients < 1024 (T1) into bytes.
1037/// Packs 10 bits per coefficient: 4 coefficients into 5 bytes.
1038/// Assumes coefficients are normalized.
1039fn polyPackT1(p: Poly, buf: []u8) void {
1040    var j: usize = 0;
1041    var i: usize = 0;
1042    while (i < buf.len) : (i += 5) {
1043        buf[i] = @truncate(p.cs[j]);
1044        buf[i + 1] = @truncate((p.cs[j] >> 8) | (p.cs[j + 1] << 2));
1045        buf[i + 2] = @truncate((p.cs[j + 1] >> 6) | (p.cs[j + 2] << 4));
1046        buf[i + 3] = @truncate((p.cs[j + 2] >> 4) | (p.cs[j + 3] << 6));
1047        buf[i + 4] = @truncate(p.cs[j + 3] >> 2);
1048        j += 4;
1049    }
1050}
1051
1052/// Unpack polynomial with coefficients < 1024 (T1) from bytes.
1053/// Output coefficients will be normalized.
1054fn polyUnpackT1(buf: []const u8) Poly {
1055    var p = Poly.zero;
1056    var j: usize = 0;
1057    var i: usize = 0;
1058    while (i < buf.len) : (i += 5) {
1059        p.cs[j] = (@as(u32, buf[i]) | (@as(u32, buf[i + 1]) << 8)) & 0x3ff;
1060        p.cs[j + 1] = ((@as(u32, buf[i + 1]) >> 2) | (@as(u32, buf[i + 2]) << 6)) & 0x3ff;
1061        p.cs[j + 2] = ((@as(u32, buf[i + 2]) >> 4) | (@as(u32, buf[i + 3]) << 4)) & 0x3ff;
1062        p.cs[j + 3] = ((@as(u32, buf[i + 3]) >> 6) | (@as(u32, buf[i + 4]) << 2)) & 0x3ff;
1063        j += 4;
1064    }
1065    return p;
1066}
1067
1068/// Pack polynomial with coefficients in (-2^(D-1), 2^(D-1)] (T0) into bytes.
1069/// Packs 13 bits per coefficient: 8 coefficients into 13 bytes.
1070/// Assumes coefficients are not normalized, but in (q-2^(D-1), q+2^(D-1)].
1071fn polyPackT0(p: Poly, buf: []u8) void {
1072    const bound = 1 << (D - 1);
1073    var j: usize = 0;
1074    var i: usize = 0;
1075    while (i < buf.len) : (i += 13) {
1076        const p0 = Q + bound - p.cs[j];
1077        const p1 = Q + bound - p.cs[j + 1];
1078        const p2 = Q + bound - p.cs[j + 2];
1079        const p3 = Q + bound - p.cs[j + 3];
1080        const p4 = Q + bound - p.cs[j + 4];
1081        const p5 = Q + bound - p.cs[j + 5];
1082        const p6 = Q + bound - p.cs[j + 6];
1083        const p7 = Q + bound - p.cs[j + 7];
1084
1085        buf[i] = @truncate(p0 >> 0);
1086        buf[i + 1] = @truncate((p0 >> 8) | (p1 << 5));
1087        buf[i + 2] = @truncate(p1 >> 3);
1088        buf[i + 3] = @truncate((p1 >> 11) | (p2 << 2));
1089        buf[i + 4] = @truncate((p2 >> 6) | (p3 << 7));
1090        buf[i + 5] = @truncate(p3 >> 1);
1091        buf[i + 6] = @truncate((p3 >> 9) | (p4 << 4));
1092        buf[i + 7] = @truncate(p4 >> 4);
1093        buf[i + 8] = @truncate((p4 >> 12) | (p5 << 1));
1094        buf[i + 9] = @truncate((p5 >> 7) | (p6 << 6));
1095        buf[i + 10] = @truncate(p6 >> 2);
1096        buf[i + 11] = @truncate((p6 >> 10) | (p7 << 3));
1097        buf[i + 12] = @truncate(p7 >> 5);
1098
1099        j += 8;
1100    }
1101}
1102
1103/// Unpack polynomial with coefficients in (-2^(D-1), 2^(D-1)] (T0) from bytes.
1104/// Output coefficients will not be normalized, but in (-2^(D-1), 2^(D-1)].
1105fn polyUnpackT0(buf: []const u8) Poly {
1106    const bound = 1 << (D - 1);
1107    var p = Poly.zero;
1108    var j: usize = 0;
1109    var i: usize = 0;
1110    while (i < buf.len) : (i += 13) {
1111        p.cs[j] = Q + bound - ((@as(u32, buf[i]) | (@as(u32, buf[i + 1]) << 8)) & 0x1fff);
1112        p.cs[j + 1] = Q + bound - (((@as(u32, buf[i + 1]) >> 5) | (@as(u32, buf[i + 2]) << 3) | (@as(u32, buf[i + 3]) << 11)) & 0x1fff);
1113        p.cs[j + 2] = Q + bound - (((@as(u32, buf[i + 3]) >> 2) | (@as(u32, buf[i + 4]) << 6)) & 0x1fff);
1114        p.cs[j + 3] = Q + bound - (((@as(u32, buf[i + 4]) >> 7) | (@as(u32, buf[i + 5]) << 1) | (@as(u32, buf[i + 6]) << 9)) & 0x1fff);
1115        p.cs[j + 4] = Q + bound - (((@as(u32, buf[i + 6]) >> 4) | (@as(u32, buf[i + 7]) << 4) | (@as(u32, buf[i + 8]) << 12)) & 0x1fff);
1116        p.cs[j + 5] = Q + bound - (((@as(u32, buf[i + 8]) >> 1) | (@as(u32, buf[i + 9]) << 7)) & 0x1fff);
1117        p.cs[j + 6] = Q + bound - (((@as(u32, buf[i + 9]) >> 6) | (@as(u32, buf[i + 10]) << 2) | (@as(u32, buf[i + 11]) << 10)) & 0x1fff);
1118        p.cs[j + 7] = Q + bound - ((@as(u32, buf[i + 11]) >> 3) | (@as(u32, buf[i + 12]) << 5));
1119        j += 8;
1120    }
1121    return p;
1122}
1123
1124/// Convert coefficient from centered representation to non-negative.
1125/// Transforms value from [0,γ₁] ∪ (Q-γ₁, Q) to [0, 2γ₁).
1126fn centeredToPositive(val: u32, comptime gamma1: u32) u32 {
1127    var result = gamma1 -% val;
1128    result +%= (signMask(u32, result) & Q);
1129    return result;
1130}
1131
1132/// Pack polynomial with coefficients in (-gamma1, gamma1] into bytes.
1133/// For gamma1_bits=17: packs 18 bits per coefficient (4 coefficients into 9 bytes)
1134/// For gamma1_bits=19: packs 20 bits per coefficient (2 coefficients into 5 bytes)
1135/// Assumes coefficients are normalized.
1136fn polyPackLeGamma1(p: Poly, comptime gamma1_bits: u8, buf: []u8) void {
1137    const gamma1: u32 = @as(u32, 1) << gamma1_bits;
1138
1139    if (gamma1_bits == 17) {
1140        // Pack 4 coefficients into 9 bytes (18 bits each)
1141        var j: usize = 0;
1142        var i: usize = 0;
1143        while (i < buf.len) : (i += 9) {
1144            // Convert from [0,γ₁] ∪ (Q-γ₁, Q) to [0, 2γ₁)
1145            const p0 = centeredToPositive(p.cs[j], gamma1);
1146            const p1 = centeredToPositive(p.cs[j + 1], gamma1);
1147            const p2 = centeredToPositive(p.cs[j + 2], gamma1);
1148            const p3 = centeredToPositive(p.cs[j + 3], gamma1);
1149
1150            buf[i] = @truncate(p0);
1151            buf[i + 1] = @truncate(p0 >> 8);
1152            buf[i + 2] = @truncate((p0 >> 16) | (p1 << 2));
1153            buf[i + 3] = @truncate(p1 >> 6);
1154            buf[i + 4] = @truncate((p1 >> 14) | (p2 << 4));
1155            buf[i + 5] = @truncate(p2 >> 4);
1156            buf[i + 6] = @truncate((p2 >> 12) | (p3 << 6));
1157            buf[i + 7] = @truncate(p3 >> 2);
1158            buf[i + 8] = @truncate(p3 >> 10);
1159
1160            j += 4;
1161        }
1162    } else if (gamma1_bits == 19) {
1163        // Pack 2 coefficients into 5 bytes (20 bits each)
1164        var j: usize = 0;
1165        var i: usize = 0;
1166        while (i < buf.len) : (i += 5) {
1167            const p0 = centeredToPositive(p.cs[j], gamma1);
1168            const p1 = centeredToPositive(p.cs[j + 1], gamma1);
1169
1170            buf[i] = @truncate(p0);
1171            buf[i + 1] = @truncate(p0 >> 8);
1172            buf[i + 2] = @truncate((p0 >> 16) | (p1 << 4));
1173            buf[i + 3] = @truncate(p1 >> 4);
1174            buf[i + 4] = @truncate(p1 >> 12);
1175
1176            j += 2;
1177        }
1178    } else {
1179        @compileError("gamma1_bits must be 17 or 19");
1180    }
1181}
1182
1183/// Unpack polynomial with coefficients in (-gamma1, gamma1] from bytes.
1184/// Output coefficients will be normalized.
1185fn polyUnpackLeGamma1(comptime gamma1_bits: u8, buf: []const u8) Poly {
1186    const gamma1: u32 = @as(u32, 1) << gamma1_bits;
1187    var p = Poly.zero;
1188
1189    if (gamma1_bits == 17) {
1190        // Unpack 4 coefficients from 9 bytes (18 bits each)
1191        var j: usize = 0;
1192        var i: usize = 0;
1193        while (i < buf.len) : (i += 9) {
1194            var p0 = @as(u32, buf[i]) | (@as(u32, buf[i + 1]) << 8) | ((@as(u32, buf[i + 2]) & 0x3) << 16);
1195            var p1 = (@as(u32, buf[i + 2]) >> 2) | (@as(u32, buf[i + 3]) << 6) | ((@as(u32, buf[i + 4]) & 0xf) << 14);
1196            var p2 = (@as(u32, buf[i + 4]) >> 4) | (@as(u32, buf[i + 5]) << 4) | ((@as(u32, buf[i + 6]) & 0x3f) << 12);
1197            var p3 = (@as(u32, buf[i + 6]) >> 6) | (@as(u32, buf[i + 7]) << 2) | (@as(u32, buf[i + 8]) << 10);
1198
1199            // Convert from [0, 2γ₁) to (-γ₁, γ₁]
1200            p0 = centeredToPositive(p0, gamma1);
1201            p1 = centeredToPositive(p1, gamma1);
1202            p2 = centeredToPositive(p2, gamma1);
1203            p3 = centeredToPositive(p3, gamma1);
1204
1205            p.cs[j] = p0;
1206            p.cs[j + 1] = p1;
1207            p.cs[j + 2] = p2;
1208            p.cs[j + 3] = p3;
1209
1210            j += 4;
1211        }
1212    } else if (gamma1_bits == 19) {
1213        // Unpack 2 coefficients from 5 bytes (20 bits each)
1214        var j: usize = 0;
1215        var i: usize = 0;
1216        while (i < buf.len) : (i += 5) {
1217            var p0 = @as(u32, buf[i]) | (@as(u32, buf[i + 1]) << 8) | ((@as(u32, buf[i + 2]) & 0xf) << 16);
1218            var p1 = (@as(u32, buf[i + 2]) >> 4) | (@as(u32, buf[i + 3]) << 4) | (@as(u32, buf[i + 4]) << 12);
1219
1220            p0 = centeredToPositive(p0, gamma1);
1221            p1 = centeredToPositive(p1, gamma1);
1222
1223            p.cs[j] = p0;
1224            p.cs[j + 1] = p1;
1225
1226            j += 2;
1227        }
1228    } else {
1229        @compileError("gamma1_bits must be 17 or 19");
1230    }
1231
1232    return p;
1233}
1234
1235/// Pack W1 polynomial for verification.
1236/// For gamma1_bits=17: packs 6 bits per coefficient (4 coefficients into 3 bytes)
1237/// For gamma1_bits=19: packs 4 bits per coefficient (2 coefficients into 1 byte)
1238/// Assumes coefficients are normalized.
1239fn polyPackW1(p: Poly, comptime gamma1_bits: u8, buf: []u8) void {
1240    if (gamma1_bits == 17) {
1241        // Pack 4 coefficients into 3 bytes (6 bits each)
1242        var j: usize = 0;
1243        var i: usize = 0;
1244        while (i < buf.len) : (i += 3) {
1245            buf[i] = @truncate(p.cs[j] | (p.cs[j + 1] << 6));
1246            buf[i + 1] = @truncate((p.cs[j + 1] >> 2) | (p.cs[j + 2] << 4));
1247            buf[i + 2] = @truncate((p.cs[j + 2] >> 4) | (p.cs[j + 3] << 2));
1248            j += 4;
1249        }
1250    } else if (gamma1_bits == 19) {
1251        // Pack 2 coefficients into 1 byte (4 bits each) - equivalent to packLe16
1252        var j: usize = 0;
1253        for (0..buf.len) |i| {
1254            buf[i] = @truncate(p.cs[j] | (p.cs[j + 1] << 4));
1255            j += 2;
1256        }
1257    } else {
1258        @compileError("gamma1_bits must be 17 or 19");
1259    }
1260}
1261
1262fn polyDeriveUniform(seed: *const [32]u8, nonce: u16) Poly {
1263    var domain_sep: [2]u8 = undefined;
1264    domain_sep[0] = @truncate(nonce);
1265    domain_sep[1] = @truncate(nonce >> 8);
1266
1267    return sampleUniformRejection(
1268        Poly,
1269        Q,
1270        23,
1271        N,
1272        seed,
1273        &domain_sep,
1274    );
1275}
1276
1277/// Sample p uniformly with coefficients of norm less than or equal to η,
1278/// using the given seed and nonce with SHAKE-256.
1279/// The polynomial will not be normalized, but will have coefficients in [q-η, q+η].
1280/// FIPS 204: ExpandS (Algorithm 27)
1281fn expandS(comptime eta: u8, seed: *const [64]u8, nonce: u16) Poly {
1282    comptime {
1283        if (eta != 2 and eta != 4) {
1284            @compileError("eta must be 2 or 4");
1285        }
1286    }
1287
1288    var p = Poly.zero;
1289    var i: usize = 0;
1290
1291    var buf: [sha3.Shake256.block_length]u8 = undefined; // SHAKE-256 rate is 136 bytes
1292
1293    // Prepare input: seed || nonce (little-endian u16)
1294    var input: [66]u8 = undefined;
1295    @memcpy(input[0..64], seed);
1296    input[64] = @truncate(nonce);
1297    input[65] = @truncate(nonce >> 8);
1298
1299    var h = sha3.Shake256.init(.{});
1300    h.update(&input);
1301
1302    while (i < N) {
1303        h.squeeze(&buf);
1304
1305        // Process buffer: extract two samples per byte (4-bit nibbles)
1306        var j: usize = 0;
1307        while (j < buf.len and i < N) : (j += 1) {
1308            var t1 = @as(u32, buf[j]) & 15;
1309            var t2 = @as(u32, buf[j]) >> 4;
1310
1311            if (eta == 2) {
1312                // For eta=2: reject if t > 14, then reduce mod 5
1313                if (t1 <= 14) {
1314                    t1 -%= ((205 * t1) >> 10) * 5; // reduce mod 5
1315                    p.cs[i] = Q + eta - t1;
1316                    i += 1;
1317                }
1318                if (t2 <= 14 and i < N) {
1319                    t2 -%= ((205 * t2) >> 10) * 5; // reduce mod 5
1320                    p.cs[i] = Q + eta - t2;
1321                    i += 1;
1322                }
1323            } else if (eta == 4) {
1324                // For eta=4: accept if t <= 2*eta = 8
1325                if (t1 <= 2 * eta) {
1326                    p.cs[i] = Q + eta - t1;
1327                    i += 1;
1328                }
1329                if (t2 <= 2 * eta and i < N) {
1330                    p.cs[i] = Q + eta - t2;
1331                    i += 1;
1332                }
1333            }
1334        }
1335    }
1336
1337    return p;
1338}
1339
1340/// Sample p uniformly with τ non-zero coefficients in {Q-1, 1} using SHAKE-256.
1341/// This creates a "ball" polynomial with exactly tau non-zero ±1 coefficients.
1342/// The polynomial will be normalized with coefficients in {0, 1, Q-1}.
1343/// FIPS 204: SampleInBall (Algorithm 18)
1344fn sampleInBall(comptime tau: u16, seed: []const u8) Poly {
1345    var p = Poly.zero;
1346
1347    var buf: [sha3.Shake256.block_length]u8 = undefined; // SHAKE-256 rate is 136 bytes
1348
1349    var h = sha3.Shake256.init(.{});
1350    h.update(seed);
1351    h.squeeze(&buf);
1352
1353    // Extract signs from first 8 bytes
1354    var signs: u64 = 0;
1355    for (0..8) |j| {
1356        signs |= @as(u64, buf[j]) << @intCast(j * 8);
1357    }
1358    var buf_off: usize = 8;
1359
1360    // Generate tau non-zero coefficients using Fisher-Yates shuffle
1361    // Start with N-tau zeros, then add tau ±1 values
1362    var i: u16 = N - tau;
1363    while (i < N) : (i += 1) {
1364        var b: u16 = undefined;
1365
1366        // Find location using rejection sampling
1367        while (true) {
1368            if (buf_off >= buf.len) {
1369                h.squeeze(&buf);
1370                buf_off = 0;
1371            }
1372
1373            b = buf[buf_off];
1374            buf_off += 1;
1375
1376            if (b <= i) {
1377                break;
1378            }
1379        }
1380
1381        // Shuffle: move existing value to position i
1382        p.cs[i] = p.cs[b];
1383
1384        // Set position b to ±1 based on sign bit
1385        p.cs[b] = 1;
1386        const sign_bit: u1 = @truncate(signs);
1387        const mask = bitMask(u32, sign_bit);
1388        p.cs[b] ^= mask & (1 | (Q - 1));
1389        signs >>= 1;
1390    }
1391
1392    return p;
1393}
1394
1395/// Sample a polynomial with coefficients uniformly distributed in (-gamma1, gamma1]
1396/// Used for sampling the masking vector y during signing
1397/// FIPS 204: ExpandMask (Algorithm 28)
1398fn expandMask(comptime gamma1_bits: u8, seed: *const [64]u8, nonce: u16) Poly {
1399    const packed_size = ((gamma1_bits + 1) * N) / 8;
1400    var buf: [packed_size]u8 = undefined;
1401
1402    // Construct IV: seed || nonce (little-endian)
1403    var iv: [66]u8 = undefined;
1404    @memcpy(iv[0..64], seed);
1405    iv[64] = @truncate(nonce & 0xFF);
1406    iv[65] = @truncate(nonce >> 8);
1407
1408    var h = sha3.Shake256.init(.{});
1409    h.update(&iv);
1410    h.squeeze(&buf);
1411
1412    // Unpack the polynomial
1413    return polyUnpackLeGamma1(gamma1_bits, &buf);
1414}
1415
1416fn MLDSAImpl(comptime p: Params) type {
1417    return struct {
1418        pub const params = p;
1419        pub const name = p.name;
1420        pub const gamma1: u32 = @as(u32, 1) << p.gamma1_bits;
1421        pub const beta: u32 = p.tau * p.eta;
1422        pub const alpha: u32 = 2 * p.gamma2;
1423
1424        const Self = @This();
1425        const PolyVecL = PolyVec(p.l);
1426        const PolyVecK = PolyVec(p.k);
1427        const MatKxL = Mat(p.k, p.l);
1428
1429        /// Length of the seed used for deterministic key generation (32 bytes).
1430        pub const seed_length: usize = 32;
1431
1432        /// Length (in bytes) of optional random bytes, for non-deterministic signatures.
1433        pub const noise_length = 32;
1434
1435        /// Size of an encoded public key in bytes.
1436        pub const public_key_bytes: usize = 32 + polyT1PackedSize() * p.k;
1437
1438        /// Size of an encoded secret key in bytes.
1439        pub const private_key_bytes: usize = 32 + 32 + p.tr_size +
1440            polyLeqEtaPackedSize() * (p.l + p.k) + polyT0PackedSize() * p.k;
1441
1442        /// Size of an encoded signature in bytes.
1443        pub const signature_bytes: usize = p.ctilde_size +
1444            polyLeGamma1PackedSize() * p.l + p.omega + p.k;
1445
1446        // Packed sizes for different polynomial representations
1447        fn polyLeqEtaPackedSize() usize {
1448            // For eta=2: 3 bits per coefficient (values in [0,4])
1449            // For eta=4: 4 bits per coefficient (values in [0,8])
1450            const double_eta_bits = if (p.eta == 2) 3 else 4;
1451            return (N * double_eta_bits) / 8;
1452        }
1453
1454        fn polyLeGamma1PackedSize() usize {
1455            return ((p.gamma1_bits + 1) * N) / 8;
1456        }
1457
1458        fn polyT1PackedSize() usize {
1459            return (N * (Q_BITS - D)) / 8;
1460        }
1461
1462        fn polyT0PackedSize() usize {
1463            return (N * D) / 8;
1464        }
1465
1466        fn polyW1PackedSize() usize {
1467            return (N * (Q_BITS - p.gamma1_bits)) / 8;
1468        }
1469
1470        /// Helper function to compute CRH (Collision Resistant Hash) using SHAKE-256.
1471        /// This consolidates the repeated pattern of init-update-squeeze for hash operations.
1472        fn crh(comptime outsize: usize, inputs: anytype) [outsize]u8 {
1473            var h = sha3.Shake256.init(.{});
1474            inline for (inputs) |input| {
1475                h.update(input);
1476            }
1477            var out: [outsize]u8 = undefined;
1478            h.squeeze(&out);
1479            return out;
1480        }
1481
1482        /// Helper function to compute t = As1 + s2.
1483        /// This is used during key generation and public key reconstruction.
1484        fn computeT(A: MatKxL, s1_hat: PolyVecL, s2: PolyVecK) PolyVecK {
1485            const t = A.mulVec(s1_hat).add(s2);
1486            return t.normalize();
1487        }
1488
1489        /// ML-DSA public key
1490        pub const PublicKey = struct {
1491            /// Size of the encoded public key in bytes
1492            pub const encoded_length: usize = 32 + polyT1PackedSize() * p.k;
1493
1494            rho: [32]u8, // Seed for matrix A
1495            t1: PolyVecK, // High bits of t = As1 + s2
1496
1497            // Cached values
1498            t1_packed: [polyT1PackedSize() * p.k]u8,
1499            A: MatKxL,
1500            tr: [p.tr_size]u8, // CRH(rho || t1)
1501
1502            /// Encode public key to bytes
1503            pub fn toBytes(self: PublicKey) [encoded_length]u8 {
1504                var out: [encoded_length]u8 = undefined;
1505                @memcpy(out[0..32], &self.rho);
1506                @memcpy(out[32..], &self.t1_packed);
1507                return out;
1508            }
1509
1510            /// Decode public key from bytes
1511            pub fn fromBytes(bytes: [encoded_length]u8) !PublicKey {
1512                var pk: PublicKey = undefined;
1513                @memcpy(&pk.rho, bytes[0..32]);
1514                @memcpy(&pk.t1_packed, bytes[32..]);
1515
1516                pk.t1 = PolyVecK.unpackT1(pk.t1_packed[0..]);
1517                pk.A = MatKxL.derive(&pk.rho);
1518                pk.tr = crh(p.tr_size, .{&bytes});
1519
1520                return pk;
1521            }
1522        };
1523
1524        /// ML-DSA secret key
1525        pub const SecretKey = struct {
1526            /// Size of the encoded secret key in bytes
1527            pub const encoded_length: usize = 32 + 32 + p.tr_size +
1528                polyLeqEtaPackedSize() * (p.l + p.k) + polyT0PackedSize() * p.k;
1529
1530            rho: [32]u8, // Seed for matrix A
1531            key: [32]u8, // Seed for signature generation randomness
1532            tr: [p.tr_size]u8, // CRH(rho || t1)
1533            s1: PolyVecL, // Secret vector 1
1534            s2: PolyVecK, // Secret vector 2
1535            t0: PolyVecK, // Low bits of t = As1 + s2
1536
1537            // Cached values (in NTT domain)
1538            A: MatKxL,
1539            s1_hat: PolyVecL,
1540            s2_hat: PolyVecK,
1541            t0_hat: PolyVecK,
1542
1543            /// Encode secret key to bytes
1544            pub fn toBytes(self: SecretKey) [encoded_length]u8 {
1545                var out: [encoded_length]u8 = undefined;
1546                var offset: usize = 0;
1547
1548                @memcpy(out[offset .. offset + 32], &self.rho);
1549                offset += 32;
1550
1551                @memcpy(out[offset .. offset + 32], &self.key);
1552                offset += 32;
1553
1554                @memcpy(out[offset .. offset + p.tr_size], &self.tr);
1555                offset += p.tr_size;
1556
1557                if (p.eta == 2) {
1558                    self.s1.packLeqEta(2, out[offset..][0 .. p.l * polyLeqEtaPackedSize()]);
1559                } else {
1560                    self.s1.packLeqEta(4, out[offset..][0 .. p.l * polyLeqEtaPackedSize()]);
1561                }
1562                offset += p.l * polyLeqEtaPackedSize();
1563
1564                if (p.eta == 2) {
1565                    self.s2.packLeqEta(2, out[offset..][0 .. p.k * polyLeqEtaPackedSize()]);
1566                } else {
1567                    self.s2.packLeqEta(4, out[offset..][0 .. p.k * polyLeqEtaPackedSize()]);
1568                }
1569                offset += p.k * polyLeqEtaPackedSize();
1570
1571                self.t0.packT0(out[offset..][0 .. p.k * polyT0PackedSize()]);
1572                offset += p.k * polyT0PackedSize();
1573
1574                return out;
1575            }
1576
1577            /// Decode secret key from bytes
1578            pub fn fromBytes(bytes: [encoded_length]u8) !SecretKey {
1579                var sk: SecretKey = undefined;
1580                var offset: usize = 0;
1581
1582                @memcpy(&sk.rho, bytes[offset .. offset + 32]);
1583                offset += 32;
1584
1585                @memcpy(&sk.key, bytes[offset .. offset + 32]);
1586                offset += 32;
1587
1588                @memcpy(&sk.tr, bytes[offset .. offset + p.tr_size]);
1589                offset += p.tr_size;
1590
1591                sk.s1 = if (p.eta == 2)
1592                    PolyVecL.unpackLeqEta(2, bytes[offset..][0 .. p.l * polyLeqEtaPackedSize()])
1593                else
1594                    PolyVecL.unpackLeqEta(4, bytes[offset..][0 .. p.l * polyLeqEtaPackedSize()]);
1595                offset += p.l * polyLeqEtaPackedSize();
1596
1597                sk.s2 = if (p.eta == 2)
1598                    PolyVecK.unpackLeqEta(2, bytes[offset..][0 .. p.k * polyLeqEtaPackedSize()])
1599                else
1600                    PolyVecK.unpackLeqEta(4, bytes[offset..][0 .. p.k * polyLeqEtaPackedSize()]);
1601                offset += p.k * polyLeqEtaPackedSize();
1602
1603                sk.t0 = PolyVecK.unpackT0(bytes[offset..][0 .. p.k * polyT0PackedSize()]);
1604                offset += p.k * polyT0PackedSize();
1605
1606                // Compute cached NTT values for efficient signing
1607                sk.A = MatKxL.derive(&sk.rho);
1608                sk.s1_hat = sk.s1.ntt();
1609                sk.s2_hat = sk.s2.ntt();
1610                sk.t0_hat = sk.t0.ntt();
1611
1612                return sk;
1613            }
1614
1615            /// Compute the public key from this private key
1616            pub fn public(self: *const SecretKey) PublicKey {
1617                var pk: PublicKey = undefined;
1618                pk.rho = self.rho;
1619                pk.A = self.A;
1620                pk.tr = self.tr;
1621
1622                // Reconstruct t = As1 + s2, then extract high bits t1
1623                // Using power2Round: t = t1 * 2^D + t0
1624                const t = computeT(self.A, self.s1_hat, self.s2);
1625
1626                var t0_unused: PolyVecK = undefined;
1627                pk.t1 = t.power2Round(&t0_unused);
1628                pk.t1.packT1(&pk.t1_packed);
1629
1630                return pk;
1631            }
1632
1633            /// Create a Signer for incrementally signing a message.
1634            /// The noise parameter can be null for deterministic signatures,
1635            /// or provide randomness for hedged signatures (recommended for fault attack resistance).
1636            pub fn signer(self: *const SecretKey, noise: ?[noise_length]u8) !Signer {
1637                return self.signerWithContext(noise, "");
1638            }
1639
1640            /// Create a Signer for incrementally signing a message with context.
1641            /// The noise parameter can be null for deterministic signatures,
1642            /// or provide randomness for hedged signatures (recommended for fault attack resistance).
1643            /// The context parameter is an optional context string (max 255 bytes).
1644            pub fn signerWithContext(self: *const SecretKey, noise: ?[noise_length]u8, context: []const u8) ContextTooLongError!Signer {
1645                return Signer.init(self, noise, context);
1646            }
1647        };
1648
1649        /// Generate a new key pair from a seed (deterministic)
1650        pub fn newKeyFromSeed(seed: *const [seed_length]u8) struct { pk: PublicKey, sk: SecretKey } {
1651            var sk: SecretKey = undefined;
1652            var pk: PublicKey = undefined;
1653
1654            // NIST mode: expand seed || k || l using SHAKE-256 to get 128-byte expanded seed
1655            const e_seed = crh(128, .{ seed, &[_]u8{ p.k, p.l } });
1656
1657            @memcpy(&pk.rho, e_seed[0..32]);
1658            const s_seed = e_seed[32..96];
1659            @memcpy(&sk.key, e_seed[96..128]);
1660            @memcpy(&sk.rho, &pk.rho);
1661
1662            sk.A = MatKxL.derive(&pk.rho);
1663            pk.A = sk.A;
1664
1665            const s_seed_array: *const [64]u8 = s_seed[0..64];
1666            for (0..p.l) |i| {
1667                sk.s1.ps[i] = expandS(p.eta, s_seed_array, @intCast(i));
1668            }
1669
1670            for (0..p.k) |i| {
1671                sk.s2.ps[i] = expandS(p.eta, s_seed_array, @intCast(p.l + i));
1672            }
1673
1674            sk.s1_hat = sk.s1.ntt();
1675            sk.s2_hat = sk.s2.ntt();
1676
1677            const t = computeT(sk.A, sk.s1_hat, sk.s2);
1678
1679            pk.t1 = t.power2Round(&sk.t0);
1680            sk.t0_hat = sk.t0.ntt();
1681            pk.t1.packT1(&pk.t1_packed);
1682
1683            // tr = H(pk) = H(rho || t1)
1684            const pk_bytes = pk.toBytes();
1685            const tr = crh(p.tr_size, .{&pk_bytes});
1686            sk.tr = tr;
1687            pk.tr = tr;
1688
1689            return .{ .pk = pk, .sk = sk };
1690        }
1691
1692        /// ML-DSA signature
1693        pub const Signature = struct {
1694            /// Size of the encoded signature in bytes
1695            pub const encoded_length: usize = p.ctilde_size +
1696                polyLeGamma1PackedSize() * p.l + p.omega + p.k;
1697
1698            c_tilde: [p.ctilde_size]u8, // Challenge hash
1699            z: PolyVecL, // Response vector
1700            hint: PolyVecK, // Hint vector
1701
1702            /// Encode signature to bytes
1703            pub fn toBytes(self: Signature) [encoded_length]u8 {
1704                var out: [encoded_length]u8 = undefined;
1705                var offset: usize = 0;
1706
1707                @memcpy(out[offset .. offset + p.ctilde_size], &self.c_tilde);
1708                offset += p.ctilde_size;
1709
1710                self.z.packLeGamma1(p.gamma1_bits, out[offset .. offset + polyLeGamma1PackedSize() * p.l]);
1711                offset += polyLeGamma1PackedSize() * p.l;
1712
1713                _ = self.hint.packHint(p.omega, out[offset..]);
1714
1715                return out;
1716            }
1717
1718            /// Decode signature from bytes
1719            pub fn fromBytes(bytes: [encoded_length]u8) EncodingError!Signature {
1720                var sig: Signature = undefined;
1721                var offset: usize = 0;
1722
1723                @memcpy(&sig.c_tilde, bytes[offset .. offset + p.ctilde_size]);
1724                offset += p.ctilde_size;
1725
1726                sig.z = PolyVecL.unpackLeGamma1(p.gamma1_bits, bytes[offset .. offset + polyLeGamma1PackedSize() * p.l]);
1727                offset += polyLeGamma1PackedSize() * p.l;
1728
1729                // Validate ||z||_inf < gamma1 - beta per FIPS 204
1730                if (sig.z.exceeds(gamma1 - beta)) {
1731                    return error.InvalidEncoding;
1732                }
1733
1734                sig.hint = PolyVecK.unpackHint(p.omega, bytes[offset..]) orelse return error.InvalidEncoding;
1735
1736                return sig;
1737            }
1738
1739            pub const VerifyError = Verifier.InitError || Verifier.VerifyError;
1740
1741            /// Verify this signature against a message and public key.
1742            /// Returns an error if the signature is invalid.
1743            pub fn verify(
1744                sig: Signature,
1745                msg: []const u8,
1746                public_key: PublicKey,
1747            ) VerifyError!void {
1748                return sig.verifyWithContext(msg, public_key, "");
1749            }
1750
1751            /// Verify this signature against a message and public key with context.
1752            /// Returns an error if the signature is invalid.
1753            /// The context parameter is an optional context string (max 255 bytes).
1754            pub fn verifyWithContext(
1755                sig: Signature,
1756                msg: []const u8,
1757                public_key: PublicKey,
1758                context: []const u8,
1759            ) VerifyError!void {
1760                if (context.len > 255) {
1761                    return error.SignatureVerificationFailed;
1762                }
1763
1764                var h = sha3.Shake256.init(.{});
1765                h.update(&public_key.tr);
1766                h.update(&[_]u8{0}); // Domain separator: 0 for pure ML-DSA
1767                h.update(&[_]u8{@intCast(context.len)});
1768                if (context.len > 0) {
1769                    h.update(context);
1770                }
1771                h.update(msg);
1772                var mu: [64]u8 = undefined;
1773                h.squeeze(&mu);
1774
1775                const z_hat = sig.z.ntt();
1776                const Az = public_key.A.mulVecHat(z_hat);
1777
1778                // Compute w' ≈ Az - 2^d·c·t1 (approximate w used in signing)
1779                var Az2dct1 = public_key.t1.mulBy2toD();
1780                Az2dct1 = Az2dct1.ntt();
1781                const c_poly = sampleInBall(p.tau, &sig.c_tilde);
1782                const c_hat = c_poly.ntt();
1783                for (0..p.k) |i| {
1784                    Az2dct1.ps[i] = Az2dct1.ps[i].mulHat(c_hat);
1785                }
1786                Az2dct1 = Az.sub(Az2dct1);
1787                Az2dct1 = Az2dct1.reduceLe2Q();
1788                Az2dct1 = Az2dct1.invNTT();
1789                Az2dct1 = Az2dct1.normalizeAssumingLe2Q();
1790
1791                // Apply hints to recover high bits w1'
1792                var w1_prime = Az2dct1.useHint(sig.hint, p.gamma2);
1793                var w1_packed: [polyW1PackedSize() * p.k]u8 = undefined;
1794                w1_prime.packW1(p.gamma1_bits, &w1_packed);
1795
1796                const c_prime = crh(p.ctilde_size, .{ &mu, &w1_packed });
1797
1798                if (!mem.eql(u8, &c_prime, &sig.c_tilde)) {
1799                    return error.SignatureVerificationFailed;
1800                }
1801            }
1802
1803            /// Create a Verifier for incrementally verifying a signature.
1804            pub fn verifier(self: Signature, public_key: PublicKey) !Verifier {
1805                return self.verifierWithContext(public_key, "");
1806            }
1807
1808            /// Create a Verifier for incrementally verifying a signature with context.
1809            /// The context parameter is an optional context string (max 255 bytes).
1810            pub fn verifierWithContext(self: Signature, public_key: PublicKey, context: []const u8) ContextTooLongError!Verifier {
1811                return Verifier.init(self, public_key, context);
1812            }
1813        };
1814
1815        /// A Signer is used to incrementally compute a signature over a streamed message.
1816        /// It can be obtained from a `SecretKey` or `KeyPair`, using the `signer()` function.
1817        pub const Signer = struct {
1818            h: sha3.Shake256, // For computing μ = CRH(tr || msg)
1819            secret_key: *const SecretKey,
1820            rnd: [32]u8,
1821
1822            /// Initialize a new Signer.
1823            /// The noise parameter can be null for deterministic signatures,
1824            /// or provide randomness for hedged signatures (recommended for fault attack resistance).
1825            /// The context parameter is an optional context string (max 255 bytes).
1826            pub fn init(secret_key: *const SecretKey, noise: ?[noise_length]u8, context: []const u8) ContextTooLongError!Signer {
1827                if (context.len > 255) {
1828                    return error.ContextTooLong;
1829                }
1830
1831                var h = sha3.Shake256.init(.{});
1832                h.update(&secret_key.tr);
1833                h.update(&[_]u8{0}); // Domain separator: 0 for pure ML-DSA
1834                h.update(&[_]u8{@intCast(context.len)});
1835                if (context.len > 0) {
1836                    h.update(context);
1837                }
1838
1839                return Signer{
1840                    .h = h,
1841                    .secret_key = secret_key,
1842                    .rnd = noise orelse .{0} ** 32,
1843                };
1844            }
1845
1846            /// Add new data to the message being signed.
1847            pub fn update(self: *Signer, data: []const u8) void {
1848                self.h.update(data);
1849            }
1850
1851            /// Compute a signature over the entire message.
1852            pub fn finalize(self: *Signer) Signature {
1853                var mu: [64]u8 = undefined;
1854                self.h.squeeze(&mu);
1855
1856                const rho_prime = crh(64, .{ &self.secret_key.key, &self.rnd, &mu });
1857
1858                var sig: Signature = undefined;
1859                var y_nonce: u16 = 0;
1860
1861                // Rejection sampling loop (FIPS 204 Algorithm 2, steps 5-16)
1862                var attempt: u32 = 0;
1863                while (true) {
1864                    attempt += 1;
1865                    if (attempt >= 576) { // (6/7)⁵⁷⁶ < 2⁻¹²⁸
1866                        @branchHint(.unlikely);
1867                        unreachable;
1868                    }
1869
1870                    const y = PolyVecL.deriveUniformLeGamma1(p.gamma1_bits, &rho_prime, y_nonce);
1871                    y_nonce += @intCast(p.l);
1872
1873                    const y_hat = y.ntt();
1874                    var w = self.secret_key.A.mulVec(y_hat);
1875
1876                    w = w.normalize();
1877                    var w0: PolyVecK = undefined;
1878                    const w1 = w.decomposeVec(p.gamma2, &w0);
1879                    var w1_packed: [polyW1PackedSize() * p.k]u8 = undefined;
1880                    w1.packW1(p.gamma1_bits, &w1_packed);
1881
1882                    sig.c_tilde = crh(p.ctilde_size, .{ &mu, &w1_packed });
1883
1884                    const c_poly = sampleInBall(p.tau, &sig.c_tilde);
1885                    const c_hat = c_poly.ntt();
1886
1887                    // Rejection check: ensure masking is effective
1888                    var w0mcs2: PolyVecK = undefined;
1889                    for (0..p.k) |i| {
1890                        w0mcs2.ps[i] = c_hat.mulHat(self.secret_key.s2_hat.ps[i]);
1891                        w0mcs2.ps[i] = w0mcs2.ps[i].invNTT();
1892                    }
1893                    w0mcs2 = w0.sub(w0mcs2);
1894                    w0mcs2 = w0mcs2.normalize();
1895
1896                    if (w0mcs2.exceeds(p.gamma2 - beta)) {
1897                        continue;
1898                    }
1899
1900                    // Compute response z = y + c·s1
1901                    for (0..p.l) |i| {
1902                        sig.z.ps[i] = c_hat.mulHat(self.secret_key.s1_hat.ps[i]);
1903                        sig.z.ps[i] = sig.z.ps[i].invNTT();
1904                    }
1905                    sig.z = sig.z.add(y);
1906                    sig.z = sig.z.normalize();
1907
1908                    if (sig.z.exceeds(gamma1 - beta)) {
1909                        continue;
1910                    }
1911
1912                    var ct0: PolyVecK = undefined;
1913                    for (0..p.k) |i| {
1914                        ct0.ps[i] = c_hat.mulHat(self.secret_key.t0_hat.ps[i]);
1915                        ct0.ps[i] = ct0.ps[i].invNTT();
1916                    }
1917                    ct0 = ct0.reduceLe2Q();
1918                    ct0 = ct0.normalize();
1919
1920                    if (ct0.exceeds(p.gamma2)) {
1921                        continue;
1922                    }
1923
1924                    // Generate hints for verification
1925                    var w0mcs2pct0 = w0mcs2.add(ct0);
1926                    w0mcs2pct0 = w0mcs2pct0.reduceLe2Q();
1927                    w0mcs2pct0 = w0mcs2pct0.normalizeAssumingLe2Q();
1928                    const hint_result = PolyVecK.makeHintVec(w0mcs2pct0, w1, p.gamma2);
1929                    if (hint_result.pop > p.omega) {
1930                        continue;
1931                    }
1932                    sig.hint = hint_result.hint;
1933
1934                    return sig;
1935                }
1936            }
1937        };
1938
1939        /// A Verifier is used to incrementally verify a signature over a streamed message.
1940        /// It can be obtained from a `Signature`, using the `verifier()` function.
1941        pub const Verifier = struct {
1942            h: sha3.Shake256, // For computing μ = CRH(tr || msg)
1943            signature: Signature,
1944            public_key: PublicKey,
1945
1946            pub const InitError = EncodingError;
1947            pub const VerifyError = SignatureVerificationError;
1948
1949            /// Initialize a new Verifier.
1950            /// The context parameter is an optional context string (max 255 bytes).
1951            pub fn init(signature: Signature, public_key: PublicKey, context: []const u8) ContextTooLongError!Verifier {
1952                if (context.len > 255) {
1953                    return error.ContextTooLong;
1954                }
1955
1956                var h = sha3.Shake256.init(.{});
1957                h.update(&public_key.tr);
1958                h.update(&[_]u8{0}); // Domain separator: 0 for pure ML-DSA
1959                h.update(&[_]u8{@intCast(context.len)}); // Context length
1960                if (context.len > 0) {
1961                    h.update(context);
1962                }
1963
1964                return Verifier{
1965                    .h = h,
1966                    .signature = signature,
1967                    .public_key = public_key,
1968                };
1969            }
1970
1971            /// Add new content to the message to be verified.
1972            pub fn update(self: *Verifier, data: []const u8) void {
1973                self.h.update(data);
1974            }
1975
1976            /// Verify that the signature is valid for the entire message.
1977            pub fn verify(self: *Verifier) SignatureVerificationError!void {
1978                var mu: [64]u8 = undefined;
1979                self.h.squeeze(&mu);
1980
1981                const z_hat = self.signature.z.ntt();
1982                const Az = self.public_key.A.mulVecHat(z_hat);
1983
1984                // Compute w' ≈ Az - 2^d·c·t1 (approximate w used in signing)
1985                var Az2dct1 = self.public_key.t1.mulBy2toD();
1986                Az2dct1 = Az2dct1.ntt();
1987                const c_poly = sampleInBall(p.tau, &self.signature.c_tilde);
1988                const c_hat = c_poly.ntt();
1989                for (0..p.k) |i| {
1990                    Az2dct1.ps[i] = Az2dct1.ps[i].mulHat(c_hat);
1991                }
1992                Az2dct1 = Az.sub(Az2dct1);
1993                Az2dct1 = Az2dct1.reduceLe2Q();
1994                Az2dct1 = Az2dct1.invNTT();
1995                Az2dct1 = Az2dct1.normalizeAssumingLe2Q();
1996
1997                // Apply hints to recover high bits w1'
1998                var w1_prime = Az2dct1.useHint(self.signature.hint, p.gamma2);
1999                var w1_packed: [polyW1PackedSize() * p.k]u8 = undefined;
2000                w1_prime.packW1(p.gamma1_bits, &w1_packed);
2001
2002                const c_prime = crh(p.ctilde_size, .{ &mu, &w1_packed });
2003
2004                if (!mem.eql(u8, &c_prime, &self.signature.c_tilde)) {
2005                    return error.SignatureVerificationFailed;
2006                }
2007            }
2008        };
2009
2010        /// A key pair consisting of a secret key and its corresponding public key.
2011        pub const KeyPair = struct {
2012            /// Length (in bytes) of a seed required to create a key pair.
2013            pub const seed_length = Self.seed_length;
2014
2015            /// The public key component.
2016            public_key: PublicKey,
2017
2018            /// The secret key component.
2019            secret_key: SecretKey,
2020
2021            /// Generate a new random key pair.
2022            /// This uses the system's cryptographically secure random number generator.
2023            ///
2024            /// `crypto.random.bytes` must be supported by the target.
2025            pub fn generate() KeyPair {
2026                var seed: [Self.seed_length]u8 = undefined;
2027                crypto.random.bytes(&seed);
2028                return generateDeterministic(seed) catch unreachable;
2029            }
2030
2031            /// Generate a key pair deterministically from a seed.
2032            /// Use for testing or when reproducibility is required.
2033            /// The seed should be generated using a cryptographically secure random source.
2034            pub fn generateDeterministic(seed: [32]u8) !KeyPair {
2035                const keys = newKeyFromSeed(&seed);
2036                return .{
2037                    .public_key = keys.pk,
2038                    .secret_key = keys.sk,
2039                };
2040            }
2041
2042            /// Derive the public key from an existing secret key.
2043            /// This recomputes the public key components from the secret key.
2044            pub fn fromSecretKey(sk: SecretKey) !KeyPair {
2045                var pk: PublicKey = undefined;
2046                pk.rho = sk.rho;
2047                pk.tr = sk.tr;
2048                pk.A = sk.A;
2049
2050                const t = computeT(sk.A, sk.s1_hat, sk.s2);
2051
2052                var t0: PolyVecK = undefined;
2053                pk.t1 = t.power2Round(&t0);
2054                pk.t1.packT1(&pk.t1_packed);
2055
2056                return .{
2057                    .public_key = pk,
2058                    .secret_key = sk,
2059                };
2060            }
2061
2062            /// Create a Signer for incrementally signing a message.
2063            /// The noise parameter can be null for deterministic signatures,
2064            /// or provide randomness for hedged signatures (recommended for fault attack resistance).
2065            pub fn signer(self: *const KeyPair, noise: ?[noise_length]u8) !Signer {
2066                return self.secret_key.signer(noise);
2067            }
2068
2069            /// Create a Signer for incrementally signing a message with context.
2070            /// The noise parameter can be null for deterministic signatures,
2071            /// or provide randomness for hedged signatures (recommended for fault attack resistance).
2072            /// The context parameter is an optional context string (max 255 bytes).
2073            pub fn signerWithContext(self: *const KeyPair, noise: ?[noise_length]u8, context: []const u8) ContextTooLongError!Signer {
2074                return self.secret_key.signerWithContext(noise, context);
2075            }
2076
2077            /// Sign a message using this key pair.
2078            /// The noise parameter can be null for deterministic signatures,
2079            /// or provide randomness for hedged signatures (recommended for fault attack resistance).
2080            pub fn sign(
2081                kp: KeyPair,
2082                msg: []const u8,
2083                noise: ?[noise_length]u8,
2084            ) !Signature {
2085                return kp.signWithContext(msg, noise, "");
2086            }
2087
2088            /// Sign a message using this key pair with context.
2089            /// The noise parameter can be null for deterministic signatures,
2090            /// or provide randomness for hedged signatures (recommended for fault attack resistance).
2091            /// The context parameter is an optional context string (max 255 bytes).
2092            pub fn signWithContext(
2093                kp: KeyPair,
2094                msg: []const u8,
2095                noise: ?[noise_length]u8,
2096                context: []const u8,
2097            ) ContextTooLongError!Signature {
2098                var st = try kp.signerWithContext(noise, context);
2099                st.update(msg);
2100                return st.finalize();
2101            }
2102        };
2103    };
2104}
2105
2106test "modular arithmetic" {
2107    // Test Montgomery reduction
2108    const x: u64 = 12345678;
2109    const y = montReduceLe2Q(x);
2110    try testing.expect(y < 2 * Q);
2111
2112    // Test modQ
2113    try testing.expectEqual(@as(u32, 0), modQ(Q));
2114    try testing.expectEqual(@as(u32, 1), modQ(Q + 1));
2115}
2116
2117test "polynomial operations" {
2118    var p1 = Poly.zero;
2119    p1.cs[0] = 1;
2120    p1.cs[1] = 2;
2121
2122    var p2 = Poly.zero;
2123    p2.cs[0] = 3;
2124    p2.cs[1] = 4;
2125
2126    const p3 = p1.add(p2);
2127    try testing.expectEqual(@as(u32, 4), p3.cs[0]);
2128    try testing.expectEqual(@as(u32, 6), p3.cs[1]);
2129}
2130
2131test "NTT and inverse NTT" {
2132    // Create a test polynomial in REGULAR FORM (not Montgomery)
2133    var p = Poly.zero;
2134    for (0..N) |i| {
2135        p.cs[i] = @intCast(i % Q);
2136    }
2137
2138    // Apply NTT then inverse NTT
2139    // According to Dilithium spec: NTT followed by invNTT multiplies by R
2140    // So result will be p * R (i.e., p in Montgomery form)
2141    var p_ntt = p.ntt();
2142
2143    // Reduce before invNTT (as Go test does)
2144    p_ntt = p_ntt.reduceLe2Q();
2145
2146    const p_restored = p_ntt.invNTT();
2147
2148    // Reduce and normalize
2149    const p_reduced = p_restored.reduceLe2Q();
2150    const p_norm = p_reduced.normalize();
2151
2152    // Check if we get p * R (which equals toMont(p))
2153    for (0..N) |i| {
2154        const original: u32 = @intCast(i % Q);
2155        const expected = toMont(original);
2156        const expected_norm = modQ(expected);
2157        try testing.expectEqual(expected_norm, p_norm.cs[i]);
2158    }
2159}
2160
2161test "parameter set instantiation" {
2162    // Just verify we can instantiate all three parameter sets
2163    const ml44 = MLDSA44;
2164    const ml65 = MLDSA65;
2165    const ml87 = MLDSA87;
2166
2167    try testing.expectEqualStrings("ML-DSA-44", ml44.name);
2168    try testing.expectEqualStrings("ML-DSA-65", ml65.name);
2169    try testing.expectEqualStrings("ML-DSA-87", ml87.name);
2170}
2171
2172test "compare zetas with Go implementation" {
2173    // First 16 zetas from Go implementation (in Montgomery form)
2174    const go_zetas = [16]u32{
2175        4193792, 25847,   5771523, 7861508, 237124,  7602457, 7504169,
2176        466468,  1826347, 2353451, 8021166, 6288512, 3119733, 5495562,
2177        3111497, 2680103,
2178    };
2179
2180    // Compare our computed zetas with Go's
2181    for (0..16) |i| {
2182        try testing.expectEqual(go_zetas[i], zetas[i]);
2183    }
2184}
2185
2186test "NTT with simple polynomial" {
2187    // Test with a very simple polynomial: just one coefficient set to 1 in regular form
2188    var p = Poly.zero;
2189    p.cs[0] = 1;
2190
2191    var p_ntt = p.ntt();
2192
2193    // Reduce before invNTT (as Go test does)
2194    p_ntt = p_ntt.reduceLe2Q();
2195
2196    const p_restored = p_ntt.invNTT();
2197
2198    // Result should be 1 * R = toMont(1) in Montgomery form
2199    const p_reduced = p_restored.reduceLe2Q();
2200    const p_norm = p_reduced.normalize();
2201
2202    const expected = modQ(toMont(1));
2203    try testing.expectEqual(expected, p_norm.cs[0]);
2204
2205    // All other coefficients should be 0 * R = 0
2206    for (1..N) |i| {
2207        try testing.expectEqual(@as(u32, 0), p_norm.cs[i]);
2208    }
2209}
2210
2211test "Montgomery reduction correctness" {
2212    // Test that Montgomery reduction works correctly
2213    // montReduceLe2Q(a * b * R) = a * b mod q (where a, b are in Montgomery form)
2214
2215    const x: u32 = 12345;
2216    const y: u32 = 67890;
2217
2218    // Convert to Montgomery form
2219    const x_mont = toMont(x);
2220    const y_mont = toMont(y);
2221
2222    // Multiply in Montgomery form
2223    const product_mont = montReduceLe2Q(@as(u64, x_mont) * @as(u64, y_mont));
2224
2225    // Convert back from Montgomery form
2226    const product = montReduceLe2Q(@as(u64, product_mont));
2227
2228    // Direct multiplication mod q
2229    const expected = modQ(@as(u32, @intCast((@as(u64, x) * @as(u64, y)) % Q)));
2230
2231    try testing.expectEqual(expected, modQ(product));
2232}
2233
2234// Removed debug test - was causing noise in output
2235
2236test "compare inv_zetas with Go implementation" {
2237    // First 16 inv_zetas from Go implementation
2238    const go_inv_zetas = [16]u32{
2239        6403635, 846154,  6979993, 4442679, 1362209, 48306,   4460757,
2240        554416,  3545687, 6767575, 976891,  8196974, 2286327, 420899,
2241        2235985, 2939036,
2242    };
2243
2244    // Compare our computed inv_zetas with Go's
2245    for (0..16) |i| {
2246        if (inv_zetas[i] != go_inv_zetas[i]) {
2247            std.debug.print("Mismatch at inv_zetas[{d}]: got {d}, expected {d}\n", .{ i, inv_zetas[i], go_inv_zetas[i] });
2248        }
2249        try testing.expectEqual(go_inv_zetas[i], inv_zetas[i]);
2250    }
2251}
2252
2253test "power2Round correctness" {
2254    // Test that power2Round correctly splits values
2255    // For all a in [0, Q), we should have a = a1*2^D + a0
2256    // where -2^(D-1) < a0 <= 2^(D-1)
2257
2258    // Test a few specific values
2259    const test_values = [_]u32{ 0, 1, Q / 2, Q - 1, 12345, 8380416 };
2260
2261    for (test_values) |a| {
2262        if (a >= Q) continue;
2263
2264        const result = power2Round(a);
2265        const a0 = @as(i32, @bitCast(result.a0_plus_q -% Q));
2266        const a1 = result.a1;
2267
2268        // Check reconstruction: a = a1*2^D + a0
2269        const reconstructed = @as(i32, @bitCast(a1 << D)) + a0;
2270        try testing.expectEqual(@as(i32, @bitCast(a)), reconstructed);
2271
2272        // Check a0 bounds: -2^(D-1) < a0 <= 2^(D-1)
2273        const bound: i32 = 1 << (D - 1);
2274        try testing.expect(a0 > -bound and a0 <= bound);
2275    }
2276}
2277
2278test "decompose correctness for ML-DSA-65" {
2279    // Test decompose with gamma2 = 95232 (ML-DSA-44)
2280    const gamma2 = 95232;
2281    const alpha = 2 * gamma2;
2282
2283    const test_values = [_]u32{ 0, 1, Q / 2, Q - 1, 12345 };
2284
2285    for (test_values) |a| {
2286        if (a >= Q) continue;
2287
2288        const result = decompose(a, gamma2);
2289        const a0 = @as(i32, @bitCast(result.a0_plus_q -% Q));
2290        const a1 = result.a1;
2291
2292        // Check reconstruction: a = a1*alpha + a0 (mod Q)
2293        var reconstructed: i64 = @as(i64, @intCast(a1)) * @as(i64, @intCast(alpha)) + @as(i64, a0);
2294        reconstructed = @mod(reconstructed, @as(i64, Q));
2295        try testing.expectEqual(@as(i64, @intCast(a)), reconstructed);
2296
2297        // Check a0 bounds (approximately)
2298        const bound: i32 = @intCast(alpha / 2);
2299        try testing.expect(@abs(a0) <= bound);
2300    }
2301}
2302
2303test "decompose correctness for ML-DSA-87" {
2304    // Test decompose with gamma2 = 261888 (ML-DSA-65 and ML-DSA-87)
2305    const gamma2 = 261888;
2306    const alpha = 2 * gamma2;
2307
2308    const test_values = [_]u32{ 0, 1, Q / 2, Q - 1, 12345 };
2309
2310    for (test_values) |a| {
2311        if (a >= Q) continue;
2312
2313        const result = decompose(a, gamma2);
2314        const a0 = @as(i32, @bitCast(result.a0_plus_q -% Q));
2315        const a1 = result.a1;
2316
2317        // Check reconstruction: a = a1*alpha + a0 (mod Q)
2318        var reconstructed: i64 = @as(i64, @intCast(a1)) * @as(i64, @intCast(alpha)) + @as(i64, a0);
2319        reconstructed = @mod(reconstructed, @as(i64, Q));
2320        try testing.expectEqual(@as(i64, @intCast(a)), reconstructed);
2321
2322        // Check a0 bounds (approximately)
2323        const bound: i32 = @intCast(alpha / 2);
2324        try testing.expect(@abs(a0) <= bound);
2325    }
2326}
2327
2328test "polyDeriveUniform deterministic" {
2329    // Test that polyDeriveUniform produces deterministic results
2330    const seed: [32]u8 = .{0x01} ++ .{0x00} ** 31;
2331    const nonce: u16 = 0;
2332
2333    const p1 = polyDeriveUniform(&seed, nonce);
2334    const p2 = polyDeriveUniform(&seed, nonce);
2335
2336    // Should be identical
2337    for (0..N) |i| {
2338        try testing.expectEqual(p1.cs[i], p2.cs[i]);
2339    }
2340
2341    // All coefficients should be in [0, Q)
2342    for (0..N) |i| {
2343        try testing.expect(p1.cs[i] < Q);
2344    }
2345}
2346
2347test "polyDeriveUniform different nonces" {
2348    // Test that different nonces produce different polynomials
2349    const seed: [32]u8 = .{0x01} ++ .{0x00} ** 31;
2350
2351    const p1 = polyDeriveUniform(&seed, 0);
2352    const p2 = polyDeriveUniform(&seed, 1);
2353
2354    // Should be different
2355    var different = false;
2356    for (0..N) |i| {
2357        if (p1.cs[i] != p2.cs[i]) {
2358            different = true;
2359            break;
2360        }
2361    }
2362    try testing.expect(different);
2363}
2364
2365test "expandS with eta=2" {
2366    // Test eta=2 sampling
2367    const seed: [64]u8 = .{0x02} ++ .{0x00} ** 63;
2368    const nonce: u16 = 0;
2369
2370    const p = expandS(2, &seed, nonce);
2371
2372    // All coefficients should be in [Q-eta, Q+eta]
2373    // The function returns coefficients as Q + eta - t, where t is in [0, 2*eta]
2374    // So coefficients are in [Q-eta, Q+eta]
2375    for (0..N) |i| {
2376        const c = p.cs[i];
2377        // Check that c is in [Q-2, Q+2]
2378        try testing.expect(c >= Q - 2 and c <= Q + 2);
2379    }
2380}
2381
2382test "expandS with eta=4" {
2383    // Test eta=4 sampling
2384    const seed: [64]u8 = .{0x03} ++ .{0x00} ** 63;
2385    const nonce: u16 = 0;
2386
2387    const p = expandS(4, &seed, nonce);
2388
2389    // All coefficients should be in [Q-eta, Q+eta]
2390    for (0..N) |i| {
2391        const c = p.cs[i];
2392        // Check bounds (coefficients are around Q ± eta)
2393        const diff = if (c >= Q) c - Q else Q - c;
2394        try testing.expect(diff <= 4);
2395    }
2396}
2397
2398test "sampleInBall has correct weight" {
2399    // Test that ball polynomial has exactly tau non-zero coefficients
2400    const tau = 39; // From ML-DSA-44
2401    const seed: [32]u8 = .{0x04} ++ .{0x00} ** 31;
2402
2403    const p = sampleInBall(tau, &seed);
2404
2405    // Count non-zero coefficients
2406    var count: u32 = 0;
2407    for (0..N) |i| {
2408        if (p.cs[i] != 0) {
2409            count += 1;
2410            // Non-zero coefficients should be 1 or Q-1
2411            try testing.expect(p.cs[i] == 1 or p.cs[i] == Q - 1);
2412        }
2413    }
2414
2415    try testing.expectEqual(tau, count);
2416}
2417
2418test "sampleInBall deterministic" {
2419    // Test that ball sampling is deterministic
2420    const tau = 49; // From ML-DSA-65
2421    const seed: [32]u8 = .{0x05} ++ .{0x00} ** 31;
2422
2423    const p1 = sampleInBall(tau, &seed);
2424    const p2 = sampleInBall(tau, &seed);
2425
2426    // Should be identical
2427    for (0..N) |i| {
2428        try testing.expectEqual(p1.cs[i], p2.cs[i]);
2429    }
2430}
2431
2432test "polyPackLeqEta / polyUnpackLeqEta roundtrip for eta=2" {
2433    // Test packing and unpacking for eta=2
2434    const eta = 2;
2435
2436    // Create a test polynomial with coefficients in [Q-eta, Q+eta]
2437    var p = Poly.zero;
2438    for (0..N) |i| {
2439        // Use various values in range
2440        const val = @as(u32, @intCast(i % 5)); // 0, 1, 2, 3, 4
2441        p.cs[i] = Q + eta - val;
2442    }
2443
2444    // Pack it
2445    var buf: [96]u8 = undefined; // eta=2: 3 bits per coeff = 96 bytes
2446    polyPackLeqEta(p, eta, &buf);
2447
2448    // Unpack it
2449    const p2 = polyUnpackLeqEta(eta, &buf);
2450
2451    // Should be identical
2452    for (0..N) |i| {
2453        try testing.expectEqual(p.cs[i], p2.cs[i]);
2454    }
2455}
2456
2457test "polyPackLeqEta / polyUnpackLeqEta roundtrip for eta=4" {
2458    // Test packing and unpacking for eta=4
2459    const eta = 4;
2460
2461    // Create a test polynomial with coefficients in [Q-eta, Q+eta]
2462    var p = Poly.zero;
2463    for (0..N) |i| {
2464        // Use various values in range
2465        const val = @as(u32, @intCast(i % 9)); // 0, 1, 2, ..., 8
2466        p.cs[i] = Q + eta - val;
2467    }
2468
2469    // Pack it
2470    var buf: [128]u8 = undefined; // eta=4: 4 bits per coeff = 128 bytes
2471    polyPackLeqEta(p, eta, &buf);
2472
2473    // Unpack it
2474    const p2 = polyUnpackLeqEta(eta, &buf);
2475
2476    // Should be identical
2477    for (0..N) |i| {
2478        try testing.expectEqual(p.cs[i], p2.cs[i]);
2479    }
2480}
2481
2482test "polyPackT1 / polyUnpackT1 roundtrip" {
2483    // Create a test polynomial with coefficients < 1024
2484    var p = Poly.zero;
2485    for (0..N) |i| {
2486        p.cs[i] = @intCast(i % 1024);
2487    }
2488
2489    // Pack it
2490    var buf: [320]u8 = undefined; // (256 * 10) / 8 = 320 bytes
2491    polyPackT1(p, &buf);
2492
2493    // Unpack it
2494    const p2 = polyUnpackT1(&buf);
2495
2496    // Should be identical
2497    for (0..N) |i| {
2498        try testing.expectEqual(p.cs[i], p2.cs[i]);
2499    }
2500}
2501
2502test "polyPackT0 / polyUnpackT0 roundtrip" {
2503    // Create a test polynomial with coefficients in (Q-2^12, Q+2^12]
2504    // This is the range (-2^12, 2^12] represented as unsigned around Q
2505    const bound = 1 << 12; // 2^(D-1) where D=13
2506    var p = Poly.zero;
2507    for (0..N) |i| {
2508        // Cycle through valid range for T0
2509        // Values should be Q + offset where offset is in (-bound, bound]
2510        const cycle_val = @as(i32, @intCast(i % (2 * bound))); // 0 to 2*bound-1
2511        const offset = cycle_val - bound + 1; // (-bound+1) to bound
2512        p.cs[i] = @as(u32, @intCast(@as(i32, Q) + offset));
2513    }
2514
2515    // Pack it
2516    var buf: [416]u8 = undefined; // (256 * 13) / 8 = 416 bytes
2517    polyPackT0(p, &buf);
2518
2519    // Unpack it
2520    const p2 = polyUnpackT0(&buf);
2521
2522    // Should be identical
2523    for (0..N) |i| {
2524        try testing.expectEqual(p.cs[i], p2.cs[i]);
2525    }
2526}
2527
2528test "polyPackLeGamma1 / polyUnpackLeGamma1 roundtrip gamma1_bits=17" {
2529    const gamma1_bits = 17;
2530    const gamma1: u32 = @as(u32, 1) << gamma1_bits;
2531
2532    // Create a test polynomial with coefficients in (-gamma1, gamma1]
2533    // Normalized: [0, gamma1] ∪ (Q-gamma1, Q)
2534    var p = Poly.zero;
2535    for (0..N) |i| {
2536        if (i % 2 == 0) {
2537            // Positive values: [0, gamma1]
2538            p.cs[i] = @intCast((i / 2) % (gamma1 + 1));
2539        } else {
2540            // Negative values: (Q-gamma1, Q)
2541            const neg_val: u32 = @intCast(((i / 2) % gamma1) + 1);
2542            p.cs[i] = Q - neg_val;
2543        }
2544    }
2545
2546    // Pack it
2547    var buf: [576]u8 = undefined; // (256 * 18) / 8 = 576 bytes
2548    polyPackLeGamma1(p, gamma1_bits, &buf);
2549
2550    // Unpack it
2551    const p2 = polyUnpackLeGamma1(gamma1_bits, &buf);
2552
2553    // Should be identical
2554    for (0..N) |i| {
2555        try testing.expectEqual(p.cs[i], p2.cs[i]);
2556    }
2557}
2558
2559test "polyPackLeGamma1 / polyUnpackLeGamma1 roundtrip gamma1_bits=19" {
2560    const gamma1_bits = 19;
2561    const gamma1: u32 = @as(u32, 1) << gamma1_bits;
2562
2563    // Create a test polynomial with coefficients in (-gamma1, gamma1]
2564    var p = Poly.zero;
2565    for (0..N) |i| {
2566        if (i % 2 == 0) {
2567            // Positive values: [0, gamma1]
2568            p.cs[i] = @intCast((i / 2) % (gamma1 + 1));
2569        } else {
2570            // Negative values: (Q-gamma1, Q)
2571            const neg_val: u32 = @intCast(((i / 2) % gamma1) + 1);
2572            p.cs[i] = Q - neg_val;
2573        }
2574    }
2575
2576    // Pack it
2577    var buf: [640]u8 = undefined; // (256 * 20) / 8 = 640 bytes
2578    polyPackLeGamma1(p, gamma1_bits, &buf);
2579
2580    // Unpack it
2581    const p2 = polyUnpackLeGamma1(gamma1_bits, &buf);
2582
2583    // Should be identical
2584    for (0..N) |i| {
2585        try testing.expectEqual(p.cs[i], p2.cs[i]);
2586    }
2587}
2588
2589test "polyPackW1 for gamma1_bits=17" {
2590    const gamma1_bits = 17;
2591
2592    // Create a test polynomial with small coefficients (w1 values < 64)
2593    var p = Poly.zero;
2594    for (0..N) |i| {
2595        p.cs[i] = @intCast(i % 64); // 6-bit values
2596    }
2597
2598    // Pack it
2599    var buf: [192]u8 = undefined; // (256 * 6) / 8 = 192 bytes
2600    polyPackW1(p, gamma1_bits, &buf);
2601
2602    // Verify basic properties
2603    // All bytes should be used
2604    var non_zero = false;
2605    for (buf) |b| {
2606        if (b != 0) {
2607            non_zero = true;
2608            break;
2609        }
2610    }
2611    try testing.expect(non_zero);
2612}
2613
2614test "polyPackW1 for gamma1_bits=19" {
2615    const gamma1_bits = 19;
2616
2617    // Create a test polynomial with small coefficients (w1 values < 16)
2618    var p = Poly.zero;
2619    for (0..N) |i| {
2620        p.cs[i] = @intCast(i % 16); // 4-bit values
2621    }
2622
2623    // Pack it
2624    var buf: [128]u8 = undefined; // (256 * 4) / 8 = 128 bytes
2625    polyPackW1(p, gamma1_bits, &buf);
2626
2627    // Verify basic properties
2628    var non_zero = false;
2629    for (buf) |b| {
2630        if (b != 0) {
2631            non_zero = true;
2632            break;
2633        }
2634    }
2635    try testing.expect(non_zero);
2636}
2637
2638test "makeHint and useHint correctness for gamma2=261888" {
2639    // Test for ML-DSA-65 and ML-DSA-87
2640    const gamma2: u32 = 261888;
2641
2642    // Test a selection of values to verify the hint mechanism works
2643    const test_values = [_]u32{ 0, 100, 1000, 10000, 100000, 1000000, Q / 2, Q - 1 };
2644
2645    for (test_values) |w| {
2646        // Decompose w to get w0 and w1
2647        const decomp = decompose(w, gamma2);
2648        const w0_plus_q = decomp.a0_plus_q;
2649        const w1 = decomp.a1;
2650
2651        // Test with various small perturbations f in [0, gamma2]
2652        const perturbations = [_]u32{ 0, 1, 10, 100, 1000, gamma2 / 2, gamma2 };
2653
2654        for (perturbations) |f| {
2655            // Test f (positive perturbation)
2656            const z0_pos = (w0_plus_q +% Q -% f) % Q;
2657            const hint_pos = makeHint(z0_pos, w1, gamma2);
2658            const w_perturbed_pos = (w +% Q -% f) % Q;
2659            const w1_recovered_pos = useHint(w_perturbed_pos, hint_pos, gamma2);
2660            try testing.expectEqual(w1, w1_recovered_pos);
2661
2662            // Test -f (negative perturbation)
2663            if (f > 0) {
2664                const z0_neg = (w0_plus_q +% f) % Q;
2665                const hint_neg = makeHint(z0_neg, w1, gamma2);
2666                const w_perturbed_neg = (w +% f) % Q;
2667                const w1_recovered_neg = useHint(w_perturbed_neg, hint_neg, gamma2);
2668                try testing.expectEqual(w1, w1_recovered_neg);
2669            }
2670        }
2671    }
2672}
2673
2674test "makeHint and useHint correctness for gamma2=95232" {
2675    // Test for ML-DSA-44
2676    const gamma2: u32 = 95232;
2677
2678    // Test a selection of values to verify the hint mechanism works
2679    const test_values = [_]u32{ 0, 100, 1000, 10000, 100000, 1000000, Q / 2, Q - 1 };
2680
2681    for (test_values) |w| {
2682        // Decompose w to get w0 and w1
2683        const decomp = decompose(w, gamma2);
2684        const w0_plus_q = decomp.a0_plus_q;
2685        const w1 = decomp.a1;
2686
2687        // Test with various small perturbations f in [0, gamma2]
2688        const perturbations = [_]u32{ 0, 1, 10, 100, 1000, gamma2 / 2, gamma2 };
2689
2690        for (perturbations) |f| {
2691            // Test f (positive perturbation)
2692            const z0_pos = (w0_plus_q +% Q -% f) % Q;
2693            const hint_pos = makeHint(z0_pos, w1, gamma2);
2694            const w_perturbed_pos = (w +% Q -% f) % Q;
2695            const w1_recovered_pos = useHint(w_perturbed_pos, hint_pos, gamma2);
2696            try testing.expectEqual(w1, w1_recovered_pos);
2697
2698            // Test -f (negative perturbation)
2699            if (f > 0) {
2700                const z0_neg = (w0_plus_q +% f) % Q;
2701                const hint_neg = makeHint(z0_neg, w1, gamma2);
2702                const w_perturbed_neg = (w +% f) % Q;
2703                const w1_recovered_neg = useHint(w_perturbed_neg, hint_neg, gamma2);
2704                try testing.expectEqual(w1, w1_recovered_neg);
2705            }
2706        }
2707    }
2708}
2709
2710test "polyMakeHint basic functionality" {
2711    const gamma2: u32 = 261888;
2712
2713    // Create test polynomials
2714    var p0 = Poly.zero;
2715    var p1 = Poly.zero;
2716
2717    // Fill with test values
2718    for (0..N) |i| {
2719        p0.cs[i] = @intCast((i * 17) % Q);
2720        p1.cs[i] = @intCast((i * 3) % 16); // High bits are at most 15 for gamma2=261888
2721    }
2722
2723    // Make hints
2724    const result = polyMakeHint(p0, p1, gamma2);
2725    const hint = result.hint;
2726    const count = result.count;
2727
2728    // Verify that hints are binary
2729    for (0..N) |i| {
2730        try testing.expect(hint.cs[i] == 0 or hint.cs[i] == 1);
2731    }
2732
2733    // Verify that count matches the number of 1s in hint
2734    var actual_count: u32 = 0;
2735    for (0..N) |i| {
2736        actual_count += hint.cs[i];
2737    }
2738    try testing.expectEqual(count, actual_count);
2739}
2740
2741test "polyUseHint reconstruction" {
2742    const gamma2: u32 = 261888;
2743
2744    // Create a test polynomial q
2745    var q = Poly.zero;
2746    for (0..N) |i| {
2747        q.cs[i] = @intCast((i * 123) % Q);
2748    }
2749
2750    // Decompose q to get high and low bits
2751    var q0_plus_q_array: [N]u32 = undefined;
2752    var q1_array: [N]u32 = undefined;
2753    for (0..N) |i| {
2754        const decomp = decompose(q.cs[i], gamma2);
2755        q0_plus_q_array[i] = decomp.a0_plus_q;
2756        q1_array[i] = decomp.a1;
2757    }
2758
2759    const q0_plus_q = Poly{ .cs = q0_plus_q_array };
2760    const q1 = Poly{ .cs = q1_array };
2761
2762    // Create hints (in this case, they'll mostly be 0 since q and q are the same)
2763    const hint_result = polyMakeHint(q0_plus_q, q1, gamma2);
2764    const hint = hint_result.hint;
2765
2766    // Use hints to recover high bits
2767    const recovered = polyUseHint(q, hint, gamma2);
2768
2769    // Recovered should match original high bits q1
2770    for (0..N) |i| {
2771        try testing.expectEqual(q1.cs[i], recovered.cs[i]);
2772    }
2773}
2774
2775test "hint roundtrip with perturbation" {
2776    const gamma2: u32 = 261888;
2777
2778    // Create a test polynomial w
2779    var w = Poly.zero;
2780    for (0..N) |i| {
2781        w.cs[i] = @intCast((i * 7919) % Q);
2782    }
2783
2784    // Decompose w to get w0 and w1
2785    var w0_plus_q = Poly.zero;
2786    var w1 = Poly.zero;
2787    for (0..N) |i| {
2788        const decomp = decompose(w.cs[i], gamma2);
2789        w0_plus_q.cs[i] = decomp.a0_plus_q;
2790        w1.cs[i] = decomp.a1;
2791    }
2792
2793    // Apply a small perturbation
2794    var f = Poly.zero;
2795    for (0..N) |i| {
2796        // Small perturbation in [-gamma2, gamma2]
2797        const f_val = @as(u32, @intCast(i % 1000));
2798        f.cs[i] = if (i % 2 == 0) f_val else Q -% f_val;
2799    }
2800
2801    // Compute w' = w - f and z0 = w0 - f
2802    var w_prime = Poly.zero;
2803    var z0 = Poly.zero;
2804    for (0..N) |i| {
2805        w_prime.cs[i] = (w.cs[i] +% Q -% f.cs[i]) % Q;
2806        z0.cs[i] = (w0_plus_q.cs[i] +% Q -% f.cs[i]) % Q;
2807    }
2808
2809    // Make hints
2810    const hint_result = polyMakeHint(z0, w1, gamma2);
2811    const hint = hint_result.hint;
2812
2813    // Use hints to recover w1 from w_prime
2814    const w1_recovered = polyUseHint(w_prime, hint, gamma2);
2815
2816    // Verify that we recovered the original high bits
2817    for (0..N) |i| {
2818        try testing.expectEqual(w1.cs[i], w1_recovered.cs[i]);
2819    }
2820}
2821
2822// Parameterized test helper for key generation
2823
2824fn testKeyGenerationBasic(comptime MlDsa: type, seed: [32]u8) !void {
2825    const result = MlDsa.newKeyFromSeed(&seed);
2826    const pk = result.pk;
2827    const sk = result.sk;
2828
2829    // Basic sanity checks
2830    try testing.expect(pk.rho.len == 32);
2831    try testing.expect(sk.rho.len == 32);
2832    try testing.expectEqualSlices(u8, &pk.rho, &sk.rho);
2833
2834    // Verify tr matches between pk and sk
2835    try testing.expectEqualSlices(u8, &pk.tr, &sk.tr);
2836
2837    // Test toBytes/fromBytes round-trip for public key
2838    const pk_bytes = pk.toBytes();
2839    const pk2 = try MlDsa.PublicKey.fromBytes(pk_bytes);
2840    try testing.expectEqualSlices(u8, &pk.rho, &pk2.rho);
2841    try testing.expectEqualSlices(u8, &pk.tr, &pk2.tr);
2842
2843    // Test toBytes/fromBytes round-trip for secret key
2844    const sk_bytes = sk.toBytes();
2845    const sk2 = try MlDsa.SecretKey.fromBytes(sk_bytes);
2846    try testing.expectEqualSlices(u8, &sk.rho, &sk2.rho);
2847    try testing.expectEqualSlices(u8, &sk.key, &sk2.key);
2848    try testing.expectEqualSlices(u8, &sk.tr, &sk2.tr);
2849}
2850
2851test "Key generation basic - all variants" {
2852    inline for (.{
2853        .{ .variant = MLDSA44, .seed_byte = 0x44 },
2854        .{ .variant = MLDSA65, .seed_byte = 0x65 },
2855        .{ .variant = MLDSA87, .seed_byte = 0x87 },
2856    }) |config| {
2857        const seed = [_]u8{config.seed_byte} ** 32;
2858        try testKeyGenerationBasic(config.variant, seed);
2859    }
2860}
2861
2862test "Key generation determinism" {
2863    const seed = [_]u8{ 0x12, 0x34, 0x56, 0x78 } ++ [_]u8{0xAB} ** 28;
2864
2865    // Generate two key pairs from the same seed
2866    const result1 = MLDSA44.newKeyFromSeed(&seed);
2867    const result2 = MLDSA44.newKeyFromSeed(&seed);
2868
2869    // They should be identical
2870    const pk_bytes1 = result1.pk.toBytes();
2871    const pk_bytes2 = result2.pk.toBytes();
2872    try testing.expectEqualSlices(u8, &pk_bytes1, &pk_bytes2);
2873
2874    const sk_bytes1 = result1.sk.toBytes();
2875    const sk_bytes2 = result2.sk.toBytes();
2876    try testing.expectEqualSlices(u8, &sk_bytes1, &sk_bytes2);
2877}
2878
2879test "Private key can compute public key" {
2880    const seed = [_]u8{0xFF} ** 32;
2881    const result = MLDSA44.newKeyFromSeed(&seed);
2882    const pk = result.pk;
2883    const sk = result.sk;
2884
2885    // Compute public key from private key
2886    const pk_from_sk = sk.public();
2887
2888    // Pack both public keys and compare
2889    const pk_bytes1 = pk.toBytes();
2890    const pk_bytes2 = pk_from_sk.toBytes();
2891
2892    try testing.expectEqualSlices(u8, &pk_bytes1, &pk_bytes2);
2893}
2894
2895// Parameterized test helper for sign and verify
2896fn testSignAndVerify(comptime MlDsa: type, seed: [32]u8, message: []const u8) !void {
2897    const result = MlDsa.newKeyFromSeed(&seed);
2898    const kp = try MlDsa.KeyPair.fromSecretKey(result.sk);
2899
2900    // Sign the message
2901    const sig = try kp.sign(message, null);
2902
2903    // Verify the signature
2904    try sig.verify(message, kp.public_key);
2905}
2906
2907test "Sign and verify - all variants" {
2908    inline for (.{
2909        .{ .variant = MLDSA44, .seed_byte = 0x44, .message = "Hello, ML-DSA-44!" },
2910        .{ .variant = MLDSA65, .seed_byte = 0x65, .message = "Hello, ML-DSA-65!" },
2911        .{ .variant = MLDSA87, .seed_byte = 0x87, .message = "Hello, ML-DSA-87!" },
2912    }) |config| {
2913        const seed = [_]u8{config.seed_byte} ** 32;
2914        try testSignAndVerify(config.variant, seed, config.message);
2915    }
2916}
2917
2918test "Invalid signature rejection" {
2919    const seed = [_]u8{0x99} ** 32;
2920    const result = MLDSA44.newKeyFromSeed(&seed);
2921    const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk);
2922
2923    const message = "Original message";
2924
2925    // Sign the message
2926    const sig = try kp.sign(message, null);
2927
2928    // Verify with wrong message should fail
2929    const wrong_message = "Modified message";
2930    try testing.expectError(error.SignatureVerificationFailed, sig.verify(wrong_message, kp.public_key));
2931
2932    // Modify signature and verify should fail
2933    var corrupted_sig_bytes = sig.toBytes();
2934    corrupted_sig_bytes[0] ^= 0xFF;
2935    const corrupted_sig = try MLDSA44.Signature.fromBytes(corrupted_sig_bytes);
2936    try testing.expectError(error.SignatureVerificationFailed, corrupted_sig.verify(message, kp.public_key));
2937}
2938
2939test "Context string support" {
2940    const seed = [_]u8{0xAA} ** 32;
2941    const result = MLDSA44.newKeyFromSeed(&seed);
2942    const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk);
2943
2944    const message = "Test message";
2945    const context1 = "context1";
2946    const context2 = "context2";
2947
2948    // Sign with context1
2949    const sig1 = try kp.signWithContext(message, null, context1);
2950
2951    // Verify with correct context should succeed
2952    try sig1.verifyWithContext(message, kp.public_key, context1);
2953
2954    // Verify with wrong context should fail
2955    try testing.expectError(error.SignatureVerificationFailed, sig1.verifyWithContext(message, kp.public_key, context2));
2956
2957    // Verify with empty context should fail
2958    try testing.expectError(error.SignatureVerificationFailed, sig1.verify(message, kp.public_key));
2959
2960    // Sign with empty context
2961    const sig2 = try kp.sign(message, null);
2962
2963    // Verify with empty context should succeed
2964    try sig2.verify(message, kp.public_key);
2965
2966    // Verify with non-empty context should fail
2967    try testing.expectError(error.SignatureVerificationFailed, sig2.verifyWithContext(message, kp.public_key, context1));
2968
2969    // Test maximum context length (255 bytes)
2970    const max_context = [_]u8{0xBB} ** 255;
2971    const sig3 = try kp.signWithContext(message, null, &max_context);
2972    try sig3.verifyWithContext(message, kp.public_key, &max_context);
2973
2974    // Test context too long (256 bytes should fail)
2975    const too_long_context = [_]u8{0xCC} ** 256;
2976    try testing.expectError(error.ContextTooLong, kp.signWithContext(message, null, &too_long_context));
2977}
2978
2979test "Context string with streaming API" {
2980    const seed = [_]u8{0xDD} ** 32;
2981    const result = MLDSA44.newKeyFromSeed(&seed);
2982    const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk);
2983
2984    const context = "streaming-context";
2985    const message_part1 = "Hello, ";
2986    const message_part2 = "World!";
2987
2988    // Sign using streaming API with context
2989    var signer = try kp.signerWithContext(null, context);
2990    signer.update(message_part1);
2991    signer.update(message_part2);
2992    const sig = signer.finalize();
2993
2994    // Verify using streaming API with context
2995    var verifier = try sig.verifierWithContext(kp.public_key, context);
2996    verifier.update(message_part1);
2997    verifier.update(message_part2);
2998    try verifier.verify();
2999
3000    // Verify with wrong context should fail
3001    var verifier_wrong = try sig.verifierWithContext(kp.public_key, "wrong");
3002    verifier_wrong.update(message_part1);
3003    verifier_wrong.update(message_part2);
3004    try testing.expectError(error.SignatureVerificationFailed, verifier_wrong.verify());
3005}
3006
3007test "Signature determinism (same rnd)" {
3008    const seed = [_]u8{0x11} ** 32;
3009    const result = MLDSA44.newKeyFromSeed(&seed);
3010    const sk = result.sk;
3011
3012    const message = "Deterministic test";
3013    const rnd = [_]u8{0x22} ** 32;
3014
3015    // Sign twice with same randomness using streaming API
3016    var st1 = try sk.signer(rnd);
3017    st1.update(message);
3018    const sig1 = st1.finalize();
3019
3020    var st2 = try sk.signer(rnd);
3021    st2.update(message);
3022    const sig2 = st2.finalize();
3023
3024    // Signatures should be identical
3025    try testing.expectEqualSlices(u8, &sig1.toBytes(), &sig2.toBytes());
3026}
3027
3028test "Signature toBytes/fromBytes roundtrip" {
3029    const seed = [_]u8{0x33} ** 32;
3030    const result = MLDSA44.newKeyFromSeed(&seed);
3031    const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk);
3032
3033    const message = "toBytes/fromBytes test";
3034
3035    // Sign the message
3036    const sig = try kp.sign(message, null);
3037    const sig_bytes = sig.toBytes();
3038
3039    // Unpack and repack
3040    const sig_reparsed = try MLDSA44.Signature.fromBytes(sig_bytes);
3041
3042    const repacked = sig_reparsed.toBytes();
3043
3044    // Should match original
3045    try testing.expectEqualSlices(u8, &sig_bytes, &repacked);
3046}
3047
3048test "Empty message signing" {
3049    const seed = [_]u8{0x44} ** 32;
3050    const result = MLDSA44.newKeyFromSeed(&seed);
3051    const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk);
3052
3053    const message = "";
3054
3055    // Sign empty message
3056    const sig = try kp.sign(message, null);
3057
3058    // Verify should work
3059    try sig.verify(message, kp.public_key);
3060}
3061
3062test "Long message signing" {
3063    const seed = [_]u8{0x55} ** 32;
3064    const result = MLDSA44.newKeyFromSeed(&seed);
3065    const kp = try MLDSA44.KeyPair.fromSecretKey(result.sk);
3066
3067    // Create a long message (1KB)
3068    const long_message = [_]u8{0xAB} ** 1024;
3069
3070    // Sign long message
3071    const sig = try kp.sign(&long_message, null);
3072
3073    // Verify should work
3074    try sig.verify(&long_message, kp.public_key);
3075}
3076
3077// Helper function to decode hex string into bytes
3078fn hexToBytes(comptime hex: []const u8, out: []u8) !void {
3079    if (hex.len != out.len * 2) return error.InvalidLength;
3080
3081    var i: usize = 0;
3082    while (i < out.len) : (i += 1) {
3083        const hi = try std.fmt.charToDigit(hex[i * 2], 16);
3084        const lo = try std.fmt.charToDigit(hex[i * 2 + 1], 16);
3085        out[i] = (hi << 4) | lo;
3086    }
3087}
3088
3089test "ML-DSA-44 KAT test vector 0" {
3090    // Test vector from NIST ML-DSA KAT (count = 0)
3091    // xi is the seed for key generation (Algorithm 1, line 1)
3092    const xi_hex = "f696484048ec21f96cf50a56d0759c448f3779752f0383d37449690694cf7a68";
3093    const pk_hex_start = "bd4e96f9a038ab5e36214fe69c0b1cb835ef9d7c8417e76aecd152f5cddebec8";
3094    const msg_hex = "6dbbc4375136df3b07f7c70e639e223e";
3095
3096    // Parse xi (32-byte seed for key generation)
3097    var xi: [32]u8 = undefined;
3098    try hexToBytes(xi_hex, &xi);
3099
3100    // Generate keys from xi
3101    const result = MLDSA44.newKeyFromSeed(&xi);
3102    const pk = result.pk;
3103    const sk = result.sk;
3104
3105    // Verify public key starts with expected bytes
3106    const pk_bytes = pk.toBytes();
3107
3108    var expected_pk_start: [32]u8 = undefined;
3109    try hexToBytes(pk_hex_start, &expected_pk_start);
3110
3111    // Check first 32 bytes of public key match
3112    try testing.expectEqualSlices(u8, &expected_pk_start, pk_bytes[0..32]);
3113
3114    // Parse message
3115    var msg: [16]u8 = undefined;
3116    try hexToBytes(msg_hex, &msg);
3117
3118    // Sign the message (deterministic mode with fixed randomness)
3119    const kp = try MLDSA44.KeyPair.fromSecretKey(sk);
3120    const sig = try kp.sign(&msg, null);
3121
3122    // Verify the signature
3123    try sig.verify(&msg, kp.public_key);
3124}
3125
3126test "ML-DSA-65 KAT test vector 0" {
3127    // Test vector from NIST ML-DSA KAT (count = 0)
3128    // xi is the seed for key generation (Algorithm 1, line 1)
3129    const xi_hex = "f696484048ec21f96cf50a56d0759c448f3779752f0383d37449690694cf7a68";
3130    const pk_hex_start = "e50d03fff3b3a70961abbb92a390008dec1283f603f50cdbaaa3d00bd659bc76";
3131    const msg_hex = "6dbbc4375136df3b07f7c70e639e223e";
3132
3133    // Parse xi (32-byte seed for key generation)
3134    var xi: [32]u8 = undefined;
3135    try hexToBytes(xi_hex, &xi);
3136
3137    // Generate keys from xi
3138    const result = MLDSA65.newKeyFromSeed(&xi);
3139    const pk = result.pk;
3140    const sk = result.sk;
3141
3142    // Verify public key starts with expected bytes
3143    const pk_bytes = pk.toBytes();
3144
3145    var expected_pk_start: [32]u8 = undefined;
3146    try hexToBytes(pk_hex_start, &expected_pk_start);
3147
3148    // Check first 32 bytes of public key match
3149    try testing.expectEqualSlices(u8, &expected_pk_start, pk_bytes[0..32]);
3150
3151    // Parse message
3152    var msg: [16]u8 = undefined;
3153    try hexToBytes(msg_hex, &msg);
3154
3155    // Sign the message
3156    const kp = try MLDSA65.KeyPair.fromSecretKey(sk);
3157    const sig = try kp.sign(&msg, null);
3158
3159    // Verify the signature
3160    try sig.verify(&msg, kp.public_key);
3161}
3162
3163test "ML-DSA-87 KAT test vector 0" {
3164    // Test vector from NIST ML-DSA KAT (count = 0)
3165    // xi is the seed for key generation (Algorithm 1, line 1)
3166    const xi_hex = "f696484048ec21f96cf50a56d0759c448f3779752f0383d37449690694cf7a68";
3167    const pk_hex_start = "bc89b367d4288f47c71a74679d0fcffbe041de41b5da2f5fc66d8e28c5899494";
3168    const msg_hex = "6dbbc4375136df3b07f7c70e639e223e";
3169
3170    // Parse xi (32-byte seed for key generation)
3171    var xi: [32]u8 = undefined;
3172    try hexToBytes(xi_hex, &xi);
3173
3174    // Generate keys from xi
3175    const result = MLDSA87.newKeyFromSeed(&xi);
3176    const pk = result.pk;
3177    const sk = result.sk;
3178
3179    // Verify public key starts with expected bytes
3180    const pk_bytes = pk.toBytes();
3181
3182    var expected_pk_start: [32]u8 = undefined;
3183    try hexToBytes(pk_hex_start, &expected_pk_start);
3184
3185    // Check first 32 bytes of public key match
3186    try testing.expectEqualSlices(u8, &expected_pk_start, pk_bytes[0..32]);
3187
3188    // Parse message
3189    var msg: [16]u8 = undefined;
3190    try hexToBytes(msg_hex, &msg);
3191
3192    // Sign the message
3193    const kp = try MLDSA87.KeyPair.fromSecretKey(sk);
3194    const sig = try kp.sign(&msg, null);
3195
3196    // Verify the signature
3197    try sig.verify(&msg, kp.public_key);
3198}
3199
3200test "KeyPair API - generate and sign" {
3201    // Test the new KeyPair API with random generation
3202    const kp = MLDSA44.KeyPair.generate();
3203    const msg = "Test message for KeyPair API";
3204
3205    // Sign with deterministic mode (no noise)
3206    const sig = try kp.sign(msg, null);
3207
3208    // Verify using Signature.verify API
3209    try sig.verify(msg, kp.public_key);
3210}
3211
3212test "KeyPair API - generateDeterministic" {
3213    // Test deterministic key generation
3214    const seed = [_]u8{42} ** 32;
3215    const kp1 = try MLDSA44.KeyPair.generateDeterministic(seed);
3216    const kp2 = try MLDSA44.KeyPair.generateDeterministic(seed);
3217
3218    // Same seed should produce same keys
3219    const pk1_bytes = kp1.public_key.toBytes();
3220    const pk2_bytes = kp2.public_key.toBytes();
3221    try testing.expectEqualSlices(u8, &pk1_bytes, &pk2_bytes);
3222}
3223
3224test "KeyPair API - fromSecretKey" {
3225    // Generate a key pair
3226    const kp1 = MLDSA44.KeyPair.generate();
3227
3228    // Derive public key from secret key
3229    const kp2 = try MLDSA44.KeyPair.fromSecretKey(kp1.secret_key);
3230
3231    // Public keys should match
3232    const pk1_bytes = kp1.public_key.toBytes();
3233    const pk2_bytes = kp2.public_key.toBytes();
3234    try testing.expectEqualSlices(u8, &pk1_bytes, &pk2_bytes);
3235}
3236
3237test "Signature verification with noise" {
3238    // Test signing with randomness (hedged signatures)
3239    const kp = MLDSA65.KeyPair.generate();
3240    const msg = "Message to be signed with randomness";
3241
3242    // Create some noise
3243    const noise = [_]u8{ 1, 2, 3, 4, 5 } ++ [_]u8{0} ** 27;
3244
3245    // Sign with noise
3246    const sig = try kp.sign(msg, noise);
3247
3248    // Verify should still work
3249    try sig.verify(msg, kp.public_key);
3250}
3251
3252test "Signature verification failure" {
3253    // Test that invalid signatures are rejected
3254    const kp = MLDSA44.KeyPair.generate();
3255    const msg = "Original message";
3256    const sig = try kp.sign(msg, null);
3257
3258    // Verify with wrong message should fail
3259    const wrong_msg = "Different message";
3260    try testing.expectError(error.SignatureVerificationFailed, sig.verify(wrong_msg, kp.public_key));
3261}
3262
3263test "Streaming API - sign and verify" {
3264    const seed = [_]u8{0x55} ** 32;
3265    const kp = try MLDSA44.KeyPair.generateDeterministic(seed);
3266
3267    const msg = "Test message for streaming API";
3268
3269    // Sign using streaming API
3270    var signer = try kp.signer(null);
3271    signer.update(msg);
3272    const sig = signer.finalize();
3273
3274    // Verify using streaming API
3275    var verifier = try sig.verifier(kp.public_key);
3276    verifier.update(msg);
3277    try verifier.verify();
3278}
3279
3280test "Streaming API - chunked message" {
3281    const seed = [_]u8{0x66} ** 32;
3282    const kp = try MLDSA44.KeyPair.generateDeterministic(seed);
3283
3284    // Create a message in chunks
3285    const chunk1 = "Hello, ";
3286    const chunk2 = "streaming ";
3287    const chunk3 = "world!";
3288    const full_msg = chunk1 ++ chunk2 ++ chunk3;
3289
3290    // Sign with chunks
3291    var signer = try kp.signer(null);
3292    signer.update(chunk1);
3293    signer.update(chunk2);
3294    signer.update(chunk3);
3295    const sig_chunked = signer.finalize();
3296
3297    // Sign with full message for comparison
3298    var signer2 = try kp.signer(null);
3299    signer2.update(full_msg);
3300    const sig_full = signer2.finalize();
3301
3302    // Signatures should be identical
3303    try testing.expectEqualSlices(u8, &sig_chunked.toBytes(), &sig_full.toBytes());
3304
3305    // Verify with chunks
3306    const sig = sig_chunked;
3307    var verifier = try sig.verifier(kp.public_key);
3308    verifier.update(chunk1);
3309    verifier.update(chunk2);
3310    verifier.update(chunk3);
3311    try verifier.verify();
3312}
3313
3314test "Streaming API - large message" {
3315    const seed = [_]u8{0x77} ** 32;
3316    const kp = try MLDSA44.KeyPair.generateDeterministic(seed);
3317
3318    // Create a large message (1MB)
3319    const chunk_size = 4096;
3320    const num_chunks = 256;
3321    var chunk: [chunk_size]u8 = undefined;
3322    for (0..chunk_size) |i| {
3323        chunk[i] = @intCast(i % 256);
3324    }
3325
3326    // Sign streaming
3327    var signer = try kp.signer(null);
3328    for (0..num_chunks) |_| {
3329        signer.update(&chunk);
3330    }
3331    const sig = signer.finalize();
3332
3333    // Verify streaming
3334    var verifier = try sig.verifier(kp.public_key);
3335    for (0..num_chunks) |_| {
3336        verifier.update(&chunk);
3337    }
3338    try verifier.verify();
3339}
3340
3341test "Streaming API - all parameter sets" {
3342    const test_msg = "Streaming test for all ML-DSA parameter sets";
3343
3344    // ML-DSA-44
3345    {
3346        const seed = [_]u8{0x44} ** 32;
3347        const kp = try MLDSA44.KeyPair.generateDeterministic(seed);
3348        var signer = try kp.signer(null);
3349        signer.update(test_msg);
3350        const sig = signer.finalize();
3351        var verifier = try sig.verifier(kp.public_key);
3352        verifier.update(test_msg);
3353        try verifier.verify();
3354    }
3355
3356    // ML-DSA-65
3357    {
3358        const seed = [_]u8{0x65} ** 32;
3359        const kp = try MLDSA65.KeyPair.generateDeterministic(seed);
3360        var signer = try kp.signer(null);
3361        signer.update(test_msg);
3362        const sig = signer.finalize();
3363        var verifier = try sig.verifier(kp.public_key);
3364        verifier.update(test_msg);
3365        try verifier.verify();
3366    }
3367
3368    // ML-DSA-87
3369    {
3370        const seed = [_]u8{0x87} ** 32;
3371        const kp = try MLDSA87.KeyPair.generateDeterministic(seed);
3372        var signer = try kp.signer(null);
3373        signer.update(test_msg);
3374        const sig = signer.finalize();
3375        var verifier = try sig.verifier(kp.public_key);
3376        verifier.update(test_msg);
3377        try verifier.verify();
3378    }
3379}
3380
3381/// Extended Euclidian Algorithm
3382/// Only meant to be used on comptime values; correctness matters, performance doesn't.
3383fn extendedEuclidean(comptime T: type, comptime a_: T, comptime b_: T) struct { gcd: T, x: T, y: T } {
3384    var a = a_;
3385    var b = b_;
3386    var x0: T = 1;
3387    var x1: T = 0;
3388    var y0: T = 0;
3389    var y1: T = 1;
3390
3391    while (b != 0) {
3392        const q = @divTrunc(a, b);
3393        const temp_a = a;
3394        a = b;
3395        b = temp_a - q * b;
3396
3397        const temp_x = x0;
3398        x0 = x1;
3399        x1 = temp_x - q * x1;
3400
3401        const temp_y = y0;
3402        y0 = y1;
3403        y1 = temp_y - q * y1;
3404    }
3405
3406    return .{ .gcd = a, .x = x0, .y = y0 };
3407}
3408
3409/// Modular inversion: computes a^(-1) mod p
3410/// Requires gcd(a,p) = 1. The result is normalized to the range [0, p).
3411fn modularInverse(comptime T: type, comptime a: T, comptime p: T) T {
3412    // Use a signed type for EEA computation
3413    const type_info = @typeInfo(T);
3414    const SignedT = if (type_info == .int and type_info.int.signedness == .unsigned)
3415        std.meta.Int(.signed, type_info.int.bits)
3416    else
3417        T;
3418
3419    const a_signed = @as(SignedT, @intCast(a));
3420    const p_signed = @as(SignedT, @intCast(p));
3421
3422    const r = extendedEuclidean(SignedT, a_signed, p_signed);
3423    assert(r.gcd == 1);
3424
3425    // Normalize result to [0, p)
3426    var result = r.x;
3427    while (result < 0) {
3428        result += p_signed;
3429    }
3430
3431    return @intCast(result);
3432}
3433
3434/// Modular exponentiation: computes a^s mod p using square-and-multiply algorithm.
3435fn modularPow(comptime T: type, comptime a: T, s: T, comptime p: T) T {
3436    const type_info = @typeInfo(T);
3437    const bits = type_info.int.bits;
3438    const WideT = std.meta.Int(.unsigned, bits * 2);
3439
3440    var ret: T = 1;
3441    var base: T = a;
3442    var exp = s;
3443
3444    while (exp > 0) {
3445        if (exp & 1 == 1) {
3446            ret = @intCast((@as(WideT, ret) * @as(WideT, base)) % p);
3447        }
3448        base = @intCast((@as(WideT, base) * @as(WideT, base)) % p);
3449        exp >>= 1;
3450    }
3451
3452    return ret;
3453}
3454
3455/// Creates an all-ones or all-zeros mask from a single bit value.
3456/// Returns all 1s (0xFF...FF) if bit == 1, all 0s if bit == 0.
3457fn bitMask(comptime T: type, bit: T) T {
3458    const type_info = @typeInfo(T);
3459    if (type_info != .int or type_info.int.signedness != .unsigned) {
3460        @compileError("bitMask requires an unsigned integer type");
3461    }
3462    return -%bit;
3463}
3464
3465/// Creates a mask from the sign bit of a signed integer.
3466/// Returns all 1s (0xFF...FF) if x < 0, all 0s if x >= 0.
3467fn signMask(comptime T: type, x: T) std.meta.Int(.unsigned, @typeInfo(T).int.bits) {
3468    const type_info = @typeInfo(T);
3469    if (type_info != .int) {
3470        @compileError("signMask requires an integer type");
3471    }
3472
3473    const bits = type_info.int.bits;
3474    const SignedT = std.meta.Int(.signed, bits);
3475
3476    // Convert to signed if needed, arithmetic right shift to propagate sign bit
3477    const x_signed: SignedT = if (type_info.int.signedness == .signed) x else @bitCast(x);
3478    const shifted = x_signed >> (bits - 1);
3479    return @bitCast(shifted);
3480}
3481
3482/// Montgomery reduction: for input x, returns y where y ≡ x*R^(-1) (mod q).
3483/// This is a generic implementation parameterized by the modulus q, its inverse qInv,
3484/// the Montgomery constant R, and the result bound.
3485///
3486/// For ML-DSA: R = 2^32, returns y < 2q
3487/// For ML-KEM: R = 2^16, returns y in range (-q, q)
3488fn montgomeryReduce(
3489    comptime InT: type,
3490    comptime OutT: type,
3491    comptime q: comptime_int,
3492    comptime qInv: comptime_int,
3493    comptime r_bits: comptime_int,
3494    x: InT,
3495) OutT {
3496    const mask = (@as(InT, 1) << r_bits) - 1;
3497    const m_full = (x *% qInv) & mask;
3498    const m: OutT = @truncate(m_full);
3499
3500    const yR = x -% @as(InT, m) * @as(InT, q);
3501    const y_shifted = @as(std.meta.Int(.unsigned, @typeInfo(InT).Int.bits), @bitCast(yR)) >> r_bits;
3502    return @bitCast(@as(std.meta.Int(.unsigned, @typeInfo(OutT).Int.bits), @truncate(y_shifted)));
3503}
3504
3505/// Uniform sampling using SHAKE-128 with rejection sampling.
3506/// Samples polynomial coefficients uniformly from [0, q) using rejection sampling.
3507///
3508/// Parameters:
3509/// - PolyType: The polynomial type to return
3510/// - q: Modulus
3511/// - bits_per_coef: Number of bits per coefficient (12 or 23)
3512/// - n: Number of coefficients
3513/// - seed: Random seed
3514/// - domain_sep: Domain separation bytes (appended to seed)
3515fn sampleUniformRejection(
3516    comptime PolyType: type,
3517    comptime q: comptime_int,
3518    comptime bits_per_coef: comptime_int,
3519    comptime n: comptime_int,
3520    seed: []const u8,
3521    domain_sep: []const u8,
3522) PolyType {
3523    var h = sha3.Shake128.init(.{});
3524    h.update(seed);
3525    h.update(domain_sep);
3526
3527    const buf_len = sha3.Shake128.block_length; // 168 bytes
3528    var buf: [buf_len]u8 = undefined;
3529
3530    var ret: PolyType = undefined;
3531    var coef_idx: usize = 0;
3532
3533    if (bits_per_coef == 12) {
3534        // ML-KEM path: pack 2 coefficients per 3 bytes (12 bits each)
3535        outer: while (true) {
3536            h.squeeze(&buf);
3537
3538            var j: usize = 0;
3539            while (j < buf_len) : (j += 3) {
3540                const b0 = @as(u16, buf[j]);
3541                const b1 = @as(u16, buf[j + 1]);
3542                const b2 = @as(u16, buf[j + 2]);
3543
3544                const ts: [2]u16 = .{
3545                    b0 | ((b1 & 0xf) << 8),
3546                    (b1 >> 4) | (b2 << 4),
3547                };
3548
3549                inline for (ts) |t| {
3550                    if (t < q) {
3551                        ret.cs[coef_idx] = @intCast(t);
3552                        coef_idx += 1;
3553                        if (coef_idx == n) break :outer;
3554                    }
3555                }
3556            }
3557        }
3558    } else if (bits_per_coef == 23) {
3559        // ML-DSA path: 1 coefficient per 3 bytes (23 bits)
3560        while (coef_idx < n) {
3561            h.squeeze(&buf);
3562
3563            var j: usize = 0;
3564            while (j < buf_len and coef_idx < n) : (j += 3) {
3565                const t = (@as(u32, buf[j]) |
3566                    (@as(u32, buf[j + 1]) << 8) |
3567                    (@as(u32, buf[j + 2]) << 16)) & 0x7fffff;
3568
3569                if (t < q) {
3570                    ret.cs[coef_idx] = @intCast(t);
3571                    coef_idx += 1;
3572                }
3573            }
3574        }
3575    } else {
3576        @compileError("bits_per_coef must be 12 or 23");
3577    }
3578
3579    return ret;
3580}
3581
3582test "bitMask and signMask helpers" {
3583    try testing.expectEqual(@as(u32, 0x00000000), bitMask(u32, 0));
3584    try testing.expectEqual(@as(u32, 0xFFFFFFFF), bitMask(u32, 1));
3585    try testing.expectEqual(@as(u8, 0x00), bitMask(u8, 0));
3586    try testing.expectEqual(@as(u8, 0xFF), bitMask(u8, 1));
3587    try testing.expectEqual(@as(u64, 0x0000000000000000), bitMask(u64, 0));
3588    try testing.expectEqual(@as(u64, 0xFFFFFFFFFFFFFFFF), bitMask(u64, 1));
3589
3590    try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -1));
3591    try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -100));
3592    try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 0));
3593    try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 1));
3594    try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 100));
3595
3596    try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(u32, 0x80000000)); // MSB set
3597    try testing.expectEqual(@as(u32, 0x00000000), signMask(u32, 0x7FFFFFFF)); // MSB clear
3598}