master
1const std = @import("std");
2const common = @import("../common.zig");
3const crypto = std.crypto;
4const debug = std.debug;
5const math = std.math;
6const mem = std.mem;
7
8const Field = common.Field;
9
10const NonCanonicalError = std.crypto.errors.NonCanonicalError;
11const NotSquareError = std.crypto.errors.NotSquareError;
12
13/// Number of bytes required to encode a scalar.
14pub const encoded_length = 48;
15
16/// A compressed scalar, in canonical form.
17pub const CompressedScalar = [encoded_length]u8;
18
19const Fe = Field(.{
20 .fiat = @import("p384_scalar_64.zig"),
21 .field_order = 39402006196394479212279040100143613805079739270465446667946905279627659399113263569398956308152294913554433653942643,
22 .field_bits = 384,
23 .saturated_bits = 384,
24 .encoded_length = encoded_length,
25});
26
27/// The scalar field order.
28pub const field_order = Fe.field_order;
29
30/// Reject a scalar whose encoding is not canonical.
31pub fn rejectNonCanonical(s: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!void {
32 return Fe.rejectNonCanonical(s, endian);
33}
34
35/// Reduce a 64-bytes scalar to the field size.
36pub fn reduce64(s: [64]u8, endian: std.builtin.Endian) CompressedScalar {
37 return Scalar.fromBytes64(s, endian).toBytes(endian);
38}
39
40/// Return a*b (mod L)
41pub fn mul(a: CompressedScalar, b: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!CompressedScalar {
42 return (try Scalar.fromBytes(a, endian)).mul(try Scalar.fromBytes(b, endian)).toBytes(endian);
43}
44
45/// Return a*b+c (mod L)
46pub fn mulAdd(a: CompressedScalar, b: CompressedScalar, c: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!CompressedScalar {
47 return (try Scalar.fromBytes(a, endian)).mul(try Scalar.fromBytes(b, endian)).add(try Scalar.fromBytes(c, endian)).toBytes(endian);
48}
49
50/// Return a+b (mod L)
51pub fn add(a: CompressedScalar, b: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!CompressedScalar {
52 return (try Scalar.fromBytes(a, endian)).add(try Scalar.fromBytes(b, endian)).toBytes(endian);
53}
54
55/// Return -s (mod L)
56pub fn neg(s: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!CompressedScalar {
57 return (try Scalar.fromBytes(s, endian)).neg().toBytes(endian);
58}
59
60/// Return (a-b) (mod L)
61pub fn sub(a: CompressedScalar, b: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!CompressedScalar {
62 return (try Scalar.fromBytes(a, endian)).sub(try Scalar.fromBytes(b, endian)).toBytes(endian);
63}
64
65/// Return a random scalar
66pub fn random(endian: std.builtin.Endian) CompressedScalar {
67 return Scalar.random().toBytes(endian);
68}
69
70/// A scalar in unpacked representation.
71pub const Scalar = struct {
72 fe: Fe,
73
74 /// Zero.
75 pub const zero = Scalar{ .fe = Fe.zero };
76
77 /// One.
78 pub const one = Scalar{ .fe = Fe.one };
79
80 /// Unpack a serialized representation of a scalar.
81 pub fn fromBytes(s: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!Scalar {
82 return Scalar{ .fe = try Fe.fromBytes(s, endian) };
83 }
84
85 /// Reduce a 512 bit input to the field size.
86 pub fn fromBytes64(s: [64]u8, endian: std.builtin.Endian) Scalar {
87 const t = ScalarDouble.fromBytes(512, s, endian);
88 return t.reduce(512);
89 }
90
91 /// Pack a scalar into bytes.
92 pub fn toBytes(n: Scalar, endian: std.builtin.Endian) CompressedScalar {
93 return n.fe.toBytes(endian);
94 }
95
96 /// Return true if the scalar is zero..
97 pub fn isZero(n: Scalar) bool {
98 return n.fe.isZero();
99 }
100
101 /// Return true if the scalar is odd.
102 pub fn isOdd(n: Scalar) bool {
103 return n.fe.isOdd();
104 }
105
106 /// Return true if a and b are equivalent.
107 pub fn equivalent(a: Scalar, b: Scalar) bool {
108 return a.fe.equivalent(b.fe);
109 }
110
111 /// Compute x+y (mod L)
112 pub fn add(x: Scalar, y: Scalar) Scalar {
113 return Scalar{ .fe = x.fe.add(y.fe) };
114 }
115
116 /// Compute x-y (mod L)
117 pub fn sub(x: Scalar, y: Scalar) Scalar {
118 return Scalar{ .fe = x.fe.sub(y.fe) };
119 }
120
121 /// Compute 2n (mod L)
122 pub fn dbl(n: Scalar) Scalar {
123 return Scalar{ .fe = n.fe.dbl() };
124 }
125
126 /// Compute x*y (mod L)
127 pub fn mul(x: Scalar, y: Scalar) Scalar {
128 return Scalar{ .fe = x.fe.mul(y.fe) };
129 }
130
131 /// Compute x^2 (mod L)
132 pub fn sq(n: Scalar) Scalar {
133 return Scalar{ .fe = n.fe.sq() };
134 }
135
136 /// Compute x^n (mod L)
137 pub fn pow(a: Scalar, comptime T: type, comptime n: T) Scalar {
138 return Scalar{ .fe = a.fe.pow(n) };
139 }
140
141 /// Compute -x (mod L)
142 pub fn neg(n: Scalar) Scalar {
143 return Scalar{ .fe = n.fe.neg() };
144 }
145
146 /// Compute x^-1 (mod L)
147 pub fn invert(n: Scalar) Scalar {
148 return Scalar{ .fe = n.fe.invert() };
149 }
150
151 /// Return true if n is a quadratic residue mod L.
152 pub fn isSquare(n: Scalar) bool {
153 return n.fe.isSquare();
154 }
155
156 /// Return the square root of L, or NotSquare if there isn't any solutions.
157 pub fn sqrt(n: Scalar) NotSquareError!Scalar {
158 return Scalar{ .fe = try n.fe.sqrt() };
159 }
160
161 /// Return a random scalar < L.
162 pub fn random() Scalar {
163 var s: [64]u8 = undefined;
164 while (true) {
165 crypto.random.bytes(&s);
166 const n = Scalar.fromBytes64(s, .little);
167 if (!n.isZero()) {
168 return n;
169 }
170 }
171 }
172};
173
174const ScalarDouble = struct {
175 x1: Fe,
176 x2: Fe,
177
178 fn fromBytes(comptime bits: usize, s_: [bits / 8]u8, endian: std.builtin.Endian) ScalarDouble {
179 debug.assert(bits > 0 and bits <= 512 and bits >= Fe.saturated_bits and bits <= Fe.saturated_bits * 2);
180
181 var s = s_;
182 if (endian == .big) {
183 for (s_, 0..) |x, i| s[s.len - 1 - i] = x;
184 }
185 var t = ScalarDouble{ .x1 = undefined, .x2 = Fe.zero };
186 {
187 var b = [_]u8{0} ** encoded_length;
188 const len = @min(s.len, 32);
189 b[0..len].* = s[0..len].*;
190 t.x1 = Fe.fromBytes(b, .little) catch unreachable;
191 }
192 if (s_.len >= 32) {
193 var b = [_]u8{0} ** encoded_length;
194 const len = @min(s.len - 32, 32);
195 b[0..len].* = s[32..][0..len].*;
196 t.x2 = Fe.fromBytes(b, .little) catch unreachable;
197 }
198 return t;
199 }
200
201 fn reduce(expanded: ScalarDouble, comptime bits: usize) Scalar {
202 debug.assert(bits > 0 and bits <= Fe.saturated_bits * 3 and bits <= 512);
203 var fe = expanded.x1;
204 if (bits >= 256) {
205 const st1 = Fe.fromInt(1 << 256) catch unreachable;
206 fe = fe.add(expanded.x2.mul(st1));
207 }
208 return Scalar{ .fe = fe };
209 }
210};