master
1const std = @import("std");
2const crypto = std.crypto;
3const debug = std.debug;
4const mem = std.mem;
5const meta = std.meta;
6
7const NonCanonicalError = crypto.errors.NonCanonicalError;
8const NotSquareError = crypto.errors.NotSquareError;
9
10/// Parameters to create a finite field type.
11pub const FieldParams = struct {
12 fiat: type,
13 field_order: comptime_int,
14 field_bits: comptime_int,
15 saturated_bits: comptime_int,
16 encoded_length: comptime_int,
17};
18
19/// A field element, internally stored in Montgomery domain.
20pub fn Field(comptime params: FieldParams) type {
21 const fiat = params.fiat;
22 const MontgomeryDomainFieldElement = fiat.MontgomeryDomainFieldElement;
23 const NonMontgomeryDomainFieldElement = fiat.NonMontgomeryDomainFieldElement;
24
25 return struct {
26 const Fe = @This();
27
28 limbs: MontgomeryDomainFieldElement,
29
30 /// Field size.
31 pub const field_order = params.field_order;
32
33 /// Number of bits to represent the set of all elements.
34 pub const field_bits = params.field_bits;
35
36 /// Number of bits that can be saturated without overflowing.
37 pub const saturated_bits = params.saturated_bits;
38
39 /// Number of bytes required to encode an element.
40 pub const encoded_length = params.encoded_length;
41
42 /// Zero.
43 pub const zero: Fe = Fe{ .limbs = mem.zeroes(MontgomeryDomainFieldElement) };
44
45 /// One.
46 pub const one = one: {
47 var fe: Fe = undefined;
48 fiat.setOne(&fe.limbs);
49 break :one fe;
50 };
51
52 /// Reject non-canonical encodings of an element.
53 pub fn rejectNonCanonical(s_: [encoded_length]u8, endian: std.builtin.Endian) NonCanonicalError!void {
54 var s = if (endian == .little) s_ else orderSwap(s_);
55 const field_order_s = comptime fos: {
56 var fos: [encoded_length]u8 = undefined;
57 mem.writeInt(std.meta.Int(.unsigned, encoded_length * 8), &fos, field_order, .little);
58 break :fos fos;
59 };
60 if (crypto.timing_safe.compare(u8, &s, &field_order_s, .little) != .lt) {
61 return error.NonCanonical;
62 }
63 }
64
65 /// Swap the endianness of an encoded element.
66 pub fn orderSwap(s: [encoded_length]u8) [encoded_length]u8 {
67 var t = s;
68 for (s, 0..) |x, i| t[t.len - 1 - i] = x;
69 return t;
70 }
71
72 /// Unpack a field element.
73 pub fn fromBytes(s_: [encoded_length]u8, endian: std.builtin.Endian) NonCanonicalError!Fe {
74 const s = if (endian == .little) s_ else orderSwap(s_);
75 try rejectNonCanonical(s, .little);
76 var limbs_z: NonMontgomeryDomainFieldElement = undefined;
77 fiat.fromBytes(&limbs_z, s);
78 var limbs: MontgomeryDomainFieldElement = undefined;
79 fiat.toMontgomery(&limbs, limbs_z);
80 return Fe{ .limbs = limbs };
81 }
82
83 /// Pack a field element.
84 pub fn toBytes(fe: Fe, endian: std.builtin.Endian) [encoded_length]u8 {
85 var limbs_z: NonMontgomeryDomainFieldElement = undefined;
86 fiat.fromMontgomery(&limbs_z, fe.limbs);
87 var s: [encoded_length]u8 = undefined;
88 fiat.toBytes(&s, limbs_z);
89 return if (endian == .little) s else orderSwap(s);
90 }
91
92 /// Element as an integer.
93 pub const IntRepr = meta.Int(.unsigned, params.field_bits);
94
95 /// Create a field element from an integer.
96 pub fn fromInt(comptime x: IntRepr) NonCanonicalError!Fe {
97 var s: [encoded_length]u8 = undefined;
98 mem.writeInt(IntRepr, &s, x, .little);
99 return fromBytes(s, .little);
100 }
101
102 /// Return the field element as an integer.
103 pub fn toInt(fe: Fe) IntRepr {
104 const s = fe.toBytes(.little);
105 return mem.readInt(IntRepr, &s, .little);
106 }
107
108 /// Return true if the field element is zero.
109 pub fn isZero(fe: Fe) bool {
110 var z: @TypeOf(fe.limbs[0]) = undefined;
111 fiat.nonzero(&z, fe.limbs);
112 return z == 0;
113 }
114
115 /// Return true if both field elements are equivalent.
116 pub fn equivalent(a: Fe, b: Fe) bool {
117 return a.sub(b).isZero();
118 }
119
120 /// Return true if the element is odd.
121 pub fn isOdd(fe: Fe) bool {
122 const s = fe.toBytes(.little);
123 return @as(u1, @truncate(s[0])) != 0;
124 }
125
126 /// Conditonally replace a field element with `a` if `c` is positive.
127 pub fn cMov(fe: *Fe, a: Fe, c: u1) void {
128 fiat.selectznz(&fe.limbs, c, fe.limbs, a.limbs);
129 }
130
131 /// Add field elements.
132 pub fn add(a: Fe, b: Fe) Fe {
133 var fe: Fe = undefined;
134 fiat.add(&fe.limbs, a.limbs, b.limbs);
135 return fe;
136 }
137
138 /// Subtract field elements.
139 pub fn sub(a: Fe, b: Fe) Fe {
140 var fe: Fe = undefined;
141 fiat.sub(&fe.limbs, a.limbs, b.limbs);
142 return fe;
143 }
144
145 /// Double a field element.
146 pub fn dbl(a: Fe) Fe {
147 var fe: Fe = undefined;
148 fiat.add(&fe.limbs, a.limbs, a.limbs);
149 return fe;
150 }
151
152 /// Multiply field elements.
153 pub fn mul(a: Fe, b: Fe) Fe {
154 var fe: Fe = undefined;
155 fiat.mul(&fe.limbs, a.limbs, b.limbs);
156 return fe;
157 }
158
159 /// Square a field element.
160 pub fn sq(a: Fe) Fe {
161 var fe: Fe = undefined;
162 fiat.square(&fe.limbs, a.limbs);
163 return fe;
164 }
165
166 /// Square a field element n times.
167 fn sqn(a: Fe, comptime n: comptime_int) Fe {
168 var i: usize = 0;
169 var fe = a;
170 while (i < n) : (i += 1) {
171 fe = fe.sq();
172 }
173 return fe;
174 }
175
176 /// Compute a^n.
177 pub fn pow(a: Fe, comptime T: type, comptime n: T) Fe {
178 var fe = one;
179 var x: T = n;
180 var t = a;
181 while (true) {
182 if (@as(u1, @truncate(x)) != 0) fe = fe.mul(t);
183 x >>= 1;
184 if (x == 0) break;
185 t = t.sq();
186 }
187 return fe;
188 }
189
190 /// Negate a field element.
191 pub fn neg(a: Fe) Fe {
192 var fe: Fe = undefined;
193 fiat.opp(&fe.limbs, a.limbs);
194 return fe;
195 }
196
197 /// Return the inverse of a field element, or 0 if a=0.
198 // Field inversion from https://eprint.iacr.org/2021/549.pdf
199 pub fn invert(a: Fe) Fe {
200 const iterations = (49 * field_bits + if (field_bits < 46) 80 else 57) / 17;
201 const Limbs = @TypeOf(a.limbs);
202 const Word = @TypeOf(a.limbs[0]);
203 const XLimbs = [a.limbs.len + 1]Word;
204
205 var d: Word = 1;
206 var f = comptime blk: {
207 var f: XLimbs = undefined;
208 fiat.msat(&f);
209 break :blk f;
210 };
211 var g: XLimbs = undefined;
212 fiat.fromMontgomery(g[0..a.limbs.len], a.limbs);
213 g[g.len - 1] = 0;
214
215 var r = Fe.one.limbs;
216 var v = Fe.zero.limbs;
217
218 var out1: Word = undefined;
219 var out2: XLimbs = undefined;
220 var out3: XLimbs = undefined;
221 var out4: Limbs = undefined;
222 var out5: Limbs = undefined;
223
224 var i: usize = 0;
225 while (i < iterations - iterations % 2) : (i += 2) {
226 fiat.divstep(&out1, &out2, &out3, &out4, &out5, d, f, g, v, r);
227 fiat.divstep(&d, &f, &g, &v, &r, out1, out2, out3, out4, out5);
228 }
229 if (iterations % 2 != 0) {
230 fiat.divstep(&out1, &out2, &out3, &out4, &out5, d, f, g, v, r);
231 v = out4;
232 f = out2;
233 }
234 var v_opp: Limbs = undefined;
235 fiat.opp(&v_opp, v);
236 fiat.selectznz(&v, @as(u1, @truncate(f[f.len - 1] >> (@bitSizeOf(Word) - 1))), v, v_opp);
237
238 const precomp = blk: {
239 var precomp: Limbs = undefined;
240 fiat.divstepPrecomp(&precomp);
241 break :blk precomp;
242 };
243 var fe: Fe = undefined;
244 fiat.mul(&fe.limbs, v, precomp);
245 return fe;
246 }
247
248 /// Return true if the field element is a square.
249 pub fn isSquare(x2: Fe) bool {
250 if (field_order == 115792089210356248762697446949407573530086143415290314195533631308867097853951) {
251 const t110 = x2.mul(x2.sq()).sq();
252 const t111 = x2.mul(t110);
253 const t111111 = t111.mul(x2.mul(t110).sqn(3));
254 const x15 = t111111.sqn(6).mul(t111111).sqn(3).mul(t111);
255 const x16 = x15.sq().mul(x2);
256 const x53 = x16.sqn(16).mul(x16).sqn(15);
257 const x47 = x15.mul(x53);
258 const ls = x47.mul(((x53.sqn(17).mul(x2)).sqn(143).mul(x47)).sqn(47)).sq().mul(x2);
259 return ls.equivalent(Fe.one);
260 } else if (field_order == 39402006196394479212279040100143613805079739270465446667948293404245721771496870329047266088258938001861606973112319) {
261 const t111 = x2.mul(x2.mul(x2.sq()).sq());
262 const t111111 = t111.mul(t111.sqn(3));
263 const t1111110 = t111111.sq();
264 const t1111111 = x2.mul(t1111110);
265 const x12 = t1111110.sqn(5).mul(t111111);
266 const x31 = x12.sqn(12).mul(x12).sqn(7).mul(t1111111);
267 const x32 = x31.sq().mul(x2);
268 const x63 = x32.sqn(31).mul(x31);
269 const x126 = x63.sqn(63).mul(x63);
270 const ls = x126.sqn(126).mul(x126).sqn(3).mul(t111).sqn(33).mul(x32).sqn(95).mul(x31);
271 return ls.equivalent(Fe.one);
272 } else {
273 const ls = x2.pow(std.meta.Int(.unsigned, field_bits), (field_order - 1) / 2); // Legendre symbol
274 return ls.equivalent(Fe.one);
275 }
276 }
277
278 // x=x2^((field_order+1)/4) w/ field order=3 (mod 4).
279 fn uncheckedSqrt(x2: Fe) Fe {
280 if (field_order % 4 != 3) @compileError("unimplemented");
281 if (field_order == 115792089210356248762697446949407573530086143415290314195533631308867097853951) {
282 const t11 = x2.mul(x2.sq());
283 const t1111 = t11.mul(t11.sqn(2));
284 const t11111111 = t1111.mul(t1111.sqn(4));
285 const x16 = t11111111.sqn(8).mul(t11111111);
286 return x16.sqn(16).mul(x16).sqn(32).mul(x2).sqn(96).mul(x2).sqn(94);
287 } else if (field_order == 39402006196394479212279040100143613805079739270465446667948293404245721771496870329047266088258938001861606973112319) {
288 const t111 = x2.mul(x2.mul(x2.sq()).sq());
289 const t111111 = t111.mul(t111.sqn(3));
290 const t1111110 = t111111.sq();
291 const t1111111 = x2.mul(t1111110);
292 const x12 = t1111110.sqn(5).mul(t111111);
293 const x31 = x12.sqn(12).mul(x12).sqn(7).mul(t1111111);
294 const x32 = x31.sq().mul(x2);
295 const x63 = x32.sqn(31).mul(x31);
296 const x126 = x63.sqn(63).mul(x63);
297 return x126.sqn(126).mul(x126).sqn(3).mul(t111).sqn(33).mul(x32).sqn(64).mul(x2).sqn(30);
298 } else if (field_order == 115792089237316195423570985008687907853269984665640564039457584007908834671663) {
299 const t11 = x2.mul(x2.sq());
300 const t1111 = t11.mul(t11.sqn(2));
301 const t11111 = x2.mul(t1111.sq());
302 const t1111111 = t11.mul(t11111.sqn(2));
303 const x11 = t1111111.sqn(4).mul(t1111);
304 const x22 = x11.sqn(11).mul(x11);
305 const x27 = x22.sqn(5).mul(t11111);
306 const x54 = x27.sqn(27).mul(x27);
307 const x108 = x54.sqn(54).mul(x54);
308 return x108.sqn(108).mul(x108).sqn(7).mul(t1111111).sqn(23).mul(x22).sqn(6).mul(t11).sqn(2);
309 } else {
310 return x2.pow(std.meta.Int(.unsigned, field_bits), (field_order + 1) / 4);
311 }
312 }
313
314 /// Compute the square root of `x2`, returning `error.NotSquare` if `x2` was not a square.
315 pub fn sqrt(x2: Fe) NotSquareError!Fe {
316 const x = x2.uncheckedSqrt();
317 if (x.sq().equivalent(x2)) {
318 return x;
319 }
320 return error.NotSquare;
321 }
322 };
323}