master
1//! Allocation-free, (best-effort) constant-time, finite field arithmetic for large integers.
2//!
3//! Unlike `std.math.big`, these integers have a fixed maximum length and are only designed to be used for modular arithmetic.
4//! Arithmetic operations are meant to run in constant-time for a given modulus, making them suitable for cryptography.
5//!
6//! Parts of that code was ported from the BSD-licensed crypto/internal/bigmod/nat.go file in the Go language, itself inspired from BearSSL.
7
8const std = @import("std");
9const builtin = @import("builtin");
10const crypto = std.crypto;
11const math = std.math;
12const mem = std.mem;
13const meta = std.meta;
14const testing = std.testing;
15const assert = std.debug.assert;
16const Endian = std.builtin.Endian;
17
18// A Limb is a single digit in a big integer.
19const Limb = usize;
20
21// The number of reserved bits in a Limb.
22const carry_bits = 1;
23
24// The number of active bits in a Limb.
25const t_bits: usize = @bitSizeOf(Limb) - carry_bits;
26
27// A TLimb is a Limb that is truncated to t_bits.
28const TLimb = meta.Int(.unsigned, t_bits);
29
30const native_endian = builtin.target.cpu.arch.endian();
31
32// A WideLimb is a Limb that is twice as wide as a normal Limb.
33const WideLimb = struct {
34 hi: Limb,
35 lo: Limb,
36};
37
38/// Value is too large for the destination.
39pub const OverflowError = error{Overflow};
40
41/// Invalid modulus. Modulus must be odd.
42pub const InvalidModulusError = error{ EvenModulus, ModulusTooSmall };
43
44/// Exponentiation with a null exponent.
45/// Exponentiation in cryptographic protocols is almost always a sign of a bug which can lead to trivial attacks.
46/// Therefore, this module returns an error when a null exponent is encountered, encouraging applications to handle this case explicitly.
47pub const NullExponentError = error{NullExponent};
48
49/// Invalid field element for the given modulus.
50pub const FieldElementError = error{NonCanonical};
51
52/// Invalid representation (Montgomery vs non-Montgomery domain.)
53pub const RepresentationError = error{UnexpectedRepresentation};
54
55/// The set of all possible errors `std.crypto.ff` functions can return.
56pub const Error = OverflowError || InvalidModulusError || NullExponentError || FieldElementError || RepresentationError;
57
58/// An unsigned big integer with a fixed maximum size (`max_bits`), suitable for cryptographic operations.
59/// Unless side-channels mitigations are explicitly disabled, operations are designed to be constant-time.
60pub fn Uint(comptime max_bits: comptime_int) type {
61 comptime assert(@bitSizeOf(Limb) % 8 == 0); // Limb size must be a multiple of 8
62
63 return struct {
64 const Self = @This();
65 const max_limbs_count = math.divCeil(usize, max_bits, t_bits) catch unreachable;
66
67 limbs_buffer: [max_limbs_count]Limb,
68 /// The number of active limbs.
69 limbs_len: usize,
70
71 /// Number of bytes required to serialize an integer.
72 pub const encoded_bytes = math.divCeil(usize, max_bits, 8) catch unreachable;
73
74 /// Constant slice of active limbs.
75 fn limbsConst(self: *const Self) []const Limb {
76 return self.limbs_buffer[0..self.limbs_len];
77 }
78
79 /// Mutable slice of active limbs.
80 fn limbs(self: *Self) []Limb {
81 return self.limbs_buffer[0..self.limbs_len];
82 }
83
84 // Removes limbs whose value is zero from the active limbs.
85 fn normalize(self: Self) Self {
86 var res = self;
87 if (self.limbs_len < 2) {
88 return res;
89 }
90 var i = self.limbs_len - 1;
91 while (i > 0 and res.limbsConst()[i] == 0) : (i -= 1) {}
92 res.limbs_len = i + 1;
93 assert(res.limbs_len <= res.limbs_buffer.len);
94 return res;
95 }
96
97 /// The zero integer.
98 pub const zero: Self = .{
99 .limbs_buffer = [1]Limb{0} ** max_limbs_count,
100 .limbs_len = max_limbs_count,
101 };
102
103 /// Creates a new big integer from a primitive type.
104 /// This function may not run in constant time.
105 pub fn fromPrimitive(comptime T: type, init_value: T) OverflowError!Self {
106 var x = init_value;
107 var out: Self = .{
108 .limbs_buffer = undefined,
109 .limbs_len = max_limbs_count,
110 };
111 for (&out.limbs_buffer) |*limb| {
112 limb.* = if (@bitSizeOf(T) > t_bits) @as(TLimb, @truncate(x)) else x;
113 x = math.shr(T, x, t_bits);
114 }
115 if (x != 0) {
116 return error.Overflow;
117 }
118 return out;
119 }
120
121 /// Converts a big integer to a primitive type.
122 /// This function may not run in constant time.
123 pub fn toPrimitive(self: Self, comptime T: type) OverflowError!T {
124 var x: T = 0;
125 var i = self.limbs_len - 1;
126 while (true) : (i -= 1) {
127 if (@bitSizeOf(T) >= t_bits and math.shr(T, x, @bitSizeOf(T) - t_bits) != 0) {
128 return error.Overflow;
129 }
130 x = math.shl(T, x, t_bits);
131 const v = math.cast(T, self.limbsConst()[i]) orelse return error.Overflow;
132 x |= v;
133 if (i == 0) break;
134 }
135 return x;
136 }
137
138 /// Encodes a big integer into a byte array.
139 pub fn toBytes(self: Self, bytes: []u8, comptime endian: Endian) OverflowError!void {
140 if (bytes.len == 0) {
141 if (self.isZero()) return;
142 return error.Overflow;
143 }
144 @memset(bytes, 0);
145 var shift: usize = 0;
146 var out_i: usize = switch (endian) {
147 .big => bytes.len - 1,
148 .little => 0,
149 };
150 for (0..self.limbs_len) |i| {
151 var remaining_bits = t_bits;
152 var limb = self.limbsConst()[i];
153 while (remaining_bits >= 8) {
154 bytes[out_i] |= math.shl(u8, @as(u8, @truncate(limb)), shift);
155 const consumed = 8 - shift;
156 limb >>= @as(u4, @truncate(consumed));
157 remaining_bits -= consumed;
158 shift = 0;
159 switch (endian) {
160 .big => {
161 if (out_i == 0) {
162 if (i != self.limbs_len - 1 or limb != 0) {
163 return error.Overflow;
164 }
165 return;
166 }
167 out_i -= 1;
168 },
169 .little => {
170 out_i += 1;
171 if (out_i == bytes.len) {
172 if (i != self.limbs_len - 1 or limb != 0) {
173 return error.Overflow;
174 }
175 return;
176 }
177 },
178 }
179 }
180 bytes[out_i] |= @as(u8, @truncate(limb));
181 shift = remaining_bits;
182 }
183 }
184
185 /// Creates a new big integer from a byte array.
186 pub fn fromBytes(bytes: []const u8, comptime endian: Endian) OverflowError!Self {
187 if (bytes.len == 0) return Self.zero;
188 var shift: usize = 0;
189 var out = Self.zero;
190 var out_i: usize = 0;
191 var i: usize = switch (endian) {
192 .big => bytes.len - 1,
193 .little => 0,
194 };
195 while (true) {
196 const bi = bytes[i];
197 out.limbs()[out_i] |= math.shl(Limb, bi, shift);
198 shift += 8;
199 if (shift >= t_bits) {
200 shift -= t_bits;
201 out.limbs()[out_i] = @as(TLimb, @truncate(out.limbs()[out_i]));
202 const overflow = math.shr(Limb, bi, 8 - shift);
203 out_i += 1;
204 if (out_i >= out.limbs_len) {
205 if (overflow != 0 or i != 0) {
206 return error.Overflow;
207 }
208 break;
209 }
210 out.limbs()[out_i] = overflow;
211 }
212 switch (endian) {
213 .big => {
214 if (i == 0) break;
215 i -= 1;
216 },
217 .little => {
218 i += 1;
219 if (i == bytes.len) break;
220 },
221 }
222 }
223 return out;
224 }
225
226 /// Returns `true` if both integers are equal.
227 pub fn eql(x: Self, y: Self) bool {
228 return crypto.timing_safe.eql([max_limbs_count]Limb, x.limbs_buffer, y.limbs_buffer);
229 }
230
231 /// Compares two integers.
232 pub fn compare(x: Self, y: Self) math.Order {
233 return crypto.timing_safe.compare(
234 Limb,
235 x.limbsConst(),
236 y.limbsConst(),
237 .little,
238 );
239 }
240
241 /// Returns `true` if the integer is zero.
242 pub fn isZero(x: Self) bool {
243 var t: Limb = 0;
244 for (x.limbsConst()) |elem| {
245 t |= elem;
246 }
247 return ct.eql(t, 0);
248 }
249
250 /// Returns `true` if the integer is odd.
251 pub fn isOdd(x: Self) bool {
252 return @as(u1, @truncate(x.limbsConst()[0])) != 0;
253 }
254
255 /// Adds `y` to `x`, and returns `true` if the operation overflowed.
256 pub fn addWithOverflow(x: *Self, y: Self) u1 {
257 return x.conditionalAddWithOverflow(true, y);
258 }
259
260 /// Subtracts `y` from `x`, and returns `true` if the operation overflowed.
261 pub fn subWithOverflow(x: *Self, y: Self) u1 {
262 return x.conditionalSubWithOverflow(true, y);
263 }
264
265 // Replaces the limbs of `x` with the limbs of `y` if `on` is `true`.
266 fn cmov(x: *Self, on: bool, y: Self) void {
267 for (x.limbs(), y.limbsConst()) |*x_limb, y_limb| {
268 x_limb.* = ct.select(on, y_limb, x_limb.*);
269 }
270 }
271
272 // Adds `y` to `x` if `on` is `true`, and returns `true` if the
273 // operation overflowed.
274 fn conditionalAddWithOverflow(x: *Self, on: bool, y: Self) u1 {
275 var carry: u1 = 0;
276 for (x.limbs(), y.limbsConst()) |*x_limb, y_limb| {
277 const res = x_limb.* + y_limb + carry;
278 x_limb.* = ct.select(on, @as(TLimb, @truncate(res)), x_limb.*);
279 carry = @truncate(res >> t_bits);
280 }
281 return carry;
282 }
283
284 // Subtracts `y` from `x` if `on` is `true`, and returns `true` if the
285 // operation overflowed.
286 fn conditionalSubWithOverflow(x: *Self, on: bool, y: Self) u1 {
287 var borrow: u1 = 0;
288 for (x.limbs(), y.limbsConst()) |*x_limb, y_limb| {
289 const res = x_limb.* -% y_limb -% borrow;
290 x_limb.* = ct.select(on, @as(TLimb, @truncate(res)), x_limb.*);
291 borrow = @truncate(res >> t_bits);
292 }
293 return borrow;
294 }
295 };
296}
297
298/// A field element.
299fn Fe_(comptime bits: comptime_int) type {
300 return struct {
301 const Self = @This();
302
303 const FeUint = Uint(bits);
304
305 /// The element value as a `Uint`.
306 v: FeUint,
307
308 /// `true` if the element is in Montgomery form.
309 montgomery: bool = false,
310
311 /// The maximum number of bytes required to encode a field element.
312 pub const encoded_bytes = FeUint.encoded_bytes;
313
314 // The number of active limbs to represent the field element.
315 fn limbs_count(self: Self) usize {
316 return self.v.limbs_len;
317 }
318
319 /// Creates a field element from a primitive.
320 /// This function may not run in constant time.
321 pub fn fromPrimitive(comptime T: type, m: Modulus(bits), x: T) (OverflowError || FieldElementError)!Self {
322 comptime assert(@bitSizeOf(T) <= bits); // Primitive type is larger than the modulus type.
323 const v = try FeUint.fromPrimitive(T, x);
324 var fe = Self{ .v = v };
325 try m.shrink(&fe);
326 try m.rejectNonCanonical(fe);
327 return fe;
328 }
329
330 /// Converts the field element to a primitive.
331 /// This function may not run in constant time.
332 pub fn toPrimitive(self: Self, comptime T: type) OverflowError!T {
333 return self.v.toPrimitive(T);
334 }
335
336 /// Creates a field element from a byte string.
337 pub fn fromBytes(m: Modulus(bits), bytes: []const u8, comptime endian: Endian) (OverflowError || FieldElementError)!Self {
338 const v = try FeUint.fromBytes(bytes, endian);
339 var fe = Self{ .v = v };
340 try m.shrink(&fe);
341 try m.rejectNonCanonical(fe);
342 return fe;
343 }
344
345 /// Converts the field element to a byte string.
346 pub fn toBytes(self: Self, bytes: []u8, comptime endian: Endian) OverflowError!void {
347 return self.v.toBytes(bytes, endian);
348 }
349
350 /// Returns `true` if the field elements are equal, in constant time.
351 pub fn eql(x: Self, y: Self) bool {
352 return x.v.eql(y.v);
353 }
354
355 /// Compares two field elements in constant time.
356 pub fn compare(x: Self, y: Self) math.Order {
357 return x.v.compare(y.v);
358 }
359
360 /// Returns `true` if the element is zero.
361 pub fn isZero(self: Self) bool {
362 return self.v.isZero();
363 }
364
365 /// Returns `true` is the element is odd.
366 pub fn isOdd(self: Self) bool {
367 return self.v.isOdd();
368 }
369 };
370}
371
372/// A modulus, defining a finite field.
373/// All operations within the field are performed modulo this modulus, without heap allocations.
374/// `max_bits` represents the number of bits in the maximum value the modulus can be set to.
375pub fn Modulus(comptime max_bits: comptime_int) type {
376 return struct {
377 const Self = @This();
378
379 /// A field element, representing a value within the field defined by this modulus.
380 pub const Fe = Fe_(max_bits);
381
382 const FeUint = Fe.FeUint;
383
384 /// The neutral element.
385 zero: Fe,
386
387 /// The modulus value.
388 v: FeUint,
389
390 /// R^2 for the Montgomery representation.
391 rr: Fe,
392 /// Inverse of the first limb
393 m0inv: Limb,
394 /// Number of leading zero bits in the modulus.
395 leading: usize,
396
397 // Number of active limbs in the modulus.
398 fn limbs_count(self: Self) usize {
399 return self.v.limbs_len;
400 }
401
402 /// Actual size of the modulus, in bits.
403 pub fn bits(self: Self) usize {
404 return self.limbs_count() * t_bits - self.leading;
405 }
406
407 /// Returns the element `1`.
408 pub fn one(self: Self) Fe {
409 var fe = self.zero;
410 fe.v.limbs()[0] = 1;
411 return fe;
412 }
413
414 /// Creates a new modulus from a `Uint` value.
415 /// The modulus must be odd and larger than 2.
416 pub fn fromUint(v_: FeUint) InvalidModulusError!Self {
417 if (!v_.isOdd()) return error.EvenModulus;
418
419 var v = v_.normalize();
420 const hi = v.limbsConst()[v.limbs_len - 1];
421 const lo = v.limbsConst()[0];
422
423 if (v.limbs_len < 2 and lo < 3) {
424 return error.ModulusTooSmall;
425 }
426
427 const leading = @clz(hi) - carry_bits;
428
429 var y = lo;
430
431 inline for (0..comptime math.log2_int(usize, t_bits)) |_| {
432 y = y *% (2 -% lo *% y);
433 }
434 const m0inv = (@as(Limb, 1) << t_bits) - (@as(TLimb, @truncate(y)));
435
436 const zero = Fe{ .v = FeUint.zero };
437
438 var m = Self{
439 .zero = zero,
440 .v = v,
441 .leading = leading,
442 .m0inv = m0inv,
443 .rr = undefined, // will be computed right after
444 };
445 m.shrink(&m.zero) catch unreachable;
446 computeRR(&m);
447
448 return m;
449 }
450
451 /// Creates a new modulus from a primitive value.
452 /// The modulus must be odd and larger than 2.
453 pub fn fromPrimitive(comptime T: type, x: T) (InvalidModulusError || OverflowError)!Self {
454 comptime assert(@bitSizeOf(T) <= max_bits); // Primitive type is larger than the modulus type.
455 const v = try FeUint.fromPrimitive(T, x);
456 return try Self.fromUint(v);
457 }
458
459 /// Creates a new modulus from a byte string.
460 pub fn fromBytes(bytes: []const u8, comptime endian: Endian) (InvalidModulusError || OverflowError)!Self {
461 const v = try FeUint.fromBytes(bytes, endian);
462 return try Self.fromUint(v);
463 }
464
465 /// Serializes the modulus to a byte string.
466 pub fn toBytes(self: Self, bytes: []u8, comptime endian: Endian) OverflowError!void {
467 return self.v.toBytes(bytes, endian);
468 }
469
470 /// Rejects field elements that are not in the canonical form.
471 pub fn rejectNonCanonical(self: Self, fe: Fe) error{NonCanonical}!void {
472 if (fe.limbs_count() != self.limbs_count() or ct.limbsCmpGeq(fe.v, self.v)) {
473 return error.NonCanonical;
474 }
475 }
476
477 // Makes the number of active limbs in a field element match the one of the modulus.
478 fn shrink(self: Self, fe: *Fe) OverflowError!void {
479 const new_len = self.limbs_count();
480 if (fe.limbs_count() < new_len) return error.Overflow;
481 var acc: Limb = 0;
482 for (fe.v.limbsConst()[new_len..]) |limb| {
483 acc |= limb;
484 }
485 if (acc != 0) return error.Overflow;
486 if (new_len > fe.v.limbs_buffer.len) return error.Overflow;
487 fe.v.limbs_len = new_len;
488 }
489
490 // Computes R^2 for the Montgomery representation.
491 fn computeRR(self: *Self) void {
492 self.rr = self.zero;
493 const n = self.rr.limbs_count();
494 self.rr.v.limbs()[n - 1] = 1;
495 for ((n - 1)..(2 * n)) |_| {
496 self.shiftIn(&self.rr, 0);
497 }
498 self.shrink(&self.rr) catch unreachable;
499 }
500
501 /// Computes x << t_bits + y (mod m)
502 fn shiftIn(self: Self, x: *Fe, y: Limb) void {
503 var d = self.zero;
504 const x_limbs = x.v.limbs();
505 const d_limbs = d.v.limbs();
506 const m_limbs = self.v.limbsConst();
507
508 var need_sub = false;
509 var i: usize = t_bits - 1;
510 while (true) : (i -= 1) {
511 var carry: u1 = @truncate(math.shr(Limb, y, i));
512 var borrow: u1 = 0;
513 for (0..self.limbs_count()) |j| {
514 const l = ct.select(need_sub, d_limbs[j], x_limbs[j]);
515 var res = (l << 1) + carry;
516 x_limbs[j] = @as(TLimb, @truncate(res));
517 carry = @truncate(res >> t_bits);
518
519 res = x_limbs[j] -% m_limbs[j] -% borrow;
520 d_limbs[j] = @as(TLimb, @truncate(res));
521
522 borrow = @truncate(res >> t_bits);
523 }
524 need_sub = ct.eql(carry, borrow);
525 if (i == 0) break;
526 }
527 x.v.cmov(need_sub, d.v);
528 }
529
530 /// Adds two field elements (mod m).
531 pub fn add(self: Self, x: Fe, y: Fe) Fe {
532 var out = x;
533 const overflow = out.v.addWithOverflow(y.v);
534 const underflow: u1 = @bitCast(ct.limbsCmpLt(out.v, self.v));
535 const need_sub = ct.eql(overflow, underflow);
536 _ = out.v.conditionalSubWithOverflow(need_sub, self.v);
537 return out;
538 }
539
540 /// Subtracts two field elements (mod m).
541 pub fn sub(self: Self, x: Fe, y: Fe) Fe {
542 var out = x;
543 const underflow: bool = @bitCast(out.v.subWithOverflow(y.v));
544 _ = out.v.conditionalAddWithOverflow(underflow, self.v);
545 return out;
546 }
547
548 /// Converts a field element to the Montgomery form.
549 pub fn toMontgomery(self: Self, x: *Fe) RepresentationError!void {
550 if (x.montgomery) {
551 return error.UnexpectedRepresentation;
552 }
553 self.shrink(x) catch unreachable;
554 x.* = self.montgomeryMul(x.*, self.rr);
555 x.montgomery = true;
556 }
557
558 /// Takes a field element out of the Montgomery form.
559 pub fn fromMontgomery(self: Self, x: *Fe) RepresentationError!void {
560 if (!x.montgomery) {
561 return error.UnexpectedRepresentation;
562 }
563 self.shrink(x) catch unreachable;
564 x.* = self.montgomeryMul(x.*, self.one());
565 x.montgomery = false;
566 }
567
568 /// Reduces an arbitrary `Uint`, converting it to a field element.
569 pub fn reduce(self: Self, x: anytype) Fe {
570 var out = self.zero;
571 var i = x.limbs_len - 1;
572 if (self.limbs_count() >= 2) {
573 const start = @min(i, self.limbs_count() - 2);
574 var j = start;
575 while (true) : (j -= 1) {
576 out.v.limbs()[j] = x.limbsConst()[i];
577 i -= 1;
578 if (j == 0) break;
579 }
580 }
581 while (true) : (i -= 1) {
582 self.shiftIn(&out, x.limbsConst()[i]);
583 if (i == 0) break;
584 }
585 return out;
586 }
587
588 fn montgomeryLoop(self: Self, d: *Fe, x: Fe, y: Fe) u1 {
589 assert(d.limbs_count() == x.limbs_count());
590 assert(d.limbs_count() == y.limbs_count());
591 assert(d.limbs_count() == self.limbs_count());
592
593 const a_limbs = x.v.limbsConst();
594 const b_limbs = y.v.limbsConst();
595 const d_limbs = d.v.limbs();
596 const m_limbs = self.v.limbsConst();
597
598 var overflow: u1 = 0;
599 for (0..self.limbs_count()) |i| {
600 var carry: Limb = 0;
601
602 var wide = ct.mulWide(a_limbs[i], b_limbs[0]);
603 var z_lo = @addWithOverflow(d_limbs[0], wide.lo);
604 const f = @as(TLimb, @truncate(z_lo[0] *% self.m0inv));
605 var z_hi = wide.hi +% z_lo[1];
606 wide = ct.mulWide(f, m_limbs[0]);
607 z_lo = @addWithOverflow(z_lo[0], wide.lo);
608 z_hi +%= z_lo[1];
609 z_hi +%= wide.hi;
610 carry = (z_hi << 1) | (z_lo[0] >> t_bits);
611
612 for (1..self.limbs_count()) |j| {
613 wide = ct.mulWide(a_limbs[i], b_limbs[j]);
614 z_lo = @addWithOverflow(d_limbs[j], wide.lo);
615 z_hi = wide.hi +% z_lo[1];
616 wide = ct.mulWide(f, m_limbs[j]);
617 z_lo = @addWithOverflow(z_lo[0], wide.lo);
618 z_hi +%= z_lo[1];
619 z_hi +%= wide.hi;
620 z_lo = @addWithOverflow(z_lo[0], carry);
621 z_hi +%= z_lo[1];
622 if (j > 0) {
623 d_limbs[j - 1] = @as(TLimb, @truncate(z_lo[0]));
624 }
625 carry = (z_hi << 1) | (z_lo[0] >> t_bits);
626 }
627 const z = overflow + carry;
628 d_limbs[self.limbs_count() - 1] = @as(TLimb, @truncate(z));
629 overflow = @as(u1, @truncate(z >> t_bits));
630 }
631 return overflow;
632 }
633
634 // Montgomery multiplication.
635 fn montgomeryMul(self: Self, x: Fe, y: Fe) Fe {
636 var d = self.zero;
637 assert(x.limbs_count() == self.limbs_count());
638 assert(y.limbs_count() == self.limbs_count());
639 const overflow = self.montgomeryLoop(&d, x, y);
640 const underflow = 1 -% @intFromBool(ct.limbsCmpGeq(d.v, self.v));
641 const need_sub = ct.eql(overflow, underflow);
642 _ = d.v.conditionalSubWithOverflow(need_sub, self.v);
643 d.montgomery = x.montgomery == y.montgomery;
644 return d;
645 }
646
647 // Montgomery squaring.
648 fn montgomerySq(self: Self, x: Fe) Fe {
649 var d = self.zero;
650 assert(x.limbs_count() == self.limbs_count());
651 const overflow = self.montgomeryLoop(&d, x, x);
652 const underflow = 1 -% @intFromBool(ct.limbsCmpGeq(d.v, self.v));
653 const need_sub = ct.eql(overflow, underflow);
654 _ = d.v.conditionalSubWithOverflow(need_sub, self.v);
655 d.montgomery = true;
656 return d;
657 }
658
659 // Returns x^e (mod m), with the exponent provided as a byte string.
660 // `public` must be set to `false` if the exponent it secret.
661 fn powWithEncodedExponentInternal(self: Self, x: Fe, e: []const u8, endian: Endian, comptime public: bool) NullExponentError!Fe {
662 var acc: u8 = 0;
663 for (e) |b| acc |= b;
664 if (acc == 0) return error.NullExponent;
665
666 var out = self.one();
667 self.toMontgomery(&out) catch unreachable;
668
669 if (public and e.len < 3 or (e.len == 3 and e[if (endian == .big) 0 else 2] <= 0b1111)) {
670 // Do not use a precomputation table for short, public exponents
671 var x_m = x;
672 if (x.montgomery == false) {
673 self.toMontgomery(&x_m) catch unreachable;
674 }
675 var s = switch (endian) {
676 .big => 0,
677 .little => e.len - 1,
678 };
679 while (true) {
680 const b = e[s];
681 var j: u3 = 7;
682 while (true) : (j -= 1) {
683 out = self.montgomerySq(out);
684 const k: u1 = @truncate(b >> j);
685 if (k != 0) {
686 const t = self.montgomeryMul(out, x_m);
687 @memcpy(out.v.limbs(), t.v.limbsConst());
688 }
689 if (j == 0) break;
690 }
691 switch (endian) {
692 .big => {
693 s += 1;
694 if (s == e.len) break;
695 },
696 .little => {
697 if (s == 0) break;
698 s -= 1;
699 },
700 }
701 }
702 } else {
703 // Use a precomputation table for large exponents
704 var pc = [1]Fe{x} ++ [_]Fe{self.zero} ** 14;
705 if (x.montgomery == false) {
706 self.toMontgomery(&pc[0]) catch unreachable;
707 }
708 for (1..pc.len) |i| {
709 pc[i] = self.montgomeryMul(pc[i - 1], pc[0]);
710 }
711 var t0 = self.zero;
712 var s = switch (endian) {
713 .big => 0,
714 .little => e.len - 1,
715 };
716 while (true) {
717 const b = e[s];
718 for ([_]u3{ 4, 0 }) |j| {
719 for (0..4) |_| {
720 out = self.montgomerySq(out);
721 }
722 const k = (b >> j) & 0b1111;
723 if (public or std.options.side_channels_mitigations == .none) {
724 if (k == 0) continue;
725 t0 = pc[k - 1];
726 } else {
727 for (pc, 0..) |t, i| {
728 t0.v.cmov(ct.eql(k, @as(u8, @truncate(i + 1))), t.v);
729 }
730 }
731 const t1 = self.montgomeryMul(out, t0);
732 if (public) {
733 @memcpy(out.v.limbs(), t1.v.limbsConst());
734 } else {
735 out.v.cmov(!ct.eql(k, 0), t1.v);
736 }
737 }
738 switch (endian) {
739 .big => {
740 s += 1;
741 if (s == e.len) break;
742 },
743 .little => {
744 if (s == 0) break;
745 s -= 1;
746 },
747 }
748 }
749 }
750 self.fromMontgomery(&out) catch unreachable;
751 return out;
752 }
753
754 /// Multiplies two field elements.
755 pub fn mul(self: Self, x: Fe, y: Fe) Fe {
756 if (x.montgomery != y.montgomery) {
757 return self.montgomeryMul(x, y);
758 }
759 var a_ = x;
760 if (x.montgomery == false) {
761 self.toMontgomery(&a_) catch unreachable;
762 } else {
763 self.fromMontgomery(&a_) catch unreachable;
764 }
765 return self.montgomeryMul(a_, y);
766 }
767
768 /// Squares a field element.
769 pub fn sq(self: Self, x: Fe) Fe {
770 var out = x;
771 if (x.montgomery == true) {
772 self.fromMontgomery(&out) catch unreachable;
773 }
774 out = self.montgomerySq(out);
775 out.montgomery = false;
776 self.toMontgomery(&out) catch unreachable;
777 return out;
778 }
779
780 /// Returns x^e (mod m) in constant time.
781 pub fn pow(self: Self, x: Fe, e: Fe) NullExponentError!Fe {
782 var buf: [Fe.encoded_bytes]u8 = undefined;
783 e.toBytes(&buf, native_endian) catch unreachable;
784 return self.powWithEncodedExponent(x, &buf, native_endian);
785 }
786
787 /// Returns x^e (mod m), assuming that the exponent is public.
788 /// The function remains constant time with respect to `x`.
789 pub fn powPublic(self: Self, x: Fe, e: Fe) NullExponentError!Fe {
790 var e_normalized = Fe{ .v = e.v.normalize() };
791 var buf_: [Fe.encoded_bytes]u8 = undefined;
792 var buf = buf_[0 .. math.divCeil(usize, e_normalized.v.limbs_len * t_bits, 8) catch unreachable];
793 e_normalized.toBytes(buf, .little) catch unreachable;
794 const leading = @clz(e_normalized.v.limbsConst()[e_normalized.v.limbs_len - carry_bits]);
795 buf = buf[0 .. buf.len - leading / 8];
796 return self.powWithEncodedPublicExponent(x, buf, .little);
797 }
798
799 /// Returns x^e (mod m), with the exponent provided as a byte string.
800 /// Exponents are usually small, so this function is faster than `powPublic` as a field element
801 /// doesn't have to be created if a serialized representation is already available.
802 ///
803 /// If the exponent is public, `powWithEncodedPublicExponent()` can be used instead for a slight speedup.
804 pub fn powWithEncodedExponent(self: Self, x: Fe, e: []const u8, endian: Endian) NullExponentError!Fe {
805 return self.powWithEncodedExponentInternal(x, e, endian, false);
806 }
807
808 /// Returns x^e (mod m), the exponent being public and provided as a byte string.
809 /// Exponents are usually small, so this function is faster than `powPublic` as a field element
810 /// doesn't have to be created if a serialized representation is already available.
811 ///
812 /// If the exponent is secret, `powWithEncodedExponent` must be used instead.
813 pub fn powWithEncodedPublicExponent(self: Self, x: Fe, e: []const u8, endian: Endian) NullExponentError!Fe {
814 return self.powWithEncodedExponentInternal(x, e, endian, true);
815 }
816 };
817}
818
819const ct = if (std.options.side_channels_mitigations == .none) ct_unprotected else ct_protected;
820
821const ct_protected = struct {
822 // Returns x if on is true, otherwise y.
823 fn select(on: bool, x: Limb, y: Limb) Limb {
824 const mask = @as(Limb, 0) -% @intFromBool(on);
825 return y ^ (mask & (y ^ x));
826 }
827
828 // Compares two values in constant time.
829 fn eql(x: anytype, y: @TypeOf(x)) bool {
830 const c1 = @subWithOverflow(x, y)[1];
831 const c2 = @subWithOverflow(y, x)[1];
832 return @as(bool, @bitCast(1 - (c1 | c2)));
833 }
834
835 // Compares two big integers in constant time, returning true if x < y.
836 fn limbsCmpLt(x: anytype, y: @TypeOf(x)) bool {
837 var c: u1 = 0;
838 for (x.limbsConst(), y.limbsConst()) |x_limb, y_limb| {
839 c = @truncate((x_limb -% y_limb -% c) >> t_bits);
840 }
841 return c != 0;
842 }
843
844 // Compares two big integers in constant time, returning true if x >= y.
845 fn limbsCmpGeq(x: anytype, y: @TypeOf(x)) bool {
846 return !limbsCmpLt(x, y);
847 }
848
849 // Multiplies two limbs and returns the result as a wide limb.
850 fn mulWide(x: Limb, y: Limb) WideLimb {
851 const half_bits = @typeInfo(Limb).int.bits / 2;
852 const Half = meta.Int(.unsigned, half_bits);
853 const x0 = @as(Half, @truncate(x));
854 const x1 = @as(Half, @truncate(x >> half_bits));
855 const y0 = @as(Half, @truncate(y));
856 const y1 = @as(Half, @truncate(y >> half_bits));
857 const w0 = math.mulWide(Half, x0, y0);
858 const t = math.mulWide(Half, x1, y0) + (w0 >> half_bits);
859 var w1: Limb = @as(Half, @truncate(t));
860 const w2 = @as(Half, @truncate(t >> half_bits));
861 w1 += math.mulWide(Half, x0, y1);
862 const hi = math.mulWide(Half, x1, y1) + w2 + (w1 >> half_bits);
863 const lo = x *% y;
864 return .{ .hi = hi, .lo = lo };
865 }
866};
867
868const ct_unprotected = struct {
869 // Returns x if on is true, otherwise y.
870 fn select(on: bool, x: Limb, y: Limb) Limb {
871 return if (on) x else y;
872 }
873
874 // Compares two values in constant time.
875 fn eql(x: anytype, y: @TypeOf(x)) bool {
876 return x == y;
877 }
878
879 // Compares two big integers in constant time, returning true if x < y.
880 fn limbsCmpLt(x: anytype, y: @TypeOf(x)) bool {
881 const x_limbs = x.limbsConst();
882 const y_limbs = y.limbsConst();
883 assert(x_limbs.len == y_limbs.len);
884
885 var i = x_limbs.len;
886 while (i != 0) {
887 i -= 1;
888 if (x_limbs[i] != y_limbs[i]) {
889 return x_limbs[i] < y_limbs[i];
890 }
891 }
892 return false;
893 }
894
895 // Compares two big integers in constant time, returning true if x >= y.
896 fn limbsCmpGeq(x: anytype, y: @TypeOf(x)) bool {
897 return !limbsCmpLt(x, y);
898 }
899
900 // Multiplies two limbs and returns the result as a wide limb.
901 fn mulWide(x: Limb, y: Limb) WideLimb {
902 const wide = math.mulWide(Limb, x, y);
903 return .{
904 .hi = @as(Limb, @truncate(wide >> @typeInfo(Limb).int.bits)),
905 .lo = @as(Limb, @truncate(wide)),
906 };
907 }
908};
909
910test "finite field arithmetic" {
911 if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
912
913 const M = Modulus(256);
914 const m = try M.fromPrimitive(u256, 3429938563481314093726330772853735541133072814650493833233);
915 var x = try M.Fe.fromPrimitive(u256, m, 80169837251094269539116136208111827396136208141182357733);
916 var y = try M.Fe.fromPrimitive(u256, m, 24620149608466364616251608466389896540098571);
917
918 const x_ = try x.toPrimitive(u256);
919 try testing.expect((try M.Fe.fromPrimitive(@TypeOf(x_), m, x_)).eql(x));
920 try testing.expectError(error.Overflow, x.toPrimitive(u50));
921
922 const bits = m.bits();
923 try testing.expectEqual(bits, 192);
924
925 var x_y = m.mul(x, y);
926 try testing.expectEqual(x_y.toPrimitive(u256), 1666576607955767413750776202132407807424848069716933450241);
927
928 try m.toMontgomery(&x);
929 x_y = m.mul(x, y);
930 try testing.expectEqual(x_y.toPrimitive(u256), 1666576607955767413750776202132407807424848069716933450241);
931 try m.fromMontgomery(&x);
932
933 x = m.add(x, y);
934 try testing.expectEqual(x.toPrimitive(u256), 80169837251118889688724602572728079004602598037722456304);
935 x = m.sub(x, y);
936 try testing.expectEqual(x.toPrimitive(u256), 80169837251094269539116136208111827396136208141182357733);
937
938 const big = try Uint(512).fromPrimitive(u495, 77285373554113307281465049383342993856348131409372633077285373554113307281465049383323332333429938563481314093726330772853735541133072814650493833233);
939 const reduced = m.reduce(big);
940 try testing.expectEqual(reduced.toPrimitive(u495), 858047099884257670294681641776170038885500210968322054970);
941
942 const x_pow_y = try m.powPublic(x, y);
943 try testing.expectEqual(x_pow_y.toPrimitive(u256), 1631933139300737762906024873185789093007782131928298618473);
944 try m.toMontgomery(&x);
945 const x_pow_y2 = try m.powPublic(x, y);
946 try m.fromMontgomery(&x);
947 try testing.expect(x_pow_y2.eql(x_pow_y));
948 try testing.expectError(error.NullExponent, m.powPublic(x, m.zero));
949
950 try testing.expect(!x.isZero());
951 try testing.expect(!y.isZero());
952 try testing.expect(m.v.isOdd());
953
954 const x_sq = m.sq(x);
955 const x_sq2 = m.mul(x, x);
956 try testing.expect(x_sq.eql(x_sq2));
957 try m.toMontgomery(&x);
958 const x_sq3 = m.sq(x);
959 const x_sq4 = m.mul(x, x);
960 try testing.expect(x_sq.eql(x_sq3));
961 try testing.expect(x_sq3.eql(x_sq4));
962 try m.fromMontgomery(&x);
963}
964
965fn testCt(ct_: anytype) !void {
966 if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
967
968 const l0: Limb = 0;
969 const l1: Limb = 1;
970 try testing.expectEqual(l1, ct_.select(true, l1, l0));
971 try testing.expectEqual(l0, ct_.select(false, l1, l0));
972 try testing.expectEqual(false, ct_.eql(l1, l0));
973 try testing.expectEqual(true, ct_.eql(l1, l1));
974
975 const M = Modulus(256);
976 const m = try M.fromPrimitive(u256, 3429938563481314093726330772853735541133072814650493833233);
977 const x = try M.Fe.fromPrimitive(u256, m, 80169837251094269539116136208111827396136208141182357733);
978 const y = try M.Fe.fromPrimitive(u256, m, 24620149608466364616251608466389896540098571);
979 try testing.expectEqual(false, ct_.limbsCmpLt(x.v, y.v));
980 try testing.expectEqual(true, ct_.limbsCmpGeq(x.v, y.v));
981
982 try testing.expectEqual(WideLimb{ .hi = 0, .lo = 0x88 }, ct_.mulWide(1 << 3, (1 << 4) + 1));
983}
984
985test ct {
986 try testCt(ct_protected);
987 try testCt(ct_unprotected);
988}