Commit 4ea4728084

Frank Denis <124872+jedisct1@users.noreply.github.com>
2025-11-18 16:39:58
Align ML-KEM code with ML-DSA (#25964)
This will facilitate maintainance and code sharing between primitives.
1 parent 73f863a
Changed files (1)
lib
std
crypto
lib/std/crypto/ml_kem.zig
@@ -105,19 +105,20 @@ const crypto = std.crypto;
 const errors = std.crypto.errors;
 const math = std.math;
 const mem = std.mem;
-const RndGen = std.Random.DefaultPrng;
 const sha3 = crypto.hash.sha3;
 
-// Q is the parameter q ≡ 3329 = 2¹¹ + 2¹⁰ + 2⁸ + 1.
+const RndGen = std.Random.DefaultPrng;
+
+// Q is the modulus q ≡ 3329 = 2¹¹ + 2¹⁰ + 2⁸ + 1
 const Q: i16 = 3329;
 
-// Montgomery R
+// Montgomery R = 2^16 mod Q (for Montgomery multiplication)
 const R: i32 = 1 << 16;
 
-// Parameter n, degree of polynomials.
+// N is the degree of polynomials (polynomial ring dimension)
 const N: usize = 256;
 
-// Size of "small" vectors used in encryption blinds.
+// eta2 is the size of "small" vectors used in encryption blinds
 const eta2: u8 = 2;
 
 const Params = struct {
@@ -215,7 +216,7 @@ fn Kyber(comptime p: Params) type {
         pub const ciphertext_length = Poly.compressedSize(p.du) * p.k + Poly.compressedSize(p.dv);
 
         const Self = @This();
-        const V = Vec(p.k);
+        const V = PolyVec(p.k);
         const M = Mat(p.k);
 
         /// Length (in bytes) of a shared secret.
@@ -241,7 +242,7 @@ fn Kyber(comptime p: Params) type {
             hpk: [h_length]u8, // H(pk)
 
             /// Size of a serialized representation of the key, in bytes.
-            pub const bytes_length = InnerPk.bytes_length;
+            pub const encoded_length = InnerPk.encoded_length;
 
             /// Generates a shared secret, and encapsulates it for the public key.
             /// If `seed` is `null`, a random seed is used. This is recommended.
@@ -289,14 +290,14 @@ fn Kyber(comptime p: Params) type {
             }
 
             /// Serializes the key into a byte array.
-            pub fn toBytes(pk: PublicKey) [bytes_length]u8 {
+            pub fn toBytes(pk: PublicKey) [encoded_length]u8 {
                 return pk.pk.toBytes();
             }
 
             /// Deserializes the key from a byte array.
-            pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!PublicKey {
+            pub fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!PublicKey {
                 var ret: PublicKey = undefined;
-                ret.pk = try InnerPk.fromBytes(buf[0..InnerPk.bytes_length]);
+                ret.pk = try InnerPk.fromBytes(buf[0..InnerPk.encoded_length]);
                 sha3.Sha3_256.hash(buf, &ret.hpk, .{});
                 return ret;
             }
@@ -310,8 +311,8 @@ fn Kyber(comptime p: Params) type {
             z: [shared_length]u8,
 
             /// Size of a serialized representation of the key, in bytes.
-            pub const bytes_length: usize =
-                InnerSk.bytes_length + InnerPk.bytes_length + h_length + shared_length;
+            pub const encoded_length: usize =
+                InnerSk.encoded_length + InnerPk.encoded_length + h_length + shared_length;
 
             /// Decapsulates the shared secret within ct using the private key.
             pub fn decaps(sk: SecretKey, ct: *const [ciphertext_length]u8) ![shared_length]u8 {
@@ -346,18 +347,18 @@ fn Kyber(comptime p: Params) type {
             }
 
             /// Serializes the key into a byte array.
-            pub fn toBytes(sk: SecretKey) [bytes_length]u8 {
+            pub fn toBytes(sk: SecretKey) [encoded_length]u8 {
                 return sk.sk.toBytes() ++ sk.pk.toBytes() ++ sk.hpk ++ sk.z;
             }
 
             /// Deserializes the key from a byte array.
-            pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!SecretKey {
+            pub fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!SecretKey {
                 var ret: SecretKey = undefined;
                 comptime var s: usize = 0;
-                ret.sk = InnerSk.fromBytes(buf[s .. s + InnerSk.bytes_length]);
-                s += InnerSk.bytes_length;
-                ret.pk = try InnerPk.fromBytes(buf[s .. s + InnerPk.bytes_length]);
-                s += InnerPk.bytes_length;
+                ret.sk = InnerSk.fromBytes(buf[s .. s + InnerSk.encoded_length]);
+                s += InnerSk.encoded_length;
+                ret.pk = try InnerPk.fromBytes(buf[s .. s + InnerPk.encoded_length]);
+                s += InnerPk.encoded_length;
                 ret.hpk = buf[s..][0..h_length].*;
                 s += h_length;
                 ret.z = buf[s..][0..shared_length].*;
@@ -418,7 +419,7 @@ fn Kyber(comptime p: Params) type {
             // Cached values
             aT: M,
 
-            const bytes_length = V.bytes_length + 32;
+            const encoded_length = V.encoded_length + 32;
 
             fn encrypt(
                 pk: InnerPk,
@@ -436,7 +437,7 @@ fn Kyber(comptime p: Params) type {
                     // Note that coefficients of r are bounded by q and those of Aᵀ
                     // are bounded by 4.5q and so their product is bounded by 2¹⁵q
                     // as required for multiplication.
-                    u.ps[i] = pk.aT.vs[i].dotHat(rh);
+                    u.ps[i] = pk.aT.rows[i].dotHat(rh);
                 }
 
                 // Aᵀ and r were not in Montgomery form, so the Montgomery
@@ -451,14 +452,14 @@ fn Kyber(comptime p: Params) type {
                 return u.compress(p.du) ++ v.compress(p.dv);
             }
 
-            fn toBytes(pk: InnerPk) [bytes_length]u8 {
+            fn toBytes(pk: InnerPk) [encoded_length]u8 {
                 return pk.th.toBytes() ++ pk.rho;
             }
 
-            fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!InnerPk {
+            fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!InnerPk {
                 var ret: InnerPk = undefined;
 
-                const th_bytes = buf[0..V.bytes_length];
+                const th_bytes = buf[0..V.encoded_length];
                 ret.th = V.fromBytes(th_bytes).normalize();
 
                 if (p.ml_kem) {
@@ -468,7 +469,7 @@ fn Kyber(comptime p: Params) type {
                     }
                 }
 
-                ret.rho = buf[V.bytes_length..bytes_length].*;
+                ret.rho = buf[V.encoded_length..encoded_length].*;
                 ret.aT = M.uniform(ret.rho, true);
                 return ret;
             }
@@ -477,7 +478,7 @@ fn Kyber(comptime p: Params) type {
         // Private key of the inner PKE
         const InnerSk = struct {
             sh: V, // NTT(s), normalized
-            const bytes_length = V.bytes_length;
+            const encoded_length = V.encoded_length;
 
             fn decrypt(sk: InnerSk, ct: *const [ciphertext_length]u8) [inner_plaintext_length]u8 {
                 const u = V.decompress(p.du, ct[0..comptime V.compressedSize(p.du)]);
@@ -491,11 +492,11 @@ fn Kyber(comptime p: Params) type {
                     .normalize().compress(1);
             }
 
-            fn toBytes(sk: InnerSk) [bytes_length]u8 {
+            fn toBytes(sk: InnerSk) [encoded_length]u8 {
                 return sk.sh.toBytes();
             }
 
-            fn fromBytes(buf: *const [bytes_length]u8) InnerSk {
+            fn fromBytes(buf: *const [encoded_length]u8) InnerSk {
                 var ret: InnerSk = undefined;
                 ret.sh = V.fromBytes(buf).normalize();
                 return ret;
@@ -516,7 +517,7 @@ fn Kyber(comptime p: Params) type {
             // Sample secret vector s.
             sk.sh = V.noise(p.eta1, 0, sigma).ntt().normalize();
 
-            const eh = Vec(p.k).noise(p.eta1, p.k, sigma).ntt(); // sample blind e.
+            const eh = PolyVec(p.k).noise(p.eta1, p.k, sigma).ntt(); // sample blind e.
             var th: V = undefined;
 
             // Next, we compute t = A s + e.
@@ -528,7 +529,7 @@ fn Kyber(comptime p: Params) type {
                 // multiplications in the inner product added a factor R⁻¹ which
                 // we'll cancel out with toMont().  This will also ensure the
                 // coefficients of th are bounded in absolute value by q.
-                th.ps[i] = pk.aT.vs[i].dotHat(sk.sh).toMont();
+                th.ps[i] = pk.aT.rows[i].dotHat(sk.sh).toMont();
             }
 
             pk.th = th.add(eh).normalize(); // bounded by 8q
@@ -565,7 +566,6 @@ const zetas = computeZetas();
 // not enough, the other coefficient is reduced as well.
 //
 // This is actually optimal, as proven in https://eprint.iacr.org/2020/1377.pdf
-// TODO generate comptime?
 const inv_ntt_reductions = [_]i16{
     -1, // after layer 1
     -1, // after layer 2
@@ -634,31 +634,8 @@ test "invNTTReductions bounds" {
     }
 }
 
-// Extended euclidean algorithm.
-//
-// For a, b finds x, y such that  x a + y b = gcd(a, b). Used to compute
-// modular inverse.
-fn eea(a: anytype, b: @TypeOf(a)) EeaResult(@TypeOf(a)) {
-    if (a == 0) {
-        return .{ .gcd = b, .x = 0, .y = 1 };
-    }
-    const r = eea(@rem(b, a), a);
-    return .{ .gcd = r.gcd, .x = r.y - @divTrunc(b, a) * r.x, .y = r.x };
-}
-
-fn EeaResult(comptime T: type) type {
-    return struct { gcd: T, x: T, y: T };
-}
-
-// Returns least common multiple of a and b.
-fn lcm(a: anytype, b: @TypeOf(a)) @TypeOf(a) {
-    const r = eea(a, b);
-    return a * b / r.gcd;
-}
-
-// Invert modulo p.
 fn invertMod(a: anytype, p: @TypeOf(a)) @TypeOf(a) {
-    const r = eea(a, p);
+    const r = extendedEuclidean(@TypeOf(a), a, p);
     assert(r.gcd == 1);
     return r.x;
 }
@@ -788,31 +765,12 @@ test "Test csubq" {
     }
 }
 
-// Compute a^s mod p.
-fn mpow(a: anytype, s: @TypeOf(a), p: @TypeOf(a)) @TypeOf(a) {
-    var ret: @TypeOf(a) = 1;
-    var s2 = s;
-    var a2 = a;
-
-    while (true) {
-        if (s2 & 1 == 1) {
-            ret = @mod(ret * a2, p);
-        }
-        s2 >>= 1;
-        if (s2 == 0) {
-            break;
-        }
-        a2 = @mod(a2 * a2, p);
-    }
-    return ret;
-}
-
 // Computes zetas table used by ntt and invNTT.
 fn computeZetas() [128]i16 {
     @setEvalBranchQuota(10000);
     var ret: [128]i16 = undefined;
     for (&ret, 0..) |*r, i| {
-        const t = @as(i16, @intCast(mpow(@as(i32, zeta), @bitReverse(@as(u7, @intCast(i))), Q)));
+        const t = @as(i16, @intCast(modularPow(i32, zeta, @bitReverse(@as(u7, @intCast(i))), Q)));
         r.* = csubq(feBarrettReduce(feToMont(t)));
     }
     return ret;
@@ -828,9 +786,10 @@ fn computeZetas() [128]i16 {
 const Poly = struct {
     cs: [N]i16,
 
-    const bytes_length = N / 2 * 3;
+    const encoded_length = N / 2 * 3;
     const zero: Poly = .{ .cs = .{0} ** N };
 
+    // Add two polynomials (coefficients not normalized)
     fn add(a: Poly, b: Poly) Poly {
         var ret: Poly = undefined;
         for (0..N) |i| {
@@ -839,6 +798,7 @@ const Poly = struct {
         return ret;
     }
 
+    // Subtract two polynomials (coefficients not normalized)
     fn sub(a: Poly, b: Poly) Poly {
         var ret: Poly = undefined;
         for (0..N) |i| {
@@ -847,25 +807,6 @@ const Poly = struct {
         return ret;
     }
 
-    // For testing, generates a random polynomial with for each
-    // coefficient |x| ≤ q.
-    fn randAbsLeqQ(rnd: anytype) Poly {
-        var ret: Poly = undefined;
-        for (0..N) |i| {
-            ret.cs[i] = rnd.random().intRangeAtMost(i16, -Q, Q);
-        }
-        return ret;
-    }
-
-    // For testing, generates a random normalized polynomial.
-    fn randNormalized(rnd: anytype) Poly {
-        var ret: Poly = undefined;
-        for (0..N) |i| {
-            ret.cs[i] = rnd.random().intRangeLessThan(i16, 0, Q);
-        }
-        return ret;
-    }
-
     // Executes a forward "NTT" on p.
     //
     // Assumes the coefficients are in absolute value ≤q.  The resulting
@@ -1054,7 +995,7 @@ const Poly = struct {
         var in_off: usize = 0;
         var out_off: usize = 0;
 
-        const batch_size: usize = comptime lcm(@as(i16, d), 8);
+        const batch_size: usize = comptime math.lcm(d, 8);
         const in_batch_size: usize = comptime batch_size / d;
         const out_batch_size: usize = comptime batch_size / 8;
 
@@ -1118,7 +1059,7 @@ const Poly = struct {
         var in_off: usize = 0;
         var out_off: usize = 0;
 
-        const batch_size: usize = comptime lcm(@as(i16, d), 8);
+        const batch_size: usize = comptime math.lcm(d, 8);
         const in_batch_size: usize = comptime batch_size / 8;
         const out_batch_size: usize = comptime batch_size / d;
 
@@ -1275,53 +1216,23 @@ const Poly = struct {
         return ret;
     }
 
-    // Sample p uniformly from the given seed and x and y coordinates.
     fn uniform(seed: [32]u8, x: u8, y: u8) Poly {
-        var h = sha3.Shake128.init(.{});
-        const suffix: [2]u8 = .{ x, y };
-        h.update(&seed);
-        h.update(&suffix);
-
-        const buf_len = sha3.Shake128.block_length; // rate SHAKE-128
-        var buf: [buf_len]u8 = undefined;
-
-        var ret: Poly = undefined;
-        var i: usize = 0; // index into ret.cs
-        outer: while (true) {
-            h.squeeze(&buf);
-
-            var j: usize = 0; // index into buf
-            while (j < buf_len) : (j += 3) {
-                const b0 = @as(u16, buf[j]);
-                const b1 = @as(u16, buf[j + 1]);
-                const b2 = @as(u16, buf[j + 2]);
-
-                const ts: [2]u16 = .{
-                    b0 | ((b1 & 0xf) << 8),
-                    (b1 >> 4) | (b2 << 4),
-                };
-
-                inline for (ts) |t| {
-                    if (t < Q) {
-                        ret.cs[i] = @as(i16, @intCast(t));
-                        i += 1;
-
-                        if (i == N) {
-                            break :outer;
-                        }
-                    }
-                }
-            }
-        }
-
-        return ret;
+        const domain_sep: [2]u8 = .{ x, y };
+        return sampleUniformRejection(
+            Poly,
+            Q,
+            12,
+            N,
+            &seed,
+            &domain_sep,
+        );
     }
 
     // Packs p.
     //
     // Assumes p is normalized (and not just Barrett reduced).
-    fn toBytes(p: Poly) [bytes_length]u8 {
-        var ret: [bytes_length]u8 = undefined;
+    fn toBytes(p: Poly) [encoded_length]u8 {
+        var ret: [encoded_length]u8 = undefined;
         for (0..comptime N / 2) |i| {
             const t0 = @as(u16, @intCast(p.cs[2 * i]));
             const t1 = @as(u16, @intCast(p.cs[2 * i + 1]));
@@ -1335,7 +1246,7 @@ const Poly = struct {
     // Unpacks a Poly from buf.
     //
     // p will not be normalized; instead 0 ≤ p[i] < 4096.
-    fn fromBytes(buf: *const [bytes_length]u8) Poly {
+    fn fromBytes(buf: *const [encoded_length]u8) Poly {
         var ret: Poly = undefined;
         for (0..comptime N / 2) |i| {
             const b0 = @as(i16, buf[3 * i]);
@@ -1348,71 +1259,65 @@ const Poly = struct {
     }
 };
 
-// A vector of K polynomials.
-fn Vec(comptime K: u8) type {
+// A vector of k polynomials.
+fn PolyVec(comptime k: u8) type {
     return struct {
-        ps: [K]Poly,
+        ps: [k]Poly,
 
         const Self = @This();
-        const bytes_length = K * Poly.bytes_length;
+        const encoded_length = k * Poly.encoded_length;
 
         fn compressedSize(comptime d: u8) usize {
-            return Poly.compressedSize(d) * K;
+            return Poly.compressedSize(d) * k;
         }
 
-        fn ntt(a: Self) Self {
+        /// Apply unary operation to each polynomial
+        fn map(v: Self, comptime op: fn (Poly) Poly) Self {
             var ret: Self = undefined;
-            for (0..K) |i| {
-                ret.ps[i] = a.ps[i].ntt();
+            inline for (0..k) |i| {
+                ret.ps[i] = op(v.ps[i]);
             }
             return ret;
         }
 
-        fn invNTT(a: Self) Self {
+        /// Apply binary operation pairwise
+        fn mapBinary(a: Self, b: Self, comptime op: fn (Poly, Poly) Poly) Self {
             var ret: Self = undefined;
-            for (0..K) |i| {
-                ret.ps[i] = a.ps[i].invNTT();
+            inline for (0..k) |i| {
+                ret.ps[i] = op(a.ps[i], b.ps[i]);
             }
             return ret;
         }
 
-        fn normalize(a: Self) Self {
-            var ret: Self = undefined;
-            for (0..K) |i| {
-                ret.ps[i] = a.ps[i].normalize();
-            }
-            return ret;
+        fn ntt(v: Self) Self {
+            return map(v, Poly.ntt);
         }
 
-        fn barrettReduce(a: Self) Self {
-            var ret: Self = undefined;
-            for (0..K) |i| {
-                ret.ps[i] = a.ps[i].barrettReduce();
-            }
-            return ret;
+        fn invNTT(v: Self) Self {
+            return map(v, Poly.invNTT);
+        }
+
+        fn normalize(v: Self) Self {
+            return map(v, Poly.normalize);
+        }
+
+        fn barrettReduce(v: Self) Self {
+            return map(v, Poly.barrettReduce);
         }
 
         fn add(a: Self, b: Self) Self {
-            var ret: Self = undefined;
-            for (0..K) |i| {
-                ret.ps[i] = a.ps[i].add(b.ps[i]);
-            }
-            return ret;
+            return mapBinary(a, b, Poly.add);
         }
 
         fn sub(a: Self, b: Self) Self {
-            var ret: Self = undefined;
-            for (0..K) |i| {
-                ret.ps[i] = a.ps[i].sub(b.ps[i]);
-            }
-            return ret;
+            return mapBinary(a, b, Poly.sub);
         }
 
         // Samples v[i] from centered binomial distribution with the given η,
         // seed and nonce+i.
         fn noise(comptime eta: u8, nonce: u8, seed: *const [32]u8) Self {
             var ret: Self = undefined;
-            for (0..K) |i| {
+            for (0..k) |i| {
                 ret.ps[i] = Poly.noise(eta, nonce + @as(u8, @intCast(i)), seed);
             }
             return ret;
@@ -1428,7 +1333,7 @@ fn Vec(comptime K: u8) type {
         // of the Montgomery factor.
         fn dotHat(a: Self, b: Self) Poly {
             var ret: Poly = Poly.zero;
-            for (0..K) |i| {
+            for (0..k) |i| {
                 ret = ret.add(a.ps[i].mulHat(b.ps[i]));
             }
             return ret;
@@ -1437,7 +1342,7 @@ fn Vec(comptime K: u8) type {
         fn compress(v: Self, comptime d: u8) [compressedSize(d)]u8 {
             const cs = comptime Poly.compressedSize(d);
             var ret: [compressedSize(d)]u8 = undefined;
-            inline for (0..K) |i| {
+            inline for (0..k) |i| {
                 ret[i * cs .. (i + 1) * cs].* = v.ps[i].compress(d);
             }
             return ret;
@@ -1446,27 +1351,27 @@ fn Vec(comptime K: u8) type {
         fn decompress(comptime d: u8, buf: *const [compressedSize(d)]u8) Self {
             const cs = comptime Poly.compressedSize(d);
             var ret: Self = undefined;
-            inline for (0..K) |i| {
+            inline for (0..k) |i| {
                 ret.ps[i] = Poly.decompress(d, buf[i * cs .. (i + 1) * cs]);
             }
             return ret;
         }
 
         /// Serializes the key into a byte array.
-        fn toBytes(v: Self) [bytes_length]u8 {
-            var ret: [bytes_length]u8 = undefined;
-            inline for (0..K) |i| {
-                ret[i * Poly.bytes_length .. (i + 1) * Poly.bytes_length].* = v.ps[i].toBytes();
+        fn toBytes(v: Self) [encoded_length]u8 {
+            var ret: [encoded_length]u8 = undefined;
+            inline for (0..k) |i| {
+                ret[i * Poly.encoded_length .. (i + 1) * Poly.encoded_length].* = v.ps[i].toBytes();
             }
             return ret;
         }
 
         /// Deserializes the key from a byte array.
-        fn fromBytes(buf: *const [bytes_length]u8) Self {
+        fn fromBytes(buf: *const [encoded_length]u8) Self {
             var ret: Self = undefined;
-            inline for (0..K) |i| {
+            inline for (0..k) |i| {
                 ret.ps[i] = Poly.fromBytes(
-                    buf[i * Poly.bytes_length .. (i + 1) * Poly.bytes_length],
+                    buf[i * Poly.encoded_length .. (i + 1) * Poly.encoded_length],
                 );
             }
             return ret;
@@ -1474,19 +1379,19 @@ fn Vec(comptime K: u8) type {
     };
 }
 
-// A matrix of K vectors
-fn Mat(comptime K: u8) type {
+// A matrix of k vectors
+fn Mat(comptime k: u8) type {
     return struct {
         const Self = @This();
-        vs: [K]Vec(K),
+        rows: [k]PolyVec(k),
 
         fn uniform(seed: [32]u8, comptime transposed: bool) Self {
             var ret: Self = undefined;
             var i: u8 = 0;
-            while (i < K) : (i += 1) {
+            while (i < k) : (i += 1) {
                 var j: u8 = 0;
-                while (j < K) : (j += 1) {
-                    ret.vs[i].ps[j] = Poly.uniform(
+                while (j < k) : (j += 1) {
+                    ret.rows[i].ps[j] = Poly.uniform(
                         seed,
                         if (transposed) i else j,
                         if (transposed) j else i,
@@ -1499,9 +1404,9 @@ fn Mat(comptime K: u8) type {
         // Returns transpose of A
         fn transpose(m: Self) Self {
             var ret: Self = undefined;
-            for (0..K) |i| {
-                for (0..K) |j| {
-                    ret.vs[i].ps[j] = m.vs[j].ps[i];
+            for (0..k) |i| {
+                for (0..k) |j| {
+                    ret.rows[i].ps[j] = m.rows[j].ps[i];
                 }
             }
             return ret;
@@ -1522,12 +1427,30 @@ fn cmov(comptime len: usize, dst: *[len]u8, src: [len]u8, b: u1) void {
     }
 }
 
+// Test helper: generates a random polynomial with each coefficient |x| ≤ q
+fn randPolyAbsLeqQ(rnd: anytype) Poly {
+    var ret: Poly = undefined;
+    for (0..N) |i| {
+        ret.cs[i] = rnd.random().intRangeAtMost(i16, -Q, Q);
+    }
+    return ret;
+}
+
+// Test helper: generates a random normalized polynomial
+fn randPolyNormalized(rnd: anytype) Poly {
+    var ret: Poly = undefined;
+    for (0..N) |i| {
+        ret.cs[i] = rnd.random().intRangeLessThan(i16, 0, Q);
+    }
+    return ret;
+}
+
 test "MulHat" {
     var rnd = RndGen.init(0);
 
     for (0..100) |_| {
-        const a = Poly.randAbsLeqQ(&rnd);
-        const b = Poly.randAbsLeqQ(&rnd);
+        const a = randPolyAbsLeqQ(&rnd);
+        const b = randPolyAbsLeqQ(&rnd);
 
         const p2 = a.ntt().mulHat(b.ntt()).barrettReduce().invNTT().normalize();
         var p: Poly = undefined;
@@ -1557,7 +1480,7 @@ test "NTT" {
     var rnd = RndGen.init(0);
 
     for (0..1000) |_| {
-        var p = Poly.randAbsLeqQ(&rnd);
+        var p = randPolyAbsLeqQ(&rnd);
         const q = p.toMont().normalize();
         p = p.ntt();
 
@@ -1580,7 +1503,7 @@ test "Compression" {
     var rnd = RndGen.init(0);
     inline for (.{ 1, 4, 5, 10, 11 }) |d| {
         for (0..1000) |_| {
-            const p = Poly.randNormalized(&rnd);
+            const p = randPolyNormalized(&rnd);
             const pp = p.compress(d);
             const pq = Poly.decompress(d, &pp).compress(d);
             try testing.expectEqual(pp, pq);
@@ -1671,7 +1594,7 @@ test "Polynomial packing" {
     var rnd = RndGen.init(0);
 
     for (0..1000) |_| {
-        const p = Poly.randNormalized(&rnd);
+        const p = randPolyNormalized(&rnd);
         try testing.expectEqual(Poly.fromBytes(&p.toBytes()), p);
     }
 }
@@ -1839,3 +1762,222 @@ const NistDRBG = struct {
         return ret;
     }
 };
+
+/// Extended Euclidian Algorithm
+/// Only meant to be used on comptime values; correctness matters, performance doesn't.
+fn extendedEuclidean(comptime T: type, comptime a_: T, comptime b_: T) struct { gcd: T, x: T, y: T } {
+    var a = a_;
+    var b = b_;
+    var x0: T = 1;
+    var x1: T = 0;
+    var y0: T = 0;
+    var y1: T = 1;
+
+    while (b != 0) {
+        const q = @divTrunc(a, b);
+        const temp_a = a;
+        a = b;
+        b = temp_a - q * b;
+
+        const temp_x = x0;
+        x0 = x1;
+        x1 = temp_x - q * x1;
+
+        const temp_y = y0;
+        y0 = y1;
+        y1 = temp_y - q * y1;
+    }
+
+    return .{ .gcd = a, .x = x0, .y = y0 };
+}
+
+/// Modular inversion: computes a^(-1) mod p
+/// Requires gcd(a,p) = 1. The result is normalized to the range [0, p).
+fn modularInverse(comptime T: type, comptime a: T, comptime p: T) T {
+    // Use a signed type for EEA computation
+    const type_info = @typeInfo(T);
+    const SignedT = if (type_info == .int and type_info.int.signedness == .unsigned)
+        std.meta.Int(.signed, type_info.int.bits)
+    else
+        T;
+
+    const a_signed = @as(SignedT, @intCast(a));
+    const p_signed = @as(SignedT, @intCast(p));
+
+    const r = extendedEuclidean(SignedT, a_signed, p_signed);
+    assert(r.gcd == 1);
+
+    // Normalize result to [0, p)
+    var result = r.x;
+    while (result < 0) {
+        result += p_signed;
+    }
+
+    return @intCast(result);
+}
+
+/// Modular exponentiation: computes a^s mod p using square-and-multiply algorithm.
+fn modularPow(comptime T: type, comptime a: T, s: T, comptime p: T) T {
+    const type_info = @typeInfo(T);
+    const bits = type_info.int.bits;
+    const WideT = std.meta.Int(.unsigned, bits * 2);
+
+    var ret: T = 1;
+    var base: T = a;
+    var exp = s;
+
+    while (exp > 0) {
+        if (exp & 1 == 1) {
+            ret = @intCast((@as(WideT, ret) * @as(WideT, base)) % p);
+        }
+        base = @intCast((@as(WideT, base) * @as(WideT, base)) % p);
+        exp >>= 1;
+    }
+
+    return ret;
+}
+
+/// Creates an all-ones or all-zeros mask from a single bit value.
+/// Returns all 1s (0xFF...FF) if bit == 1, all 0s if bit == 0.
+fn bitMask(comptime T: type, bit: T) T {
+    const type_info = @typeInfo(T);
+    if (type_info != .int or type_info.int.signedness != .unsigned) {
+        @compileError("bitMask requires an unsigned integer type");
+    }
+    return -%bit;
+}
+
+/// Creates a mask from the sign bit of a signed integer.
+/// Returns all 1s (0xFF...FF) if x < 0, all 0s if x >= 0.
+fn signMask(comptime T: type, x: T) std.meta.Int(.unsigned, @typeInfo(T).int.bits) {
+    const type_info = @typeInfo(T);
+    if (type_info != .int) {
+        @compileError("signMask requires an integer type");
+    }
+
+    const bits = type_info.int.bits;
+    const SignedT = std.meta.Int(.signed, bits);
+
+    // Convert to signed if needed, arithmetic right shift to propagate sign bit
+    const x_signed: SignedT = if (type_info.int.signedness == .signed) x else @bitCast(x);
+    const shifted = x_signed >> (bits - 1);
+    return @bitCast(shifted);
+}
+
+test "bitMask and signMask helpers" {
+    try testing.expectEqual(@as(u32, 0x00000000), bitMask(u32, 0));
+    try testing.expectEqual(@as(u32, 0xFFFFFFFF), bitMask(u32, 1));
+    try testing.expectEqual(@as(u8, 0x00), bitMask(u8, 0));
+    try testing.expectEqual(@as(u8, 0xFF), bitMask(u8, 1));
+    try testing.expectEqual(@as(u64, 0x0000000000000000), bitMask(u64, 0));
+    try testing.expectEqual(@as(u64, 0xFFFFFFFFFFFFFFFF), bitMask(u64, 1));
+
+    try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -1));
+    try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -100));
+    try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 0));
+    try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 1));
+    try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 100));
+
+    try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(u32, 0x80000000)); // MSB set
+    try testing.expectEqual(@as(u32, 0x00000000), signMask(u32, 0x7FFFFFFF)); // MSB clear
+}
+
+/// Montgomery reduction: for input x, returns y where y ≡ x*R^(-1) (mod q).
+/// This is a generic implementation parameterized by the modulus q, its inverse qInv,
+/// the Montgomery constant R, and the result bound.
+///
+/// For ML-DSA: R = 2^32, returns y < 2q
+/// For ML-KEM: R = 2^16, returns y in range (-q, q)
+fn montgomeryReduce(
+    comptime InT: type,
+    comptime OutT: type,
+    comptime q: comptime_int,
+    comptime qInv: comptime_int,
+    comptime r_bits: comptime_int,
+    x: InT,
+) OutT {
+    const mask = (@as(InT, 1) << r_bits) - 1;
+    const m_full = (x *% qInv) & mask;
+    const m: OutT = @truncate(m_full);
+
+    const yR = x -% @as(InT, m) * @as(InT, q);
+    const y_shifted = @as(std.meta.Int(.unsigned, @typeInfo(InT).Int.bits), @bitCast(yR)) >> r_bits;
+    return @bitCast(@as(std.meta.Int(.unsigned, @typeInfo(OutT).Int.bits), @truncate(y_shifted)));
+}
+
+/// Uniform sampling using SHAKE-128 with rejection sampling.
+/// Samples polynomial coefficients uniformly from [0, q) using rejection sampling.
+///
+/// Parameters:
+/// - PolyType: The polynomial type to return
+/// - q: Modulus
+/// - bits_per_coef: Number of bits per coefficient (12 or 23)
+/// - n: Number of coefficients
+/// - seed: Random seed
+/// - domain_sep: Domain separation bytes (appended to seed)
+fn sampleUniformRejection(
+    comptime PolyType: type,
+    comptime q: comptime_int,
+    comptime bits_per_coef: comptime_int,
+    comptime n: comptime_int,
+    seed: []const u8,
+    domain_sep: []const u8,
+) PolyType {
+    var h = sha3.Shake128.init(.{});
+    h.update(seed);
+    h.update(domain_sep);
+
+    const buf_len = sha3.Shake128.block_length; // 168 bytes
+    var buf: [buf_len]u8 = undefined;
+
+    var ret: PolyType = undefined;
+    var coef_idx: usize = 0;
+
+    if (bits_per_coef == 12) {
+        // ML-KEM path: pack 2 coefficients per 3 bytes (12 bits each)
+        outer: while (true) {
+            h.squeeze(&buf);
+
+            var j: usize = 0;
+            while (j < buf_len) : (j += 3) {
+                const b0 = @as(u16, buf[j]);
+                const b1 = @as(u16, buf[j + 1]);
+                const b2 = @as(u16, buf[j + 2]);
+
+                const ts: [2]u16 = .{
+                    b0 | ((b1 & 0xf) << 8),
+                    (b1 >> 4) | (b2 << 4),
+                };
+
+                inline for (ts) |t| {
+                    if (t < q) {
+                        ret.cs[coef_idx] = @intCast(t);
+                        coef_idx += 1;
+                        if (coef_idx == n) break :outer;
+                    }
+                }
+            }
+        }
+    } else if (bits_per_coef == 23) {
+        // ML-DSA path: 1 coefficient per 3 bytes (23 bits)
+        while (coef_idx < n) {
+            h.squeeze(&buf);
+
+            var j: usize = 0;
+            while (j < buf_len and coef_idx < n) : (j += 3) {
+                const t = (@as(u32, buf[j]) |
+                    (@as(u32, buf[j + 1]) << 8) |
+                    (@as(u32, buf[j + 2]) << 16)) & 0x7fffff;
+
+                if (t < q) {
+                    ret.cs[coef_idx] = @intCast(t);
+                    coef_idx += 1;
+                }
+            }
+        }
+    } else {
+        @compileError("bits_per_coef must be 12 or 23");
+    }
+
+    return ret;
+}