master
  1//! Allocation-free, (best-effort) constant-time, finite field arithmetic for large integers.
  2//!
  3//! Unlike `std.math.big`, these integers have a fixed maximum length and are only designed to be used for modular arithmetic.
  4//! Arithmetic operations are meant to run in constant-time for a given modulus, making them suitable for cryptography.
  5//!
  6//! Parts of that code was ported from the BSD-licensed crypto/internal/bigmod/nat.go file in the Go language, itself inspired from BearSSL.
  7
  8const std = @import("std");
  9const builtin = @import("builtin");
 10const crypto = std.crypto;
 11const math = std.math;
 12const mem = std.mem;
 13const meta = std.meta;
 14const testing = std.testing;
 15const assert = std.debug.assert;
 16const Endian = std.builtin.Endian;
 17
 18// A Limb is a single digit in a big integer.
 19const Limb = usize;
 20
 21// The number of reserved bits in a Limb.
 22const carry_bits = 1;
 23
 24// The number of active bits in a Limb.
 25const t_bits: usize = @bitSizeOf(Limb) - carry_bits;
 26
 27// A TLimb is a Limb that is truncated to t_bits.
 28const TLimb = meta.Int(.unsigned, t_bits);
 29
 30const native_endian = builtin.target.cpu.arch.endian();
 31
 32// A WideLimb is a Limb that is twice as wide as a normal Limb.
 33const WideLimb = struct {
 34    hi: Limb,
 35    lo: Limb,
 36};
 37
 38/// Value is too large for the destination.
 39pub const OverflowError = error{Overflow};
 40
 41/// Invalid modulus. Modulus must be odd.
 42pub const InvalidModulusError = error{ EvenModulus, ModulusTooSmall };
 43
 44/// Exponentiation with a null exponent.
 45/// Exponentiation in cryptographic protocols is almost always a sign of a bug which can lead to trivial attacks.
 46/// Therefore, this module returns an error when a null exponent is encountered, encouraging applications to handle this case explicitly.
 47pub const NullExponentError = error{NullExponent};
 48
 49/// Invalid field element for the given modulus.
 50pub const FieldElementError = error{NonCanonical};
 51
 52/// Invalid representation (Montgomery vs non-Montgomery domain.)
 53pub const RepresentationError = error{UnexpectedRepresentation};
 54
 55/// The set of all possible errors `std.crypto.ff` functions can return.
 56pub const Error = OverflowError || InvalidModulusError || NullExponentError || FieldElementError || RepresentationError;
 57
 58/// An unsigned big integer with a fixed maximum size (`max_bits`), suitable for cryptographic operations.
 59/// Unless side-channels mitigations are explicitly disabled, operations are designed to be constant-time.
 60pub fn Uint(comptime max_bits: comptime_int) type {
 61    comptime assert(@bitSizeOf(Limb) % 8 == 0); // Limb size must be a multiple of 8
 62
 63    return struct {
 64        const Self = @This();
 65        const max_limbs_count = math.divCeil(usize, max_bits, t_bits) catch unreachable;
 66
 67        limbs_buffer: [max_limbs_count]Limb,
 68        /// The number of active limbs.
 69        limbs_len: usize,
 70
 71        /// Number of bytes required to serialize an integer.
 72        pub const encoded_bytes = math.divCeil(usize, max_bits, 8) catch unreachable;
 73
 74        /// Constant slice of active limbs.
 75        fn limbsConst(self: *const Self) []const Limb {
 76            return self.limbs_buffer[0..self.limbs_len];
 77        }
 78
 79        /// Mutable slice of active limbs.
 80        fn limbs(self: *Self) []Limb {
 81            return self.limbs_buffer[0..self.limbs_len];
 82        }
 83
 84        // Removes limbs whose value is zero from the active limbs.
 85        fn normalize(self: Self) Self {
 86            var res = self;
 87            if (self.limbs_len < 2) {
 88                return res;
 89            }
 90            var i = self.limbs_len - 1;
 91            while (i > 0 and res.limbsConst()[i] == 0) : (i -= 1) {}
 92            res.limbs_len = i + 1;
 93            assert(res.limbs_len <= res.limbs_buffer.len);
 94            return res;
 95        }
 96
 97        /// The zero integer.
 98        pub const zero: Self = .{
 99            .limbs_buffer = [1]Limb{0} ** max_limbs_count,
100            .limbs_len = max_limbs_count,
101        };
102
103        /// Creates a new big integer from a primitive type.
104        /// This function may not run in constant time.
105        pub fn fromPrimitive(comptime T: type, init_value: T) OverflowError!Self {
106            var x = init_value;
107            var out: Self = .{
108                .limbs_buffer = undefined,
109                .limbs_len = max_limbs_count,
110            };
111            for (&out.limbs_buffer) |*limb| {
112                limb.* = if (@bitSizeOf(T) > t_bits) @as(TLimb, @truncate(x)) else x;
113                x = math.shr(T, x, t_bits);
114            }
115            if (x != 0) {
116                return error.Overflow;
117            }
118            return out;
119        }
120
121        /// Converts a big integer to a primitive type.
122        /// This function may not run in constant time.
123        pub fn toPrimitive(self: Self, comptime T: type) OverflowError!T {
124            var x: T = 0;
125            var i = self.limbs_len - 1;
126            while (true) : (i -= 1) {
127                if (@bitSizeOf(T) >= t_bits and math.shr(T, x, @bitSizeOf(T) - t_bits) != 0) {
128                    return error.Overflow;
129                }
130                x = math.shl(T, x, t_bits);
131                const v = math.cast(T, self.limbsConst()[i]) orelse return error.Overflow;
132                x |= v;
133                if (i == 0) break;
134            }
135            return x;
136        }
137
138        /// Encodes a big integer into a byte array.
139        pub fn toBytes(self: Self, bytes: []u8, comptime endian: Endian) OverflowError!void {
140            if (bytes.len == 0) {
141                if (self.isZero()) return;
142                return error.Overflow;
143            }
144            @memset(bytes, 0);
145            var shift: usize = 0;
146            var out_i: usize = switch (endian) {
147                .big => bytes.len - 1,
148                .little => 0,
149            };
150            for (0..self.limbs_len) |i| {
151                var remaining_bits = t_bits;
152                var limb = self.limbsConst()[i];
153                while (remaining_bits >= 8) {
154                    bytes[out_i] |= math.shl(u8, @as(u8, @truncate(limb)), shift);
155                    const consumed = 8 - shift;
156                    limb >>= @as(u4, @truncate(consumed));
157                    remaining_bits -= consumed;
158                    shift = 0;
159                    switch (endian) {
160                        .big => {
161                            if (out_i == 0) {
162                                if (i != self.limbs_len - 1 or limb != 0) {
163                                    return error.Overflow;
164                                }
165                                return;
166                            }
167                            out_i -= 1;
168                        },
169                        .little => {
170                            out_i += 1;
171                            if (out_i == bytes.len) {
172                                if (i != self.limbs_len - 1 or limb != 0) {
173                                    return error.Overflow;
174                                }
175                                return;
176                            }
177                        },
178                    }
179                }
180                bytes[out_i] |= @as(u8, @truncate(limb));
181                shift = remaining_bits;
182            }
183        }
184
185        /// Creates a new big integer from a byte array.
186        pub fn fromBytes(bytes: []const u8, comptime endian: Endian) OverflowError!Self {
187            if (bytes.len == 0) return Self.zero;
188            var shift: usize = 0;
189            var out = Self.zero;
190            var out_i: usize = 0;
191            var i: usize = switch (endian) {
192                .big => bytes.len - 1,
193                .little => 0,
194            };
195            while (true) {
196                const bi = bytes[i];
197                out.limbs()[out_i] |= math.shl(Limb, bi, shift);
198                shift += 8;
199                if (shift >= t_bits) {
200                    shift -= t_bits;
201                    out.limbs()[out_i] = @as(TLimb, @truncate(out.limbs()[out_i]));
202                    const overflow = math.shr(Limb, bi, 8 - shift);
203                    out_i += 1;
204                    if (out_i >= out.limbs_len) {
205                        if (overflow != 0 or i != 0) {
206                            return error.Overflow;
207                        }
208                        break;
209                    }
210                    out.limbs()[out_i] = overflow;
211                }
212                switch (endian) {
213                    .big => {
214                        if (i == 0) break;
215                        i -= 1;
216                    },
217                    .little => {
218                        i += 1;
219                        if (i == bytes.len) break;
220                    },
221                }
222            }
223            return out;
224        }
225
226        /// Returns `true` if both integers are equal.
227        pub fn eql(x: Self, y: Self) bool {
228            return crypto.timing_safe.eql([max_limbs_count]Limb, x.limbs_buffer, y.limbs_buffer);
229        }
230
231        /// Compares two integers.
232        pub fn compare(x: Self, y: Self) math.Order {
233            return crypto.timing_safe.compare(
234                Limb,
235                x.limbsConst(),
236                y.limbsConst(),
237                .little,
238            );
239        }
240
241        /// Returns `true` if the integer is zero.
242        pub fn isZero(x: Self) bool {
243            var t: Limb = 0;
244            for (x.limbsConst()) |elem| {
245                t |= elem;
246            }
247            return ct.eql(t, 0);
248        }
249
250        /// Returns `true` if the integer is odd.
251        pub fn isOdd(x: Self) bool {
252            return @as(u1, @truncate(x.limbsConst()[0])) != 0;
253        }
254
255        /// Adds `y` to `x`, and returns `true` if the operation overflowed.
256        pub fn addWithOverflow(x: *Self, y: Self) u1 {
257            return x.conditionalAddWithOverflow(true, y);
258        }
259
260        /// Subtracts `y` from `x`, and returns `true` if the operation overflowed.
261        pub fn subWithOverflow(x: *Self, y: Self) u1 {
262            return x.conditionalSubWithOverflow(true, y);
263        }
264
265        // Replaces the limbs of `x` with the limbs of `y` if `on` is `true`.
266        fn cmov(x: *Self, on: bool, y: Self) void {
267            for (x.limbs(), y.limbsConst()) |*x_limb, y_limb| {
268                x_limb.* = ct.select(on, y_limb, x_limb.*);
269            }
270        }
271
272        // Adds `y` to `x` if `on` is `true`, and returns `true` if the
273        // operation overflowed.
274        fn conditionalAddWithOverflow(x: *Self, on: bool, y: Self) u1 {
275            var carry: u1 = 0;
276            for (x.limbs(), y.limbsConst()) |*x_limb, y_limb| {
277                const res = x_limb.* + y_limb + carry;
278                x_limb.* = ct.select(on, @as(TLimb, @truncate(res)), x_limb.*);
279                carry = @truncate(res >> t_bits);
280            }
281            return carry;
282        }
283
284        // Subtracts `y` from `x` if `on` is `true`, and returns `true` if the
285        // operation overflowed.
286        fn conditionalSubWithOverflow(x: *Self, on: bool, y: Self) u1 {
287            var borrow: u1 = 0;
288            for (x.limbs(), y.limbsConst()) |*x_limb, y_limb| {
289                const res = x_limb.* -% y_limb -% borrow;
290                x_limb.* = ct.select(on, @as(TLimb, @truncate(res)), x_limb.*);
291                borrow = @truncate(res >> t_bits);
292            }
293            return borrow;
294        }
295    };
296}
297
298/// A field element.
299fn Fe_(comptime bits: comptime_int) type {
300    return struct {
301        const Self = @This();
302
303        const FeUint = Uint(bits);
304
305        /// The element value as a `Uint`.
306        v: FeUint,
307
308        /// `true` if the element is in Montgomery form.
309        montgomery: bool = false,
310
311        /// The maximum number of bytes required to encode a field element.
312        pub const encoded_bytes = FeUint.encoded_bytes;
313
314        // The number of active limbs to represent the field element.
315        fn limbs_count(self: Self) usize {
316            return self.v.limbs_len;
317        }
318
319        /// Creates a field element from a primitive.
320        /// This function may not run in constant time.
321        pub fn fromPrimitive(comptime T: type, m: Modulus(bits), x: T) (OverflowError || FieldElementError)!Self {
322            comptime assert(@bitSizeOf(T) <= bits); // Primitive type is larger than the modulus type.
323            const v = try FeUint.fromPrimitive(T, x);
324            var fe = Self{ .v = v };
325            try m.shrink(&fe);
326            try m.rejectNonCanonical(fe);
327            return fe;
328        }
329
330        /// Converts the field element to a primitive.
331        /// This function may not run in constant time.
332        pub fn toPrimitive(self: Self, comptime T: type) OverflowError!T {
333            return self.v.toPrimitive(T);
334        }
335
336        /// Creates a field element from a byte string.
337        pub fn fromBytes(m: Modulus(bits), bytes: []const u8, comptime endian: Endian) (OverflowError || FieldElementError)!Self {
338            const v = try FeUint.fromBytes(bytes, endian);
339            var fe = Self{ .v = v };
340            try m.shrink(&fe);
341            try m.rejectNonCanonical(fe);
342            return fe;
343        }
344
345        /// Converts the field element to a byte string.
346        pub fn toBytes(self: Self, bytes: []u8, comptime endian: Endian) OverflowError!void {
347            return self.v.toBytes(bytes, endian);
348        }
349
350        /// Returns `true` if the field elements are equal, in constant time.
351        pub fn eql(x: Self, y: Self) bool {
352            return x.v.eql(y.v);
353        }
354
355        /// Compares two field elements in constant time.
356        pub fn compare(x: Self, y: Self) math.Order {
357            return x.v.compare(y.v);
358        }
359
360        /// Returns `true` if the element is zero.
361        pub fn isZero(self: Self) bool {
362            return self.v.isZero();
363        }
364
365        /// Returns `true` is the element is odd.
366        pub fn isOdd(self: Self) bool {
367            return self.v.isOdd();
368        }
369    };
370}
371
372/// A modulus, defining a finite field.
373/// All operations within the field are performed modulo this modulus, without heap allocations.
374/// `max_bits` represents the number of bits in the maximum value the modulus can be set to.
375pub fn Modulus(comptime max_bits: comptime_int) type {
376    return struct {
377        const Self = @This();
378
379        /// A field element, representing a value within the field defined by this modulus.
380        pub const Fe = Fe_(max_bits);
381
382        const FeUint = Fe.FeUint;
383
384        /// The neutral element.
385        zero: Fe,
386
387        /// The modulus value.
388        v: FeUint,
389
390        /// R^2 for the Montgomery representation.
391        rr: Fe,
392        /// Inverse of the first limb
393        m0inv: Limb,
394        /// Number of leading zero bits in the modulus.
395        leading: usize,
396
397        // Number of active limbs in the modulus.
398        fn limbs_count(self: Self) usize {
399            return self.v.limbs_len;
400        }
401
402        /// Actual size of the modulus, in bits.
403        pub fn bits(self: Self) usize {
404            return self.limbs_count() * t_bits - self.leading;
405        }
406
407        /// Returns the element `1`.
408        pub fn one(self: Self) Fe {
409            var fe = self.zero;
410            fe.v.limbs()[0] = 1;
411            return fe;
412        }
413
414        /// Creates a new modulus from a `Uint` value.
415        /// The modulus must be odd and larger than 2.
416        pub fn fromUint(v_: FeUint) InvalidModulusError!Self {
417            if (!v_.isOdd()) return error.EvenModulus;
418
419            var v = v_.normalize();
420            const hi = v.limbsConst()[v.limbs_len - 1];
421            const lo = v.limbsConst()[0];
422
423            if (v.limbs_len < 2 and lo < 3) {
424                return error.ModulusTooSmall;
425            }
426
427            const leading = @clz(hi) - carry_bits;
428
429            var y = lo;
430
431            inline for (0..comptime math.log2_int(usize, t_bits)) |_| {
432                y = y *% (2 -% lo *% y);
433            }
434            const m0inv = (@as(Limb, 1) << t_bits) - (@as(TLimb, @truncate(y)));
435
436            const zero = Fe{ .v = FeUint.zero };
437
438            var m = Self{
439                .zero = zero,
440                .v = v,
441                .leading = leading,
442                .m0inv = m0inv,
443                .rr = undefined, // will be computed right after
444            };
445            m.shrink(&m.zero) catch unreachable;
446            computeRR(&m);
447
448            return m;
449        }
450
451        /// Creates a new modulus from a primitive value.
452        /// The modulus must be odd and larger than 2.
453        pub fn fromPrimitive(comptime T: type, x: T) (InvalidModulusError || OverflowError)!Self {
454            comptime assert(@bitSizeOf(T) <= max_bits); // Primitive type is larger than the modulus type.
455            const v = try FeUint.fromPrimitive(T, x);
456            return try Self.fromUint(v);
457        }
458
459        /// Creates a new modulus from a byte string.
460        pub fn fromBytes(bytes: []const u8, comptime endian: Endian) (InvalidModulusError || OverflowError)!Self {
461            const v = try FeUint.fromBytes(bytes, endian);
462            return try Self.fromUint(v);
463        }
464
465        /// Serializes the modulus to a byte string.
466        pub fn toBytes(self: Self, bytes: []u8, comptime endian: Endian) OverflowError!void {
467            return self.v.toBytes(bytes, endian);
468        }
469
470        /// Rejects field elements that are not in the canonical form.
471        pub fn rejectNonCanonical(self: Self, fe: Fe) error{NonCanonical}!void {
472            if (fe.limbs_count() != self.limbs_count() or ct.limbsCmpGeq(fe.v, self.v)) {
473                return error.NonCanonical;
474            }
475        }
476
477        // Makes the number of active limbs in a field element match the one of the modulus.
478        fn shrink(self: Self, fe: *Fe) OverflowError!void {
479            const new_len = self.limbs_count();
480            if (fe.limbs_count() < new_len) return error.Overflow;
481            var acc: Limb = 0;
482            for (fe.v.limbsConst()[new_len..]) |limb| {
483                acc |= limb;
484            }
485            if (acc != 0) return error.Overflow;
486            if (new_len > fe.v.limbs_buffer.len) return error.Overflow;
487            fe.v.limbs_len = new_len;
488        }
489
490        // Computes R^2 for the Montgomery representation.
491        fn computeRR(self: *Self) void {
492            self.rr = self.zero;
493            const n = self.rr.limbs_count();
494            self.rr.v.limbs()[n - 1] = 1;
495            for ((n - 1)..(2 * n)) |_| {
496                self.shiftIn(&self.rr, 0);
497            }
498            self.shrink(&self.rr) catch unreachable;
499        }
500
501        /// Computes x << t_bits + y (mod m)
502        fn shiftIn(self: Self, x: *Fe, y: Limb) void {
503            var d = self.zero;
504            const x_limbs = x.v.limbs();
505            const d_limbs = d.v.limbs();
506            const m_limbs = self.v.limbsConst();
507
508            var need_sub = false;
509            var i: usize = t_bits - 1;
510            while (true) : (i -= 1) {
511                var carry: u1 = @truncate(math.shr(Limb, y, i));
512                var borrow: u1 = 0;
513                for (0..self.limbs_count()) |j| {
514                    const l = ct.select(need_sub, d_limbs[j], x_limbs[j]);
515                    var res = (l << 1) + carry;
516                    x_limbs[j] = @as(TLimb, @truncate(res));
517                    carry = @truncate(res >> t_bits);
518
519                    res = x_limbs[j] -% m_limbs[j] -% borrow;
520                    d_limbs[j] = @as(TLimb, @truncate(res));
521
522                    borrow = @truncate(res >> t_bits);
523                }
524                need_sub = ct.eql(carry, borrow);
525                if (i == 0) break;
526            }
527            x.v.cmov(need_sub, d.v);
528        }
529
530        /// Adds two field elements (mod m).
531        pub fn add(self: Self, x: Fe, y: Fe) Fe {
532            var out = x;
533            const overflow = out.v.addWithOverflow(y.v);
534            const underflow: u1 = @bitCast(ct.limbsCmpLt(out.v, self.v));
535            const need_sub = ct.eql(overflow, underflow);
536            _ = out.v.conditionalSubWithOverflow(need_sub, self.v);
537            return out;
538        }
539
540        /// Subtracts two field elements (mod m).
541        pub fn sub(self: Self, x: Fe, y: Fe) Fe {
542            var out = x;
543            const underflow: bool = @bitCast(out.v.subWithOverflow(y.v));
544            _ = out.v.conditionalAddWithOverflow(underflow, self.v);
545            return out;
546        }
547
548        /// Converts a field element to the Montgomery form.
549        pub fn toMontgomery(self: Self, x: *Fe) RepresentationError!void {
550            if (x.montgomery) {
551                return error.UnexpectedRepresentation;
552            }
553            self.shrink(x) catch unreachable;
554            x.* = self.montgomeryMul(x.*, self.rr);
555            x.montgomery = true;
556        }
557
558        /// Takes a field element out of the Montgomery form.
559        pub fn fromMontgomery(self: Self, x: *Fe) RepresentationError!void {
560            if (!x.montgomery) {
561                return error.UnexpectedRepresentation;
562            }
563            self.shrink(x) catch unreachable;
564            x.* = self.montgomeryMul(x.*, self.one());
565            x.montgomery = false;
566        }
567
568        /// Reduces an arbitrary `Uint`, converting it to a field element.
569        pub fn reduce(self: Self, x: anytype) Fe {
570            var out = self.zero;
571            var i = x.limbs_len - 1;
572            if (self.limbs_count() >= 2) {
573                const start = @min(i, self.limbs_count() - 2);
574                var j = start;
575                while (true) : (j -= 1) {
576                    out.v.limbs()[j] = x.limbsConst()[i];
577                    i -= 1;
578                    if (j == 0) break;
579                }
580            }
581            while (true) : (i -= 1) {
582                self.shiftIn(&out, x.limbsConst()[i]);
583                if (i == 0) break;
584            }
585            return out;
586        }
587
588        fn montgomeryLoop(self: Self, d: *Fe, x: Fe, y: Fe) u1 {
589            assert(d.limbs_count() == x.limbs_count());
590            assert(d.limbs_count() == y.limbs_count());
591            assert(d.limbs_count() == self.limbs_count());
592
593            const a_limbs = x.v.limbsConst();
594            const b_limbs = y.v.limbsConst();
595            const d_limbs = d.v.limbs();
596            const m_limbs = self.v.limbsConst();
597
598            var overflow: u1 = 0;
599            for (0..self.limbs_count()) |i| {
600                var carry: Limb = 0;
601
602                var wide = ct.mulWide(a_limbs[i], b_limbs[0]);
603                var z_lo = @addWithOverflow(d_limbs[0], wide.lo);
604                const f = @as(TLimb, @truncate(z_lo[0] *% self.m0inv));
605                var z_hi = wide.hi +% z_lo[1];
606                wide = ct.mulWide(f, m_limbs[0]);
607                z_lo = @addWithOverflow(z_lo[0], wide.lo);
608                z_hi +%= z_lo[1];
609                z_hi +%= wide.hi;
610                carry = (z_hi << 1) | (z_lo[0] >> t_bits);
611
612                for (1..self.limbs_count()) |j| {
613                    wide = ct.mulWide(a_limbs[i], b_limbs[j]);
614                    z_lo = @addWithOverflow(d_limbs[j], wide.lo);
615                    z_hi = wide.hi +% z_lo[1];
616                    wide = ct.mulWide(f, m_limbs[j]);
617                    z_lo = @addWithOverflow(z_lo[0], wide.lo);
618                    z_hi +%= z_lo[1];
619                    z_hi +%= wide.hi;
620                    z_lo = @addWithOverflow(z_lo[0], carry);
621                    z_hi +%= z_lo[1];
622                    if (j > 0) {
623                        d_limbs[j - 1] = @as(TLimb, @truncate(z_lo[0]));
624                    }
625                    carry = (z_hi << 1) | (z_lo[0] >> t_bits);
626                }
627                const z = overflow + carry;
628                d_limbs[self.limbs_count() - 1] = @as(TLimb, @truncate(z));
629                overflow = @as(u1, @truncate(z >> t_bits));
630            }
631            return overflow;
632        }
633
634        // Montgomery multiplication.
635        fn montgomeryMul(self: Self, x: Fe, y: Fe) Fe {
636            var d = self.zero;
637            assert(x.limbs_count() == self.limbs_count());
638            assert(y.limbs_count() == self.limbs_count());
639            const overflow = self.montgomeryLoop(&d, x, y);
640            const underflow = 1 -% @intFromBool(ct.limbsCmpGeq(d.v, self.v));
641            const need_sub = ct.eql(overflow, underflow);
642            _ = d.v.conditionalSubWithOverflow(need_sub, self.v);
643            d.montgomery = x.montgomery == y.montgomery;
644            return d;
645        }
646
647        // Montgomery squaring.
648        fn montgomerySq(self: Self, x: Fe) Fe {
649            var d = self.zero;
650            assert(x.limbs_count() == self.limbs_count());
651            const overflow = self.montgomeryLoop(&d, x, x);
652            const underflow = 1 -% @intFromBool(ct.limbsCmpGeq(d.v, self.v));
653            const need_sub = ct.eql(overflow, underflow);
654            _ = d.v.conditionalSubWithOverflow(need_sub, self.v);
655            d.montgomery = true;
656            return d;
657        }
658
659        // Returns x^e (mod m), with the exponent provided as a byte string.
660        // `public` must be set to `false` if the exponent it secret.
661        fn powWithEncodedExponentInternal(self: Self, x: Fe, e: []const u8, endian: Endian, comptime public: bool) NullExponentError!Fe {
662            var acc: u8 = 0;
663            for (e) |b| acc |= b;
664            if (acc == 0) return error.NullExponent;
665
666            var out = self.one();
667            self.toMontgomery(&out) catch unreachable;
668
669            if (public and e.len < 3 or (e.len == 3 and e[if (endian == .big) 0 else 2] <= 0b1111)) {
670                // Do not use a precomputation table for short, public exponents
671                var x_m = x;
672                if (x.montgomery == false) {
673                    self.toMontgomery(&x_m) catch unreachable;
674                }
675                var s = switch (endian) {
676                    .big => 0,
677                    .little => e.len - 1,
678                };
679                while (true) {
680                    const b = e[s];
681                    var j: u3 = 7;
682                    while (true) : (j -= 1) {
683                        out = self.montgomerySq(out);
684                        const k: u1 = @truncate(b >> j);
685                        if (k != 0) {
686                            const t = self.montgomeryMul(out, x_m);
687                            @memcpy(out.v.limbs(), t.v.limbsConst());
688                        }
689                        if (j == 0) break;
690                    }
691                    switch (endian) {
692                        .big => {
693                            s += 1;
694                            if (s == e.len) break;
695                        },
696                        .little => {
697                            if (s == 0) break;
698                            s -= 1;
699                        },
700                    }
701                }
702            } else {
703                // Use a precomputation table for large exponents
704                var pc = [1]Fe{x} ++ [_]Fe{self.zero} ** 14;
705                if (x.montgomery == false) {
706                    self.toMontgomery(&pc[0]) catch unreachable;
707                }
708                for (1..pc.len) |i| {
709                    pc[i] = self.montgomeryMul(pc[i - 1], pc[0]);
710                }
711                var t0 = self.zero;
712                var s = switch (endian) {
713                    .big => 0,
714                    .little => e.len - 1,
715                };
716                while (true) {
717                    const b = e[s];
718                    for ([_]u3{ 4, 0 }) |j| {
719                        for (0..4) |_| {
720                            out = self.montgomerySq(out);
721                        }
722                        const k = (b >> j) & 0b1111;
723                        if (public or std.options.side_channels_mitigations == .none) {
724                            if (k == 0) continue;
725                            t0 = pc[k - 1];
726                        } else {
727                            for (pc, 0..) |t, i| {
728                                t0.v.cmov(ct.eql(k, @as(u8, @truncate(i + 1))), t.v);
729                            }
730                        }
731                        const t1 = self.montgomeryMul(out, t0);
732                        if (public) {
733                            @memcpy(out.v.limbs(), t1.v.limbsConst());
734                        } else {
735                            out.v.cmov(!ct.eql(k, 0), t1.v);
736                        }
737                    }
738                    switch (endian) {
739                        .big => {
740                            s += 1;
741                            if (s == e.len) break;
742                        },
743                        .little => {
744                            if (s == 0) break;
745                            s -= 1;
746                        },
747                    }
748                }
749            }
750            self.fromMontgomery(&out) catch unreachable;
751            return out;
752        }
753
754        /// Multiplies two field elements.
755        pub fn mul(self: Self, x: Fe, y: Fe) Fe {
756            if (x.montgomery != y.montgomery) {
757                return self.montgomeryMul(x, y);
758            }
759            var a_ = x;
760            if (x.montgomery == false) {
761                self.toMontgomery(&a_) catch unreachable;
762            } else {
763                self.fromMontgomery(&a_) catch unreachable;
764            }
765            return self.montgomeryMul(a_, y);
766        }
767
768        /// Squares a field element.
769        pub fn sq(self: Self, x: Fe) Fe {
770            var out = x;
771            if (x.montgomery == true) {
772                self.fromMontgomery(&out) catch unreachable;
773            }
774            out = self.montgomerySq(out);
775            out.montgomery = false;
776            self.toMontgomery(&out) catch unreachable;
777            return out;
778        }
779
780        /// Returns x^e (mod m) in constant time.
781        pub fn pow(self: Self, x: Fe, e: Fe) NullExponentError!Fe {
782            var buf: [Fe.encoded_bytes]u8 = undefined;
783            e.toBytes(&buf, native_endian) catch unreachable;
784            return self.powWithEncodedExponent(x, &buf, native_endian);
785        }
786
787        /// Returns x^e (mod m), assuming that the exponent is public.
788        /// The function remains constant time with respect to `x`.
789        pub fn powPublic(self: Self, x: Fe, e: Fe) NullExponentError!Fe {
790            var e_normalized = Fe{ .v = e.v.normalize() };
791            var buf_: [Fe.encoded_bytes]u8 = undefined;
792            var buf = buf_[0 .. math.divCeil(usize, e_normalized.v.limbs_len * t_bits, 8) catch unreachable];
793            e_normalized.toBytes(buf, .little) catch unreachable;
794            const leading = @clz(e_normalized.v.limbsConst()[e_normalized.v.limbs_len - carry_bits]);
795            buf = buf[0 .. buf.len - leading / 8];
796            return self.powWithEncodedPublicExponent(x, buf, .little);
797        }
798
799        /// Returns x^e (mod m), with the exponent provided as a byte string.
800        /// Exponents are usually small, so this function is faster than `powPublic` as a field element
801        /// doesn't have to be created if a serialized representation is already available.
802        ///
803        /// If the exponent is public, `powWithEncodedPublicExponent()` can be used instead for a slight speedup.
804        pub fn powWithEncodedExponent(self: Self, x: Fe, e: []const u8, endian: Endian) NullExponentError!Fe {
805            return self.powWithEncodedExponentInternal(x, e, endian, false);
806        }
807
808        /// Returns x^e (mod m), the exponent being public and provided as a byte string.
809        /// Exponents are usually small, so this function is faster than `powPublic` as a field element
810        /// doesn't have to be created if a serialized representation is already available.
811        ///
812        /// If the exponent is secret, `powWithEncodedExponent` must be used instead.
813        pub fn powWithEncodedPublicExponent(self: Self, x: Fe, e: []const u8, endian: Endian) NullExponentError!Fe {
814            return self.powWithEncodedExponentInternal(x, e, endian, true);
815        }
816    };
817}
818
819const ct = if (std.options.side_channels_mitigations == .none) ct_unprotected else ct_protected;
820
821const ct_protected = struct {
822    // Returns x if on is true, otherwise y.
823    fn select(on: bool, x: Limb, y: Limb) Limb {
824        const mask = @as(Limb, 0) -% @intFromBool(on);
825        return y ^ (mask & (y ^ x));
826    }
827
828    // Compares two values in constant time.
829    fn eql(x: anytype, y: @TypeOf(x)) bool {
830        const c1 = @subWithOverflow(x, y)[1];
831        const c2 = @subWithOverflow(y, x)[1];
832        return @as(bool, @bitCast(1 - (c1 | c2)));
833    }
834
835    // Compares two big integers in constant time, returning true if x < y.
836    fn limbsCmpLt(x: anytype, y: @TypeOf(x)) bool {
837        var c: u1 = 0;
838        for (x.limbsConst(), y.limbsConst()) |x_limb, y_limb| {
839            c = @truncate((x_limb -% y_limb -% c) >> t_bits);
840        }
841        return c != 0;
842    }
843
844    // Compares two big integers in constant time, returning true if x >= y.
845    fn limbsCmpGeq(x: anytype, y: @TypeOf(x)) bool {
846        return !limbsCmpLt(x, y);
847    }
848
849    // Multiplies two limbs and returns the result as a wide limb.
850    fn mulWide(x: Limb, y: Limb) WideLimb {
851        const half_bits = @typeInfo(Limb).int.bits / 2;
852        const Half = meta.Int(.unsigned, half_bits);
853        const x0 = @as(Half, @truncate(x));
854        const x1 = @as(Half, @truncate(x >> half_bits));
855        const y0 = @as(Half, @truncate(y));
856        const y1 = @as(Half, @truncate(y >> half_bits));
857        const w0 = math.mulWide(Half, x0, y0);
858        const t = math.mulWide(Half, x1, y0) + (w0 >> half_bits);
859        var w1: Limb = @as(Half, @truncate(t));
860        const w2 = @as(Half, @truncate(t >> half_bits));
861        w1 += math.mulWide(Half, x0, y1);
862        const hi = math.mulWide(Half, x1, y1) + w2 + (w1 >> half_bits);
863        const lo = x *% y;
864        return .{ .hi = hi, .lo = lo };
865    }
866};
867
868const ct_unprotected = struct {
869    // Returns x if on is true, otherwise y.
870    fn select(on: bool, x: Limb, y: Limb) Limb {
871        return if (on) x else y;
872    }
873
874    // Compares two values in constant time.
875    fn eql(x: anytype, y: @TypeOf(x)) bool {
876        return x == y;
877    }
878
879    // Compares two big integers in constant time, returning true if x < y.
880    fn limbsCmpLt(x: anytype, y: @TypeOf(x)) bool {
881        const x_limbs = x.limbsConst();
882        const y_limbs = y.limbsConst();
883        assert(x_limbs.len == y_limbs.len);
884
885        var i = x_limbs.len;
886        while (i != 0) {
887            i -= 1;
888            if (x_limbs[i] != y_limbs[i]) {
889                return x_limbs[i] < y_limbs[i];
890            }
891        }
892        return false;
893    }
894
895    // Compares two big integers in constant time, returning true if x >= y.
896    fn limbsCmpGeq(x: anytype, y: @TypeOf(x)) bool {
897        return !limbsCmpLt(x, y);
898    }
899
900    // Multiplies two limbs and returns the result as a wide limb.
901    fn mulWide(x: Limb, y: Limb) WideLimb {
902        const wide = math.mulWide(Limb, x, y);
903        return .{
904            .hi = @as(Limb, @truncate(wide >> @typeInfo(Limb).int.bits)),
905            .lo = @as(Limb, @truncate(wide)),
906        };
907    }
908};
909
910test "finite field arithmetic" {
911    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
912
913    const M = Modulus(256);
914    const m = try M.fromPrimitive(u256, 3429938563481314093726330772853735541133072814650493833233);
915    var x = try M.Fe.fromPrimitive(u256, m, 80169837251094269539116136208111827396136208141182357733);
916    var y = try M.Fe.fromPrimitive(u256, m, 24620149608466364616251608466389896540098571);
917
918    const x_ = try x.toPrimitive(u256);
919    try testing.expect((try M.Fe.fromPrimitive(@TypeOf(x_), m, x_)).eql(x));
920    try testing.expectError(error.Overflow, x.toPrimitive(u50));
921
922    const bits = m.bits();
923    try testing.expectEqual(bits, 192);
924
925    var x_y = m.mul(x, y);
926    try testing.expectEqual(x_y.toPrimitive(u256), 1666576607955767413750776202132407807424848069716933450241);
927
928    try m.toMontgomery(&x);
929    x_y = m.mul(x, y);
930    try testing.expectEqual(x_y.toPrimitive(u256), 1666576607955767413750776202132407807424848069716933450241);
931    try m.fromMontgomery(&x);
932
933    x = m.add(x, y);
934    try testing.expectEqual(x.toPrimitive(u256), 80169837251118889688724602572728079004602598037722456304);
935    x = m.sub(x, y);
936    try testing.expectEqual(x.toPrimitive(u256), 80169837251094269539116136208111827396136208141182357733);
937
938    const big = try Uint(512).fromPrimitive(u495, 77285373554113307281465049383342993856348131409372633077285373554113307281465049383323332333429938563481314093726330772853735541133072814650493833233);
939    const reduced = m.reduce(big);
940    try testing.expectEqual(reduced.toPrimitive(u495), 858047099884257670294681641776170038885500210968322054970);
941
942    const x_pow_y = try m.powPublic(x, y);
943    try testing.expectEqual(x_pow_y.toPrimitive(u256), 1631933139300737762906024873185789093007782131928298618473);
944    try m.toMontgomery(&x);
945    const x_pow_y2 = try m.powPublic(x, y);
946    try m.fromMontgomery(&x);
947    try testing.expect(x_pow_y2.eql(x_pow_y));
948    try testing.expectError(error.NullExponent, m.powPublic(x, m.zero));
949
950    try testing.expect(!x.isZero());
951    try testing.expect(!y.isZero());
952    try testing.expect(m.v.isOdd());
953
954    const x_sq = m.sq(x);
955    const x_sq2 = m.mul(x, x);
956    try testing.expect(x_sq.eql(x_sq2));
957    try m.toMontgomery(&x);
958    const x_sq3 = m.sq(x);
959    const x_sq4 = m.mul(x, x);
960    try testing.expect(x_sq.eql(x_sq3));
961    try testing.expect(x_sq3.eql(x_sq4));
962    try m.fromMontgomery(&x);
963}
964
965fn testCt(ct_: anytype) !void {
966    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
967
968    const l0: Limb = 0;
969    const l1: Limb = 1;
970    try testing.expectEqual(l1, ct_.select(true, l1, l0));
971    try testing.expectEqual(l0, ct_.select(false, l1, l0));
972    try testing.expectEqual(false, ct_.eql(l1, l0));
973    try testing.expectEqual(true, ct_.eql(l1, l1));
974
975    const M = Modulus(256);
976    const m = try M.fromPrimitive(u256, 3429938563481314093726330772853735541133072814650493833233);
977    const x = try M.Fe.fromPrimitive(u256, m, 80169837251094269539116136208111827396136208141182357733);
978    const y = try M.Fe.fromPrimitive(u256, m, 24620149608466364616251608466389896540098571);
979    try testing.expectEqual(false, ct_.limbsCmpLt(x.v, y.v));
980    try testing.expectEqual(true, ct_.limbsCmpGeq(x.v, y.v));
981
982    try testing.expectEqual(WideLimb{ .hi = 0, .lo = 0x88 }, ct_.mulWide(1 << 3, (1 << 4) + 1));
983}
984
985test ct {
986    try testCt(ct_protected);
987    try testCt(ct_unprotected);
988}