master
1const std = @import("std");
2const common = @import("../common.zig");
3const crypto = std.crypto;
4const debug = std.debug;
5const math = std.math;
6const mem = std.mem;
7
8const Field = common.Field;
9
10const NonCanonicalError = std.crypto.errors.NonCanonicalError;
11const NotSquareError = std.crypto.errors.NotSquareError;
12
13/// Number of bytes required to encode a scalar.
14pub const encoded_length = 32;
15
16/// A compressed scalar, in canonical form.
17pub const CompressedScalar = [encoded_length]u8;
18
19const Fe = Field(.{
20 .fiat = @import("secp256k1_scalar_64.zig"),
21 .field_order = 115792089237316195423570985008687907852837564279074904382605163141518161494337,
22 .field_bits = 256,
23 .saturated_bits = 256,
24 .encoded_length = encoded_length,
25});
26
27/// The scalar field order.
28pub const field_order = Fe.field_order;
29
30/// Reject a scalar whose encoding is not canonical.
31pub fn rejectNonCanonical(s: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!void {
32 return Fe.rejectNonCanonical(s, endian);
33}
34
35/// Reduce a 48-bytes scalar to the field size.
36pub fn reduce48(s: [48]u8, endian: std.builtin.Endian) CompressedScalar {
37 return Scalar.fromBytes48(s, endian).toBytes(endian);
38}
39
40/// Reduce a 64-bytes scalar to the field size.
41pub fn reduce64(s: [64]u8, endian: std.builtin.Endian) CompressedScalar {
42 return Scalar.fromBytes64(s, endian).toBytes(endian);
43}
44
45/// Return a*b (mod L)
46pub fn mul(a: CompressedScalar, b: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!CompressedScalar {
47 return (try Scalar.fromBytes(a, endian)).mul(try Scalar.fromBytes(b, endian)).toBytes(endian);
48}
49
50/// Return a*b+c (mod L)
51pub fn mulAdd(a: CompressedScalar, b: CompressedScalar, c: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!CompressedScalar {
52 return (try Scalar.fromBytes(a, endian)).mul(try Scalar.fromBytes(b, endian)).add(try Scalar.fromBytes(c, endian)).toBytes(endian);
53}
54
55/// Return a+b (mod L)
56pub fn add(a: CompressedScalar, b: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!CompressedScalar {
57 return (try Scalar.fromBytes(a, endian)).add(try Scalar.fromBytes(b, endian)).toBytes(endian);
58}
59
60/// Return -s (mod L)
61pub fn neg(s: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!CompressedScalar {
62 return (try Scalar.fromBytes(s, endian)).neg().toBytes(endian);
63}
64
65/// Return (a-b) (mod L)
66pub fn sub(a: CompressedScalar, b: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!CompressedScalar {
67 return (try Scalar.fromBytes(a, endian)).sub(try Scalar.fromBytes(b, endian)).toBytes(endian);
68}
69
70/// Return a random scalar
71pub fn random(endian: std.builtin.Endian) CompressedScalar {
72 return Scalar.random().toBytes(endian);
73}
74
75/// A scalar in unpacked representation.
76pub const Scalar = struct {
77 fe: Fe,
78
79 /// Zero.
80 pub const zero = Scalar{ .fe = Fe.zero };
81
82 /// One.
83 pub const one = Scalar{ .fe = Fe.one };
84
85 /// Unpack a serialized representation of a scalar.
86 pub fn fromBytes(s: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!Scalar {
87 return Scalar{ .fe = try Fe.fromBytes(s, endian) };
88 }
89
90 /// Reduce a 384 bit input to the field size.
91 pub fn fromBytes48(s: [48]u8, endian: std.builtin.Endian) Scalar {
92 const t = ScalarDouble.fromBytes(384, s, endian);
93 return t.reduce(384);
94 }
95
96 /// Reduce a 512 bit input to the field size.
97 pub fn fromBytes64(s: [64]u8, endian: std.builtin.Endian) Scalar {
98 const t = ScalarDouble.fromBytes(512, s, endian);
99 return t.reduce(512);
100 }
101
102 /// Pack a scalar into bytes.
103 pub fn toBytes(n: Scalar, endian: std.builtin.Endian) CompressedScalar {
104 return n.fe.toBytes(endian);
105 }
106
107 /// Return true if the scalar is zero..
108 pub fn isZero(n: Scalar) bool {
109 return n.fe.isZero();
110 }
111
112 /// Return true if the scalar is odd.
113 pub fn isOdd(n: Scalar) bool {
114 return n.fe.isOdd();
115 }
116
117 /// Return true if a and b are equivalent.
118 pub fn equivalent(a: Scalar, b: Scalar) bool {
119 return a.fe.equivalent(b.fe);
120 }
121
122 /// Compute x+y (mod L)
123 pub fn add(x: Scalar, y: Scalar) Scalar {
124 return Scalar{ .fe = x.fe.add(y.fe) };
125 }
126
127 /// Compute x-y (mod L)
128 pub fn sub(x: Scalar, y: Scalar) Scalar {
129 return Scalar{ .fe = x.fe.sub(y.fe) };
130 }
131
132 /// Compute 2n (mod L)
133 pub fn dbl(n: Scalar) Scalar {
134 return Scalar{ .fe = n.fe.dbl() };
135 }
136
137 /// Compute x*y (mod L)
138 pub fn mul(x: Scalar, y: Scalar) Scalar {
139 return Scalar{ .fe = x.fe.mul(y.fe) };
140 }
141
142 /// Compute x^2 (mod L)
143 pub fn sq(n: Scalar) Scalar {
144 return Scalar{ .fe = n.fe.sq() };
145 }
146
147 /// Compute x^n (mod L)
148 pub fn pow(a: Scalar, comptime T: type, comptime n: T) Scalar {
149 return Scalar{ .fe = a.fe.pow(n) };
150 }
151
152 /// Compute -x (mod L)
153 pub fn neg(n: Scalar) Scalar {
154 return Scalar{ .fe = n.fe.neg() };
155 }
156
157 /// Compute x^-1 (mod L)
158 pub fn invert(n: Scalar) Scalar {
159 return Scalar{ .fe = n.fe.invert() };
160 }
161
162 /// Return true if n is a quadratic residue mod L.
163 pub fn isSquare(n: Scalar) bool {
164 return n.fe.isSquare();
165 }
166
167 /// Return the square root of L, or NotSquare if there isn't any solutions.
168 pub fn sqrt(n: Scalar) NotSquareError!Scalar {
169 return Scalar{ .fe = try n.fe.sqrt() };
170 }
171
172 /// Return a random scalar < L.
173 pub fn random() Scalar {
174 var s: [48]u8 = undefined;
175 while (true) {
176 crypto.random.bytes(&s);
177 const n = Scalar.fromBytes48(s, .little);
178 if (!n.isZero()) {
179 return n;
180 }
181 }
182 }
183};
184
185const ScalarDouble = struct {
186 x1: Fe,
187 x2: Fe,
188 x3: Fe,
189
190 fn fromBytes(comptime bits: usize, s_: [bits / 8]u8, endian: std.builtin.Endian) ScalarDouble {
191 debug.assert(bits > 0 and bits <= 512 and bits >= Fe.saturated_bits and bits <= Fe.saturated_bits * 3);
192
193 var s = s_;
194 if (endian == .big) {
195 for (s_, 0..) |x, i| s[s.len - 1 - i] = x;
196 }
197 var t = ScalarDouble{ .x1 = undefined, .x2 = Fe.zero, .x3 = Fe.zero };
198 {
199 var b = [_]u8{0} ** encoded_length;
200 const len = @min(s.len, 24);
201 b[0..len].* = s[0..len].*;
202 t.x1 = Fe.fromBytes(b, .little) catch unreachable;
203 }
204 if (s_.len >= 24) {
205 var b = [_]u8{0} ** encoded_length;
206 const len = @min(s.len - 24, 24);
207 b[0..len].* = s[24..][0..len].*;
208 t.x2 = Fe.fromBytes(b, .little) catch unreachable;
209 }
210 if (s_.len >= 48) {
211 var b = [_]u8{0} ** encoded_length;
212 const len = s.len - 48;
213 b[0..len].* = s[48..][0..len].*;
214 t.x3 = Fe.fromBytes(b, .little) catch unreachable;
215 }
216 return t;
217 }
218
219 fn reduce(expanded: ScalarDouble, comptime bits: usize) Scalar {
220 debug.assert(bits > 0 and bits <= Fe.saturated_bits * 3 and bits <= 512);
221 var fe = expanded.x1;
222 if (bits >= 192) {
223 const st1 = Fe.fromInt(1 << 192) catch unreachable;
224 fe = fe.add(expanded.x2.mul(st1));
225 if (bits >= 384) {
226 const st2 = st1.sq();
227 fe = fe.add(expanded.x3.mul(st2));
228 }
229 }
230 return Scalar{ .fe = fe };
231 }
232};