master
  1const std = @import("std");
  2const common = @import("../common.zig");
  3const crypto = std.crypto;
  4const debug = std.debug;
  5const math = std.math;
  6const mem = std.mem;
  7
  8const Field = common.Field;
  9
 10const NonCanonicalError = std.crypto.errors.NonCanonicalError;
 11const NotSquareError = std.crypto.errors.NotSquareError;
 12
 13/// Number of bytes required to encode a scalar.
 14pub const encoded_length = 32;
 15
 16/// A compressed scalar, in canonical form.
 17pub const CompressedScalar = [encoded_length]u8;
 18
 19const Fe = Field(.{
 20    .fiat = @import("secp256k1_scalar_64.zig"),
 21    .field_order = 115792089237316195423570985008687907852837564279074904382605163141518161494337,
 22    .field_bits = 256,
 23    .saturated_bits = 256,
 24    .encoded_length = encoded_length,
 25});
 26
 27/// The scalar field order.
 28pub const field_order = Fe.field_order;
 29
 30/// Reject a scalar whose encoding is not canonical.
 31pub fn rejectNonCanonical(s: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!void {
 32    return Fe.rejectNonCanonical(s, endian);
 33}
 34
 35/// Reduce a 48-bytes scalar to the field size.
 36pub fn reduce48(s: [48]u8, endian: std.builtin.Endian) CompressedScalar {
 37    return Scalar.fromBytes48(s, endian).toBytes(endian);
 38}
 39
 40/// Reduce a 64-bytes scalar to the field size.
 41pub fn reduce64(s: [64]u8, endian: std.builtin.Endian) CompressedScalar {
 42    return Scalar.fromBytes64(s, endian).toBytes(endian);
 43}
 44
 45/// Return a*b (mod L)
 46pub fn mul(a: CompressedScalar, b: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!CompressedScalar {
 47    return (try Scalar.fromBytes(a, endian)).mul(try Scalar.fromBytes(b, endian)).toBytes(endian);
 48}
 49
 50/// Return a*b+c (mod L)
 51pub fn mulAdd(a: CompressedScalar, b: CompressedScalar, c: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!CompressedScalar {
 52    return (try Scalar.fromBytes(a, endian)).mul(try Scalar.fromBytes(b, endian)).add(try Scalar.fromBytes(c, endian)).toBytes(endian);
 53}
 54
 55/// Return a+b (mod L)
 56pub fn add(a: CompressedScalar, b: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!CompressedScalar {
 57    return (try Scalar.fromBytes(a, endian)).add(try Scalar.fromBytes(b, endian)).toBytes(endian);
 58}
 59
 60/// Return -s (mod L)
 61pub fn neg(s: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!CompressedScalar {
 62    return (try Scalar.fromBytes(s, endian)).neg().toBytes(endian);
 63}
 64
 65/// Return (a-b) (mod L)
 66pub fn sub(a: CompressedScalar, b: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!CompressedScalar {
 67    return (try Scalar.fromBytes(a, endian)).sub(try Scalar.fromBytes(b, endian)).toBytes(endian);
 68}
 69
 70/// Return a random scalar
 71pub fn random(endian: std.builtin.Endian) CompressedScalar {
 72    return Scalar.random().toBytes(endian);
 73}
 74
 75/// A scalar in unpacked representation.
 76pub const Scalar = struct {
 77    fe: Fe,
 78
 79    /// Zero.
 80    pub const zero = Scalar{ .fe = Fe.zero };
 81
 82    /// One.
 83    pub const one = Scalar{ .fe = Fe.one };
 84
 85    /// Unpack a serialized representation of a scalar.
 86    pub fn fromBytes(s: CompressedScalar, endian: std.builtin.Endian) NonCanonicalError!Scalar {
 87        return Scalar{ .fe = try Fe.fromBytes(s, endian) };
 88    }
 89
 90    /// Reduce a 384 bit input to the field size.
 91    pub fn fromBytes48(s: [48]u8, endian: std.builtin.Endian) Scalar {
 92        const t = ScalarDouble.fromBytes(384, s, endian);
 93        return t.reduce(384);
 94    }
 95
 96    /// Reduce a 512 bit input to the field size.
 97    pub fn fromBytes64(s: [64]u8, endian: std.builtin.Endian) Scalar {
 98        const t = ScalarDouble.fromBytes(512, s, endian);
 99        return t.reduce(512);
100    }
101
102    /// Pack a scalar into bytes.
103    pub fn toBytes(n: Scalar, endian: std.builtin.Endian) CompressedScalar {
104        return n.fe.toBytes(endian);
105    }
106
107    /// Return true if the scalar is zero..
108    pub fn isZero(n: Scalar) bool {
109        return n.fe.isZero();
110    }
111
112    /// Return true if the scalar is odd.
113    pub fn isOdd(n: Scalar) bool {
114        return n.fe.isOdd();
115    }
116
117    /// Return true if a and b are equivalent.
118    pub fn equivalent(a: Scalar, b: Scalar) bool {
119        return a.fe.equivalent(b.fe);
120    }
121
122    /// Compute x+y (mod L)
123    pub fn add(x: Scalar, y: Scalar) Scalar {
124        return Scalar{ .fe = x.fe.add(y.fe) };
125    }
126
127    /// Compute x-y (mod L)
128    pub fn sub(x: Scalar, y: Scalar) Scalar {
129        return Scalar{ .fe = x.fe.sub(y.fe) };
130    }
131
132    /// Compute 2n (mod L)
133    pub fn dbl(n: Scalar) Scalar {
134        return Scalar{ .fe = n.fe.dbl() };
135    }
136
137    /// Compute x*y (mod L)
138    pub fn mul(x: Scalar, y: Scalar) Scalar {
139        return Scalar{ .fe = x.fe.mul(y.fe) };
140    }
141
142    /// Compute x^2 (mod L)
143    pub fn sq(n: Scalar) Scalar {
144        return Scalar{ .fe = n.fe.sq() };
145    }
146
147    /// Compute x^n (mod L)
148    pub fn pow(a: Scalar, comptime T: type, comptime n: T) Scalar {
149        return Scalar{ .fe = a.fe.pow(n) };
150    }
151
152    /// Compute -x (mod L)
153    pub fn neg(n: Scalar) Scalar {
154        return Scalar{ .fe = n.fe.neg() };
155    }
156
157    /// Compute x^-1 (mod L)
158    pub fn invert(n: Scalar) Scalar {
159        return Scalar{ .fe = n.fe.invert() };
160    }
161
162    /// Return true if n is a quadratic residue mod L.
163    pub fn isSquare(n: Scalar) bool {
164        return n.fe.isSquare();
165    }
166
167    /// Return the square root of L, or NotSquare if there isn't any solutions.
168    pub fn sqrt(n: Scalar) NotSquareError!Scalar {
169        return Scalar{ .fe = try n.fe.sqrt() };
170    }
171
172    /// Return a random scalar < L.
173    pub fn random() Scalar {
174        var s: [48]u8 = undefined;
175        while (true) {
176            crypto.random.bytes(&s);
177            const n = Scalar.fromBytes48(s, .little);
178            if (!n.isZero()) {
179                return n;
180            }
181        }
182    }
183};
184
185const ScalarDouble = struct {
186    x1: Fe,
187    x2: Fe,
188    x3: Fe,
189
190    fn fromBytes(comptime bits: usize, s_: [bits / 8]u8, endian: std.builtin.Endian) ScalarDouble {
191        debug.assert(bits > 0 and bits <= 512 and bits >= Fe.saturated_bits and bits <= Fe.saturated_bits * 3);
192
193        var s = s_;
194        if (endian == .big) {
195            for (s_, 0..) |x, i| s[s.len - 1 - i] = x;
196        }
197        var t = ScalarDouble{ .x1 = undefined, .x2 = Fe.zero, .x3 = Fe.zero };
198        {
199            var b = [_]u8{0} ** encoded_length;
200            const len = @min(s.len, 24);
201            b[0..len].* = s[0..len].*;
202            t.x1 = Fe.fromBytes(b, .little) catch unreachable;
203        }
204        if (s_.len >= 24) {
205            var b = [_]u8{0} ** encoded_length;
206            const len = @min(s.len - 24, 24);
207            b[0..len].* = s[24..][0..len].*;
208            t.x2 = Fe.fromBytes(b, .little) catch unreachable;
209        }
210        if (s_.len >= 48) {
211            var b = [_]u8{0} ** encoded_length;
212            const len = s.len - 48;
213            b[0..len].* = s[48..][0..len].*;
214            t.x3 = Fe.fromBytes(b, .little) catch unreachable;
215        }
216        return t;
217    }
218
219    fn reduce(expanded: ScalarDouble, comptime bits: usize) Scalar {
220        debug.assert(bits > 0 and bits <= Fe.saturated_bits * 3 and bits <= 512);
221        var fe = expanded.x1;
222        if (bits >= 192) {
223            const st1 = Fe.fromInt(1 << 192) catch unreachable;
224            fe = fe.add(expanded.x2.mul(st1));
225            if (bits >= 384) {
226                const st2 = st1.sq();
227                fe = fe.add(expanded.x3.mul(st2));
228            }
229        }
230        return Scalar{ .fe = fe };
231    }
232};