master
  1const std = @import("std");
  2const crypto = std.crypto;
  3const debug = std.debug;
  4const mem = std.mem;
  5const meta = std.meta;
  6
  7const NonCanonicalError = crypto.errors.NonCanonicalError;
  8const NotSquareError = crypto.errors.NotSquareError;
  9
 10/// Parameters to create a finite field type.
 11pub const FieldParams = struct {
 12    fiat: type,
 13    field_order: comptime_int,
 14    field_bits: comptime_int,
 15    saturated_bits: comptime_int,
 16    encoded_length: comptime_int,
 17};
 18
 19/// A field element, internally stored in Montgomery domain.
 20pub fn Field(comptime params: FieldParams) type {
 21    const fiat = params.fiat;
 22    const MontgomeryDomainFieldElement = fiat.MontgomeryDomainFieldElement;
 23    const NonMontgomeryDomainFieldElement = fiat.NonMontgomeryDomainFieldElement;
 24
 25    return struct {
 26        const Fe = @This();
 27
 28        limbs: MontgomeryDomainFieldElement,
 29
 30        /// Field size.
 31        pub const field_order = params.field_order;
 32
 33        /// Number of bits to represent the set of all elements.
 34        pub const field_bits = params.field_bits;
 35
 36        /// Number of bits that can be saturated without overflowing.
 37        pub const saturated_bits = params.saturated_bits;
 38
 39        /// Number of bytes required to encode an element.
 40        pub const encoded_length = params.encoded_length;
 41
 42        /// Zero.
 43        pub const zero: Fe = Fe{ .limbs = mem.zeroes(MontgomeryDomainFieldElement) };
 44
 45        /// One.
 46        pub const one = one: {
 47            var fe: Fe = undefined;
 48            fiat.setOne(&fe.limbs);
 49            break :one fe;
 50        };
 51
 52        /// Reject non-canonical encodings of an element.
 53        pub fn rejectNonCanonical(s_: [encoded_length]u8, endian: std.builtin.Endian) NonCanonicalError!void {
 54            var s = if (endian == .little) s_ else orderSwap(s_);
 55            const field_order_s = comptime fos: {
 56                var fos: [encoded_length]u8 = undefined;
 57                mem.writeInt(std.meta.Int(.unsigned, encoded_length * 8), &fos, field_order, .little);
 58                break :fos fos;
 59            };
 60            if (crypto.timing_safe.compare(u8, &s, &field_order_s, .little) != .lt) {
 61                return error.NonCanonical;
 62            }
 63        }
 64
 65        /// Swap the endianness of an encoded element.
 66        pub fn orderSwap(s: [encoded_length]u8) [encoded_length]u8 {
 67            var t = s;
 68            for (s, 0..) |x, i| t[t.len - 1 - i] = x;
 69            return t;
 70        }
 71
 72        /// Unpack a field element.
 73        pub fn fromBytes(s_: [encoded_length]u8, endian: std.builtin.Endian) NonCanonicalError!Fe {
 74            const s = if (endian == .little) s_ else orderSwap(s_);
 75            try rejectNonCanonical(s, .little);
 76            var limbs_z: NonMontgomeryDomainFieldElement = undefined;
 77            fiat.fromBytes(&limbs_z, s);
 78            var limbs: MontgomeryDomainFieldElement = undefined;
 79            fiat.toMontgomery(&limbs, limbs_z);
 80            return Fe{ .limbs = limbs };
 81        }
 82
 83        /// Pack a field element.
 84        pub fn toBytes(fe: Fe, endian: std.builtin.Endian) [encoded_length]u8 {
 85            var limbs_z: NonMontgomeryDomainFieldElement = undefined;
 86            fiat.fromMontgomery(&limbs_z, fe.limbs);
 87            var s: [encoded_length]u8 = undefined;
 88            fiat.toBytes(&s, limbs_z);
 89            return if (endian == .little) s else orderSwap(s);
 90        }
 91
 92        /// Element as an integer.
 93        pub const IntRepr = meta.Int(.unsigned, params.field_bits);
 94
 95        /// Create a field element from an integer.
 96        pub fn fromInt(comptime x: IntRepr) NonCanonicalError!Fe {
 97            var s: [encoded_length]u8 = undefined;
 98            mem.writeInt(IntRepr, &s, x, .little);
 99            return fromBytes(s, .little);
100        }
101
102        /// Return the field element as an integer.
103        pub fn toInt(fe: Fe) IntRepr {
104            const s = fe.toBytes(.little);
105            return mem.readInt(IntRepr, &s, .little);
106        }
107
108        /// Return true if the field element is zero.
109        pub fn isZero(fe: Fe) bool {
110            var z: @TypeOf(fe.limbs[0]) = undefined;
111            fiat.nonzero(&z, fe.limbs);
112            return z == 0;
113        }
114
115        /// Return true if both field elements are equivalent.
116        pub fn equivalent(a: Fe, b: Fe) bool {
117            return a.sub(b).isZero();
118        }
119
120        /// Return true if the element is odd.
121        pub fn isOdd(fe: Fe) bool {
122            const s = fe.toBytes(.little);
123            return @as(u1, @truncate(s[0])) != 0;
124        }
125
126        /// Conditonally replace a field element with `a` if `c` is positive.
127        pub fn cMov(fe: *Fe, a: Fe, c: u1) void {
128            fiat.selectznz(&fe.limbs, c, fe.limbs, a.limbs);
129        }
130
131        /// Add field elements.
132        pub fn add(a: Fe, b: Fe) Fe {
133            var fe: Fe = undefined;
134            fiat.add(&fe.limbs, a.limbs, b.limbs);
135            return fe;
136        }
137
138        /// Subtract field elements.
139        pub fn sub(a: Fe, b: Fe) Fe {
140            var fe: Fe = undefined;
141            fiat.sub(&fe.limbs, a.limbs, b.limbs);
142            return fe;
143        }
144
145        /// Double a field element.
146        pub fn dbl(a: Fe) Fe {
147            var fe: Fe = undefined;
148            fiat.add(&fe.limbs, a.limbs, a.limbs);
149            return fe;
150        }
151
152        /// Multiply field elements.
153        pub fn mul(a: Fe, b: Fe) Fe {
154            var fe: Fe = undefined;
155            fiat.mul(&fe.limbs, a.limbs, b.limbs);
156            return fe;
157        }
158
159        /// Square a field element.
160        pub fn sq(a: Fe) Fe {
161            var fe: Fe = undefined;
162            fiat.square(&fe.limbs, a.limbs);
163            return fe;
164        }
165
166        /// Square a field element n times.
167        fn sqn(a: Fe, comptime n: comptime_int) Fe {
168            var i: usize = 0;
169            var fe = a;
170            while (i < n) : (i += 1) {
171                fe = fe.sq();
172            }
173            return fe;
174        }
175
176        /// Compute a^n.
177        pub fn pow(a: Fe, comptime T: type, comptime n: T) Fe {
178            var fe = one;
179            var x: T = n;
180            var t = a;
181            while (true) {
182                if (@as(u1, @truncate(x)) != 0) fe = fe.mul(t);
183                x >>= 1;
184                if (x == 0) break;
185                t = t.sq();
186            }
187            return fe;
188        }
189
190        /// Negate a field element.
191        pub fn neg(a: Fe) Fe {
192            var fe: Fe = undefined;
193            fiat.opp(&fe.limbs, a.limbs);
194            return fe;
195        }
196
197        /// Return the inverse of a field element, or 0 if a=0.
198        // Field inversion from https://eprint.iacr.org/2021/549.pdf
199        pub fn invert(a: Fe) Fe {
200            const iterations = (49 * field_bits + if (field_bits < 46) 80 else 57) / 17;
201            const Limbs = @TypeOf(a.limbs);
202            const Word = @TypeOf(a.limbs[0]);
203            const XLimbs = [a.limbs.len + 1]Word;
204
205            var d: Word = 1;
206            var f = comptime blk: {
207                var f: XLimbs = undefined;
208                fiat.msat(&f);
209                break :blk f;
210            };
211            var g: XLimbs = undefined;
212            fiat.fromMontgomery(g[0..a.limbs.len], a.limbs);
213            g[g.len - 1] = 0;
214
215            var r = Fe.one.limbs;
216            var v = Fe.zero.limbs;
217
218            var out1: Word = undefined;
219            var out2: XLimbs = undefined;
220            var out3: XLimbs = undefined;
221            var out4: Limbs = undefined;
222            var out5: Limbs = undefined;
223
224            var i: usize = 0;
225            while (i < iterations - iterations % 2) : (i += 2) {
226                fiat.divstep(&out1, &out2, &out3, &out4, &out5, d, f, g, v, r);
227                fiat.divstep(&d, &f, &g, &v, &r, out1, out2, out3, out4, out5);
228            }
229            if (iterations % 2 != 0) {
230                fiat.divstep(&out1, &out2, &out3, &out4, &out5, d, f, g, v, r);
231                v = out4;
232                f = out2;
233            }
234            var v_opp: Limbs = undefined;
235            fiat.opp(&v_opp, v);
236            fiat.selectznz(&v, @as(u1, @truncate(f[f.len - 1] >> (@bitSizeOf(Word) - 1))), v, v_opp);
237
238            const precomp = blk: {
239                var precomp: Limbs = undefined;
240                fiat.divstepPrecomp(&precomp);
241                break :blk precomp;
242            };
243            var fe: Fe = undefined;
244            fiat.mul(&fe.limbs, v, precomp);
245            return fe;
246        }
247
248        /// Return true if the field element is a square.
249        pub fn isSquare(x2: Fe) bool {
250            if (field_order == 115792089210356248762697446949407573530086143415290314195533631308867097853951) {
251                const t110 = x2.mul(x2.sq()).sq();
252                const t111 = x2.mul(t110);
253                const t111111 = t111.mul(x2.mul(t110).sqn(3));
254                const x15 = t111111.sqn(6).mul(t111111).sqn(3).mul(t111);
255                const x16 = x15.sq().mul(x2);
256                const x53 = x16.sqn(16).mul(x16).sqn(15);
257                const x47 = x15.mul(x53);
258                const ls = x47.mul(((x53.sqn(17).mul(x2)).sqn(143).mul(x47)).sqn(47)).sq().mul(x2);
259                return ls.equivalent(Fe.one);
260            } else if (field_order == 39402006196394479212279040100143613805079739270465446667948293404245721771496870329047266088258938001861606973112319) {
261                const t111 = x2.mul(x2.mul(x2.sq()).sq());
262                const t111111 = t111.mul(t111.sqn(3));
263                const t1111110 = t111111.sq();
264                const t1111111 = x2.mul(t1111110);
265                const x12 = t1111110.sqn(5).mul(t111111);
266                const x31 = x12.sqn(12).mul(x12).sqn(7).mul(t1111111);
267                const x32 = x31.sq().mul(x2);
268                const x63 = x32.sqn(31).mul(x31);
269                const x126 = x63.sqn(63).mul(x63);
270                const ls = x126.sqn(126).mul(x126).sqn(3).mul(t111).sqn(33).mul(x32).sqn(95).mul(x31);
271                return ls.equivalent(Fe.one);
272            } else {
273                const ls = x2.pow(std.meta.Int(.unsigned, field_bits), (field_order - 1) / 2); // Legendre symbol
274                return ls.equivalent(Fe.one);
275            }
276        }
277
278        // x=x2^((field_order+1)/4) w/ field order=3 (mod 4).
279        fn uncheckedSqrt(x2: Fe) Fe {
280            if (field_order % 4 != 3) @compileError("unimplemented");
281            if (field_order == 115792089210356248762697446949407573530086143415290314195533631308867097853951) {
282                const t11 = x2.mul(x2.sq());
283                const t1111 = t11.mul(t11.sqn(2));
284                const t11111111 = t1111.mul(t1111.sqn(4));
285                const x16 = t11111111.sqn(8).mul(t11111111);
286                return x16.sqn(16).mul(x16).sqn(32).mul(x2).sqn(96).mul(x2).sqn(94);
287            } else if (field_order == 39402006196394479212279040100143613805079739270465446667948293404245721771496870329047266088258938001861606973112319) {
288                const t111 = x2.mul(x2.mul(x2.sq()).sq());
289                const t111111 = t111.mul(t111.sqn(3));
290                const t1111110 = t111111.sq();
291                const t1111111 = x2.mul(t1111110);
292                const x12 = t1111110.sqn(5).mul(t111111);
293                const x31 = x12.sqn(12).mul(x12).sqn(7).mul(t1111111);
294                const x32 = x31.sq().mul(x2);
295                const x63 = x32.sqn(31).mul(x31);
296                const x126 = x63.sqn(63).mul(x63);
297                return x126.sqn(126).mul(x126).sqn(3).mul(t111).sqn(33).mul(x32).sqn(64).mul(x2).sqn(30);
298            } else if (field_order == 115792089237316195423570985008687907853269984665640564039457584007908834671663) {
299                const t11 = x2.mul(x2.sq());
300                const t1111 = t11.mul(t11.sqn(2));
301                const t11111 = x2.mul(t1111.sq());
302                const t1111111 = t11.mul(t11111.sqn(2));
303                const x11 = t1111111.sqn(4).mul(t1111);
304                const x22 = x11.sqn(11).mul(x11);
305                const x27 = x22.sqn(5).mul(t11111);
306                const x54 = x27.sqn(27).mul(x27);
307                const x108 = x54.sqn(54).mul(x54);
308                return x108.sqn(108).mul(x108).sqn(7).mul(t1111111).sqn(23).mul(x22).sqn(6).mul(t11).sqn(2);
309            } else {
310                return x2.pow(std.meta.Int(.unsigned, field_bits), (field_order + 1) / 4);
311            }
312        }
313
314        /// Compute the square root of `x2`, returning `error.NotSquare` if `x2` was not a square.
315        pub fn sqrt(x2: Fe) NotSquareError!Fe {
316            const x = x2.uncheckedSqrt();
317            if (x.sq().equivalent(x2)) {
318                return x;
319            }
320            return error.NotSquare;
321        }
322    };
323}