Commit 4ea4728084
Changed files (1)
lib
std
crypto
lib/std/crypto/ml_kem.zig
@@ -105,19 +105,20 @@ const crypto = std.crypto;
const errors = std.crypto.errors;
const math = std.math;
const mem = std.mem;
-const RndGen = std.Random.DefaultPrng;
const sha3 = crypto.hash.sha3;
-// Q is the parameter q ≡ 3329 = 2¹¹ + 2¹⁰ + 2⁸ + 1.
+const RndGen = std.Random.DefaultPrng;
+
+// Q is the modulus q ≡ 3329 = 2¹¹ + 2¹⁰ + 2⁸ + 1
const Q: i16 = 3329;
-// Montgomery R
+// Montgomery R = 2^16 mod Q (for Montgomery multiplication)
const R: i32 = 1 << 16;
-// Parameter n, degree of polynomials.
+// N is the degree of polynomials (polynomial ring dimension)
const N: usize = 256;
-// Size of "small" vectors used in encryption blinds.
+// eta2 is the size of "small" vectors used in encryption blinds
const eta2: u8 = 2;
const Params = struct {
@@ -215,7 +216,7 @@ fn Kyber(comptime p: Params) type {
pub const ciphertext_length = Poly.compressedSize(p.du) * p.k + Poly.compressedSize(p.dv);
const Self = @This();
- const V = Vec(p.k);
+ const V = PolyVec(p.k);
const M = Mat(p.k);
/// Length (in bytes) of a shared secret.
@@ -241,7 +242,7 @@ fn Kyber(comptime p: Params) type {
hpk: [h_length]u8, // H(pk)
/// Size of a serialized representation of the key, in bytes.
- pub const bytes_length = InnerPk.bytes_length;
+ pub const encoded_length = InnerPk.encoded_length;
/// Generates a shared secret, and encapsulates it for the public key.
/// If `seed` is `null`, a random seed is used. This is recommended.
@@ -289,14 +290,14 @@ fn Kyber(comptime p: Params) type {
}
/// Serializes the key into a byte array.
- pub fn toBytes(pk: PublicKey) [bytes_length]u8 {
+ pub fn toBytes(pk: PublicKey) [encoded_length]u8 {
return pk.pk.toBytes();
}
/// Deserializes the key from a byte array.
- pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!PublicKey {
+ pub fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!PublicKey {
var ret: PublicKey = undefined;
- ret.pk = try InnerPk.fromBytes(buf[0..InnerPk.bytes_length]);
+ ret.pk = try InnerPk.fromBytes(buf[0..InnerPk.encoded_length]);
sha3.Sha3_256.hash(buf, &ret.hpk, .{});
return ret;
}
@@ -310,8 +311,8 @@ fn Kyber(comptime p: Params) type {
z: [shared_length]u8,
/// Size of a serialized representation of the key, in bytes.
- pub const bytes_length: usize =
- InnerSk.bytes_length + InnerPk.bytes_length + h_length + shared_length;
+ pub const encoded_length: usize =
+ InnerSk.encoded_length + InnerPk.encoded_length + h_length + shared_length;
/// Decapsulates the shared secret within ct using the private key.
pub fn decaps(sk: SecretKey, ct: *const [ciphertext_length]u8) ![shared_length]u8 {
@@ -346,18 +347,18 @@ fn Kyber(comptime p: Params) type {
}
/// Serializes the key into a byte array.
- pub fn toBytes(sk: SecretKey) [bytes_length]u8 {
+ pub fn toBytes(sk: SecretKey) [encoded_length]u8 {
return sk.sk.toBytes() ++ sk.pk.toBytes() ++ sk.hpk ++ sk.z;
}
/// Deserializes the key from a byte array.
- pub fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!SecretKey {
+ pub fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!SecretKey {
var ret: SecretKey = undefined;
comptime var s: usize = 0;
- ret.sk = InnerSk.fromBytes(buf[s .. s + InnerSk.bytes_length]);
- s += InnerSk.bytes_length;
- ret.pk = try InnerPk.fromBytes(buf[s .. s + InnerPk.bytes_length]);
- s += InnerPk.bytes_length;
+ ret.sk = InnerSk.fromBytes(buf[s .. s + InnerSk.encoded_length]);
+ s += InnerSk.encoded_length;
+ ret.pk = try InnerPk.fromBytes(buf[s .. s + InnerPk.encoded_length]);
+ s += InnerPk.encoded_length;
ret.hpk = buf[s..][0..h_length].*;
s += h_length;
ret.z = buf[s..][0..shared_length].*;
@@ -418,7 +419,7 @@ fn Kyber(comptime p: Params) type {
// Cached values
aT: M,
- const bytes_length = V.bytes_length + 32;
+ const encoded_length = V.encoded_length + 32;
fn encrypt(
pk: InnerPk,
@@ -436,7 +437,7 @@ fn Kyber(comptime p: Params) type {
// Note that coefficients of r are bounded by q and those of Aᵀ
// are bounded by 4.5q and so their product is bounded by 2¹⁵q
// as required for multiplication.
- u.ps[i] = pk.aT.vs[i].dotHat(rh);
+ u.ps[i] = pk.aT.rows[i].dotHat(rh);
}
// Aᵀ and r were not in Montgomery form, so the Montgomery
@@ -451,14 +452,14 @@ fn Kyber(comptime p: Params) type {
return u.compress(p.du) ++ v.compress(p.dv);
}
- fn toBytes(pk: InnerPk) [bytes_length]u8 {
+ fn toBytes(pk: InnerPk) [encoded_length]u8 {
return pk.th.toBytes() ++ pk.rho;
}
- fn fromBytes(buf: *const [bytes_length]u8) errors.NonCanonicalError!InnerPk {
+ fn fromBytes(buf: *const [encoded_length]u8) errors.NonCanonicalError!InnerPk {
var ret: InnerPk = undefined;
- const th_bytes = buf[0..V.bytes_length];
+ const th_bytes = buf[0..V.encoded_length];
ret.th = V.fromBytes(th_bytes).normalize();
if (p.ml_kem) {
@@ -468,7 +469,7 @@ fn Kyber(comptime p: Params) type {
}
}
- ret.rho = buf[V.bytes_length..bytes_length].*;
+ ret.rho = buf[V.encoded_length..encoded_length].*;
ret.aT = M.uniform(ret.rho, true);
return ret;
}
@@ -477,7 +478,7 @@ fn Kyber(comptime p: Params) type {
// Private key of the inner PKE
const InnerSk = struct {
sh: V, // NTT(s), normalized
- const bytes_length = V.bytes_length;
+ const encoded_length = V.encoded_length;
fn decrypt(sk: InnerSk, ct: *const [ciphertext_length]u8) [inner_plaintext_length]u8 {
const u = V.decompress(p.du, ct[0..comptime V.compressedSize(p.du)]);
@@ -491,11 +492,11 @@ fn Kyber(comptime p: Params) type {
.normalize().compress(1);
}
- fn toBytes(sk: InnerSk) [bytes_length]u8 {
+ fn toBytes(sk: InnerSk) [encoded_length]u8 {
return sk.sh.toBytes();
}
- fn fromBytes(buf: *const [bytes_length]u8) InnerSk {
+ fn fromBytes(buf: *const [encoded_length]u8) InnerSk {
var ret: InnerSk = undefined;
ret.sh = V.fromBytes(buf).normalize();
return ret;
@@ -516,7 +517,7 @@ fn Kyber(comptime p: Params) type {
// Sample secret vector s.
sk.sh = V.noise(p.eta1, 0, sigma).ntt().normalize();
- const eh = Vec(p.k).noise(p.eta1, p.k, sigma).ntt(); // sample blind e.
+ const eh = PolyVec(p.k).noise(p.eta1, p.k, sigma).ntt(); // sample blind e.
var th: V = undefined;
// Next, we compute t = A s + e.
@@ -528,7 +529,7 @@ fn Kyber(comptime p: Params) type {
// multiplications in the inner product added a factor R⁻¹ which
// we'll cancel out with toMont(). This will also ensure the
// coefficients of th are bounded in absolute value by q.
- th.ps[i] = pk.aT.vs[i].dotHat(sk.sh).toMont();
+ th.ps[i] = pk.aT.rows[i].dotHat(sk.sh).toMont();
}
pk.th = th.add(eh).normalize(); // bounded by 8q
@@ -565,7 +566,6 @@ const zetas = computeZetas();
// not enough, the other coefficient is reduced as well.
//
// This is actually optimal, as proven in https://eprint.iacr.org/2020/1377.pdf
-// TODO generate comptime?
const inv_ntt_reductions = [_]i16{
-1, // after layer 1
-1, // after layer 2
@@ -634,31 +634,8 @@ test "invNTTReductions bounds" {
}
}
-// Extended euclidean algorithm.
-//
-// For a, b finds x, y such that x a + y b = gcd(a, b). Used to compute
-// modular inverse.
-fn eea(a: anytype, b: @TypeOf(a)) EeaResult(@TypeOf(a)) {
- if (a == 0) {
- return .{ .gcd = b, .x = 0, .y = 1 };
- }
- const r = eea(@rem(b, a), a);
- return .{ .gcd = r.gcd, .x = r.y - @divTrunc(b, a) * r.x, .y = r.x };
-}
-
-fn EeaResult(comptime T: type) type {
- return struct { gcd: T, x: T, y: T };
-}
-
-// Returns least common multiple of a and b.
-fn lcm(a: anytype, b: @TypeOf(a)) @TypeOf(a) {
- const r = eea(a, b);
- return a * b / r.gcd;
-}
-
-// Invert modulo p.
fn invertMod(a: anytype, p: @TypeOf(a)) @TypeOf(a) {
- const r = eea(a, p);
+ const r = extendedEuclidean(@TypeOf(a), a, p);
assert(r.gcd == 1);
return r.x;
}
@@ -788,31 +765,12 @@ test "Test csubq" {
}
}
-// Compute a^s mod p.
-fn mpow(a: anytype, s: @TypeOf(a), p: @TypeOf(a)) @TypeOf(a) {
- var ret: @TypeOf(a) = 1;
- var s2 = s;
- var a2 = a;
-
- while (true) {
- if (s2 & 1 == 1) {
- ret = @mod(ret * a2, p);
- }
- s2 >>= 1;
- if (s2 == 0) {
- break;
- }
- a2 = @mod(a2 * a2, p);
- }
- return ret;
-}
-
// Computes zetas table used by ntt and invNTT.
fn computeZetas() [128]i16 {
@setEvalBranchQuota(10000);
var ret: [128]i16 = undefined;
for (&ret, 0..) |*r, i| {
- const t = @as(i16, @intCast(mpow(@as(i32, zeta), @bitReverse(@as(u7, @intCast(i))), Q)));
+ const t = @as(i16, @intCast(modularPow(i32, zeta, @bitReverse(@as(u7, @intCast(i))), Q)));
r.* = csubq(feBarrettReduce(feToMont(t)));
}
return ret;
@@ -828,9 +786,10 @@ fn computeZetas() [128]i16 {
const Poly = struct {
cs: [N]i16,
- const bytes_length = N / 2 * 3;
+ const encoded_length = N / 2 * 3;
const zero: Poly = .{ .cs = .{0} ** N };
+ // Add two polynomials (coefficients not normalized)
fn add(a: Poly, b: Poly) Poly {
var ret: Poly = undefined;
for (0..N) |i| {
@@ -839,6 +798,7 @@ const Poly = struct {
return ret;
}
+ // Subtract two polynomials (coefficients not normalized)
fn sub(a: Poly, b: Poly) Poly {
var ret: Poly = undefined;
for (0..N) |i| {
@@ -847,25 +807,6 @@ const Poly = struct {
return ret;
}
- // For testing, generates a random polynomial with for each
- // coefficient |x| ≤ q.
- fn randAbsLeqQ(rnd: anytype) Poly {
- var ret: Poly = undefined;
- for (0..N) |i| {
- ret.cs[i] = rnd.random().intRangeAtMost(i16, -Q, Q);
- }
- return ret;
- }
-
- // For testing, generates a random normalized polynomial.
- fn randNormalized(rnd: anytype) Poly {
- var ret: Poly = undefined;
- for (0..N) |i| {
- ret.cs[i] = rnd.random().intRangeLessThan(i16, 0, Q);
- }
- return ret;
- }
-
// Executes a forward "NTT" on p.
//
// Assumes the coefficients are in absolute value ≤q. The resulting
@@ -1054,7 +995,7 @@ const Poly = struct {
var in_off: usize = 0;
var out_off: usize = 0;
- const batch_size: usize = comptime lcm(@as(i16, d), 8);
+ const batch_size: usize = comptime math.lcm(d, 8);
const in_batch_size: usize = comptime batch_size / d;
const out_batch_size: usize = comptime batch_size / 8;
@@ -1118,7 +1059,7 @@ const Poly = struct {
var in_off: usize = 0;
var out_off: usize = 0;
- const batch_size: usize = comptime lcm(@as(i16, d), 8);
+ const batch_size: usize = comptime math.lcm(d, 8);
const in_batch_size: usize = comptime batch_size / 8;
const out_batch_size: usize = comptime batch_size / d;
@@ -1275,53 +1216,23 @@ const Poly = struct {
return ret;
}
- // Sample p uniformly from the given seed and x and y coordinates.
fn uniform(seed: [32]u8, x: u8, y: u8) Poly {
- var h = sha3.Shake128.init(.{});
- const suffix: [2]u8 = .{ x, y };
- h.update(&seed);
- h.update(&suffix);
-
- const buf_len = sha3.Shake128.block_length; // rate SHAKE-128
- var buf: [buf_len]u8 = undefined;
-
- var ret: Poly = undefined;
- var i: usize = 0; // index into ret.cs
- outer: while (true) {
- h.squeeze(&buf);
-
- var j: usize = 0; // index into buf
- while (j < buf_len) : (j += 3) {
- const b0 = @as(u16, buf[j]);
- const b1 = @as(u16, buf[j + 1]);
- const b2 = @as(u16, buf[j + 2]);
-
- const ts: [2]u16 = .{
- b0 | ((b1 & 0xf) << 8),
- (b1 >> 4) | (b2 << 4),
- };
-
- inline for (ts) |t| {
- if (t < Q) {
- ret.cs[i] = @as(i16, @intCast(t));
- i += 1;
-
- if (i == N) {
- break :outer;
- }
- }
- }
- }
- }
-
- return ret;
+ const domain_sep: [2]u8 = .{ x, y };
+ return sampleUniformRejection(
+ Poly,
+ Q,
+ 12,
+ N,
+ &seed,
+ &domain_sep,
+ );
}
// Packs p.
//
// Assumes p is normalized (and not just Barrett reduced).
- fn toBytes(p: Poly) [bytes_length]u8 {
- var ret: [bytes_length]u8 = undefined;
+ fn toBytes(p: Poly) [encoded_length]u8 {
+ var ret: [encoded_length]u8 = undefined;
for (0..comptime N / 2) |i| {
const t0 = @as(u16, @intCast(p.cs[2 * i]));
const t1 = @as(u16, @intCast(p.cs[2 * i + 1]));
@@ -1335,7 +1246,7 @@ const Poly = struct {
// Unpacks a Poly from buf.
//
// p will not be normalized; instead 0 ≤ p[i] < 4096.
- fn fromBytes(buf: *const [bytes_length]u8) Poly {
+ fn fromBytes(buf: *const [encoded_length]u8) Poly {
var ret: Poly = undefined;
for (0..comptime N / 2) |i| {
const b0 = @as(i16, buf[3 * i]);
@@ -1348,71 +1259,65 @@ const Poly = struct {
}
};
-// A vector of K polynomials.
-fn Vec(comptime K: u8) type {
+// A vector of k polynomials.
+fn PolyVec(comptime k: u8) type {
return struct {
- ps: [K]Poly,
+ ps: [k]Poly,
const Self = @This();
- const bytes_length = K * Poly.bytes_length;
+ const encoded_length = k * Poly.encoded_length;
fn compressedSize(comptime d: u8) usize {
- return Poly.compressedSize(d) * K;
+ return Poly.compressedSize(d) * k;
}
- fn ntt(a: Self) Self {
+ /// Apply unary operation to each polynomial
+ fn map(v: Self, comptime op: fn (Poly) Poly) Self {
var ret: Self = undefined;
- for (0..K) |i| {
- ret.ps[i] = a.ps[i].ntt();
+ inline for (0..k) |i| {
+ ret.ps[i] = op(v.ps[i]);
}
return ret;
}
- fn invNTT(a: Self) Self {
+ /// Apply binary operation pairwise
+ fn mapBinary(a: Self, b: Self, comptime op: fn (Poly, Poly) Poly) Self {
var ret: Self = undefined;
- for (0..K) |i| {
- ret.ps[i] = a.ps[i].invNTT();
+ inline for (0..k) |i| {
+ ret.ps[i] = op(a.ps[i], b.ps[i]);
}
return ret;
}
- fn normalize(a: Self) Self {
- var ret: Self = undefined;
- for (0..K) |i| {
- ret.ps[i] = a.ps[i].normalize();
- }
- return ret;
+ fn ntt(v: Self) Self {
+ return map(v, Poly.ntt);
}
- fn barrettReduce(a: Self) Self {
- var ret: Self = undefined;
- for (0..K) |i| {
- ret.ps[i] = a.ps[i].barrettReduce();
- }
- return ret;
+ fn invNTT(v: Self) Self {
+ return map(v, Poly.invNTT);
+ }
+
+ fn normalize(v: Self) Self {
+ return map(v, Poly.normalize);
+ }
+
+ fn barrettReduce(v: Self) Self {
+ return map(v, Poly.barrettReduce);
}
fn add(a: Self, b: Self) Self {
- var ret: Self = undefined;
- for (0..K) |i| {
- ret.ps[i] = a.ps[i].add(b.ps[i]);
- }
- return ret;
+ return mapBinary(a, b, Poly.add);
}
fn sub(a: Self, b: Self) Self {
- var ret: Self = undefined;
- for (0..K) |i| {
- ret.ps[i] = a.ps[i].sub(b.ps[i]);
- }
- return ret;
+ return mapBinary(a, b, Poly.sub);
}
// Samples v[i] from centered binomial distribution with the given η,
// seed and nonce+i.
fn noise(comptime eta: u8, nonce: u8, seed: *const [32]u8) Self {
var ret: Self = undefined;
- for (0..K) |i| {
+ for (0..k) |i| {
ret.ps[i] = Poly.noise(eta, nonce + @as(u8, @intCast(i)), seed);
}
return ret;
@@ -1428,7 +1333,7 @@ fn Vec(comptime K: u8) type {
// of the Montgomery factor.
fn dotHat(a: Self, b: Self) Poly {
var ret: Poly = Poly.zero;
- for (0..K) |i| {
+ for (0..k) |i| {
ret = ret.add(a.ps[i].mulHat(b.ps[i]));
}
return ret;
@@ -1437,7 +1342,7 @@ fn Vec(comptime K: u8) type {
fn compress(v: Self, comptime d: u8) [compressedSize(d)]u8 {
const cs = comptime Poly.compressedSize(d);
var ret: [compressedSize(d)]u8 = undefined;
- inline for (0..K) |i| {
+ inline for (0..k) |i| {
ret[i * cs .. (i + 1) * cs].* = v.ps[i].compress(d);
}
return ret;
@@ -1446,27 +1351,27 @@ fn Vec(comptime K: u8) type {
fn decompress(comptime d: u8, buf: *const [compressedSize(d)]u8) Self {
const cs = comptime Poly.compressedSize(d);
var ret: Self = undefined;
- inline for (0..K) |i| {
+ inline for (0..k) |i| {
ret.ps[i] = Poly.decompress(d, buf[i * cs .. (i + 1) * cs]);
}
return ret;
}
/// Serializes the key into a byte array.
- fn toBytes(v: Self) [bytes_length]u8 {
- var ret: [bytes_length]u8 = undefined;
- inline for (0..K) |i| {
- ret[i * Poly.bytes_length .. (i + 1) * Poly.bytes_length].* = v.ps[i].toBytes();
+ fn toBytes(v: Self) [encoded_length]u8 {
+ var ret: [encoded_length]u8 = undefined;
+ inline for (0..k) |i| {
+ ret[i * Poly.encoded_length .. (i + 1) * Poly.encoded_length].* = v.ps[i].toBytes();
}
return ret;
}
/// Deserializes the key from a byte array.
- fn fromBytes(buf: *const [bytes_length]u8) Self {
+ fn fromBytes(buf: *const [encoded_length]u8) Self {
var ret: Self = undefined;
- inline for (0..K) |i| {
+ inline for (0..k) |i| {
ret.ps[i] = Poly.fromBytes(
- buf[i * Poly.bytes_length .. (i + 1) * Poly.bytes_length],
+ buf[i * Poly.encoded_length .. (i + 1) * Poly.encoded_length],
);
}
return ret;
@@ -1474,19 +1379,19 @@ fn Vec(comptime K: u8) type {
};
}
-// A matrix of K vectors
-fn Mat(comptime K: u8) type {
+// A matrix of k vectors
+fn Mat(comptime k: u8) type {
return struct {
const Self = @This();
- vs: [K]Vec(K),
+ rows: [k]PolyVec(k),
fn uniform(seed: [32]u8, comptime transposed: bool) Self {
var ret: Self = undefined;
var i: u8 = 0;
- while (i < K) : (i += 1) {
+ while (i < k) : (i += 1) {
var j: u8 = 0;
- while (j < K) : (j += 1) {
- ret.vs[i].ps[j] = Poly.uniform(
+ while (j < k) : (j += 1) {
+ ret.rows[i].ps[j] = Poly.uniform(
seed,
if (transposed) i else j,
if (transposed) j else i,
@@ -1499,9 +1404,9 @@ fn Mat(comptime K: u8) type {
// Returns transpose of A
fn transpose(m: Self) Self {
var ret: Self = undefined;
- for (0..K) |i| {
- for (0..K) |j| {
- ret.vs[i].ps[j] = m.vs[j].ps[i];
+ for (0..k) |i| {
+ for (0..k) |j| {
+ ret.rows[i].ps[j] = m.rows[j].ps[i];
}
}
return ret;
@@ -1522,12 +1427,30 @@ fn cmov(comptime len: usize, dst: *[len]u8, src: [len]u8, b: u1) void {
}
}
+// Test helper: generates a random polynomial with each coefficient |x| ≤ q
+fn randPolyAbsLeqQ(rnd: anytype) Poly {
+ var ret: Poly = undefined;
+ for (0..N) |i| {
+ ret.cs[i] = rnd.random().intRangeAtMost(i16, -Q, Q);
+ }
+ return ret;
+}
+
+// Test helper: generates a random normalized polynomial
+fn randPolyNormalized(rnd: anytype) Poly {
+ var ret: Poly = undefined;
+ for (0..N) |i| {
+ ret.cs[i] = rnd.random().intRangeLessThan(i16, 0, Q);
+ }
+ return ret;
+}
+
test "MulHat" {
var rnd = RndGen.init(0);
for (0..100) |_| {
- const a = Poly.randAbsLeqQ(&rnd);
- const b = Poly.randAbsLeqQ(&rnd);
+ const a = randPolyAbsLeqQ(&rnd);
+ const b = randPolyAbsLeqQ(&rnd);
const p2 = a.ntt().mulHat(b.ntt()).barrettReduce().invNTT().normalize();
var p: Poly = undefined;
@@ -1557,7 +1480,7 @@ test "NTT" {
var rnd = RndGen.init(0);
for (0..1000) |_| {
- var p = Poly.randAbsLeqQ(&rnd);
+ var p = randPolyAbsLeqQ(&rnd);
const q = p.toMont().normalize();
p = p.ntt();
@@ -1580,7 +1503,7 @@ test "Compression" {
var rnd = RndGen.init(0);
inline for (.{ 1, 4, 5, 10, 11 }) |d| {
for (0..1000) |_| {
- const p = Poly.randNormalized(&rnd);
+ const p = randPolyNormalized(&rnd);
const pp = p.compress(d);
const pq = Poly.decompress(d, &pp).compress(d);
try testing.expectEqual(pp, pq);
@@ -1671,7 +1594,7 @@ test "Polynomial packing" {
var rnd = RndGen.init(0);
for (0..1000) |_| {
- const p = Poly.randNormalized(&rnd);
+ const p = randPolyNormalized(&rnd);
try testing.expectEqual(Poly.fromBytes(&p.toBytes()), p);
}
}
@@ -1839,3 +1762,222 @@ const NistDRBG = struct {
return ret;
}
};
+
+/// Extended Euclidian Algorithm
+/// Only meant to be used on comptime values; correctness matters, performance doesn't.
+fn extendedEuclidean(comptime T: type, comptime a_: T, comptime b_: T) struct { gcd: T, x: T, y: T } {
+ var a = a_;
+ var b = b_;
+ var x0: T = 1;
+ var x1: T = 0;
+ var y0: T = 0;
+ var y1: T = 1;
+
+ while (b != 0) {
+ const q = @divTrunc(a, b);
+ const temp_a = a;
+ a = b;
+ b = temp_a - q * b;
+
+ const temp_x = x0;
+ x0 = x1;
+ x1 = temp_x - q * x1;
+
+ const temp_y = y0;
+ y0 = y1;
+ y1 = temp_y - q * y1;
+ }
+
+ return .{ .gcd = a, .x = x0, .y = y0 };
+}
+
+/// Modular inversion: computes a^(-1) mod p
+/// Requires gcd(a,p) = 1. The result is normalized to the range [0, p).
+fn modularInverse(comptime T: type, comptime a: T, comptime p: T) T {
+ // Use a signed type for EEA computation
+ const type_info = @typeInfo(T);
+ const SignedT = if (type_info == .int and type_info.int.signedness == .unsigned)
+ std.meta.Int(.signed, type_info.int.bits)
+ else
+ T;
+
+ const a_signed = @as(SignedT, @intCast(a));
+ const p_signed = @as(SignedT, @intCast(p));
+
+ const r = extendedEuclidean(SignedT, a_signed, p_signed);
+ assert(r.gcd == 1);
+
+ // Normalize result to [0, p)
+ var result = r.x;
+ while (result < 0) {
+ result += p_signed;
+ }
+
+ return @intCast(result);
+}
+
+/// Modular exponentiation: computes a^s mod p using square-and-multiply algorithm.
+fn modularPow(comptime T: type, comptime a: T, s: T, comptime p: T) T {
+ const type_info = @typeInfo(T);
+ const bits = type_info.int.bits;
+ const WideT = std.meta.Int(.unsigned, bits * 2);
+
+ var ret: T = 1;
+ var base: T = a;
+ var exp = s;
+
+ while (exp > 0) {
+ if (exp & 1 == 1) {
+ ret = @intCast((@as(WideT, ret) * @as(WideT, base)) % p);
+ }
+ base = @intCast((@as(WideT, base) * @as(WideT, base)) % p);
+ exp >>= 1;
+ }
+
+ return ret;
+}
+
+/// Creates an all-ones or all-zeros mask from a single bit value.
+/// Returns all 1s (0xFF...FF) if bit == 1, all 0s if bit == 0.
+fn bitMask(comptime T: type, bit: T) T {
+ const type_info = @typeInfo(T);
+ if (type_info != .int or type_info.int.signedness != .unsigned) {
+ @compileError("bitMask requires an unsigned integer type");
+ }
+ return -%bit;
+}
+
+/// Creates a mask from the sign bit of a signed integer.
+/// Returns all 1s (0xFF...FF) if x < 0, all 0s if x >= 0.
+fn signMask(comptime T: type, x: T) std.meta.Int(.unsigned, @typeInfo(T).int.bits) {
+ const type_info = @typeInfo(T);
+ if (type_info != .int) {
+ @compileError("signMask requires an integer type");
+ }
+
+ const bits = type_info.int.bits;
+ const SignedT = std.meta.Int(.signed, bits);
+
+ // Convert to signed if needed, arithmetic right shift to propagate sign bit
+ const x_signed: SignedT = if (type_info.int.signedness == .signed) x else @bitCast(x);
+ const shifted = x_signed >> (bits - 1);
+ return @bitCast(shifted);
+}
+
+test "bitMask and signMask helpers" {
+ try testing.expectEqual(@as(u32, 0x00000000), bitMask(u32, 0));
+ try testing.expectEqual(@as(u32, 0xFFFFFFFF), bitMask(u32, 1));
+ try testing.expectEqual(@as(u8, 0x00), bitMask(u8, 0));
+ try testing.expectEqual(@as(u8, 0xFF), bitMask(u8, 1));
+ try testing.expectEqual(@as(u64, 0x0000000000000000), bitMask(u64, 0));
+ try testing.expectEqual(@as(u64, 0xFFFFFFFFFFFFFFFF), bitMask(u64, 1));
+
+ try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -1));
+ try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(i32, -100));
+ try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 0));
+ try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 1));
+ try testing.expectEqual(@as(u32, 0x00000000), signMask(i32, 100));
+
+ try testing.expectEqual(@as(u32, 0xFFFFFFFF), signMask(u32, 0x80000000)); // MSB set
+ try testing.expectEqual(@as(u32, 0x00000000), signMask(u32, 0x7FFFFFFF)); // MSB clear
+}
+
+/// Montgomery reduction: for input x, returns y where y ≡ x*R^(-1) (mod q).
+/// This is a generic implementation parameterized by the modulus q, its inverse qInv,
+/// the Montgomery constant R, and the result bound.
+///
+/// For ML-DSA: R = 2^32, returns y < 2q
+/// For ML-KEM: R = 2^16, returns y in range (-q, q)
+fn montgomeryReduce(
+ comptime InT: type,
+ comptime OutT: type,
+ comptime q: comptime_int,
+ comptime qInv: comptime_int,
+ comptime r_bits: comptime_int,
+ x: InT,
+) OutT {
+ const mask = (@as(InT, 1) << r_bits) - 1;
+ const m_full = (x *% qInv) & mask;
+ const m: OutT = @truncate(m_full);
+
+ const yR = x -% @as(InT, m) * @as(InT, q);
+ const y_shifted = @as(std.meta.Int(.unsigned, @typeInfo(InT).Int.bits), @bitCast(yR)) >> r_bits;
+ return @bitCast(@as(std.meta.Int(.unsigned, @typeInfo(OutT).Int.bits), @truncate(y_shifted)));
+}
+
+/// Uniform sampling using SHAKE-128 with rejection sampling.
+/// Samples polynomial coefficients uniformly from [0, q) using rejection sampling.
+///
+/// Parameters:
+/// - PolyType: The polynomial type to return
+/// - q: Modulus
+/// - bits_per_coef: Number of bits per coefficient (12 or 23)
+/// - n: Number of coefficients
+/// - seed: Random seed
+/// - domain_sep: Domain separation bytes (appended to seed)
+fn sampleUniformRejection(
+ comptime PolyType: type,
+ comptime q: comptime_int,
+ comptime bits_per_coef: comptime_int,
+ comptime n: comptime_int,
+ seed: []const u8,
+ domain_sep: []const u8,
+) PolyType {
+ var h = sha3.Shake128.init(.{});
+ h.update(seed);
+ h.update(domain_sep);
+
+ const buf_len = sha3.Shake128.block_length; // 168 bytes
+ var buf: [buf_len]u8 = undefined;
+
+ var ret: PolyType = undefined;
+ var coef_idx: usize = 0;
+
+ if (bits_per_coef == 12) {
+ // ML-KEM path: pack 2 coefficients per 3 bytes (12 bits each)
+ outer: while (true) {
+ h.squeeze(&buf);
+
+ var j: usize = 0;
+ while (j < buf_len) : (j += 3) {
+ const b0 = @as(u16, buf[j]);
+ const b1 = @as(u16, buf[j + 1]);
+ const b2 = @as(u16, buf[j + 2]);
+
+ const ts: [2]u16 = .{
+ b0 | ((b1 & 0xf) << 8),
+ (b1 >> 4) | (b2 << 4),
+ };
+
+ inline for (ts) |t| {
+ if (t < q) {
+ ret.cs[coef_idx] = @intCast(t);
+ coef_idx += 1;
+ if (coef_idx == n) break :outer;
+ }
+ }
+ }
+ }
+ } else if (bits_per_coef == 23) {
+ // ML-DSA path: 1 coefficient per 3 bytes (23 bits)
+ while (coef_idx < n) {
+ h.squeeze(&buf);
+
+ var j: usize = 0;
+ while (j < buf_len and coef_idx < n) : (j += 3) {
+ const t = (@as(u32, buf[j]) |
+ (@as(u32, buf[j + 1]) << 8) |
+ (@as(u32, buf[j + 2]) << 16)) & 0x7fffff;
+
+ if (t < q) {
+ ret.cs[coef_idx] = @intCast(t);
+ coef_idx += 1;
+ }
+ }
+ }
+ } else {
+ @compileError("bits_per_coef must be 12 or 23");
+ }
+
+ return ret;
+}