master
  1const std = @import("std");
  2const fmt = std.fmt;
  3
  4const EncodingError = std.crypto.errors.EncodingError;
  5const IdentityElementError = std.crypto.errors.IdentityElementError;
  6const NonCanonicalError = std.crypto.errors.NonCanonicalError;
  7const WeakPublicKeyError = std.crypto.errors.WeakPublicKeyError;
  8
  9/// Group operations over Edwards25519.
 10pub const Ristretto255 = struct {
 11    /// The underlying elliptic curve.
 12    pub const Curve = @import("edwards25519.zig").Edwards25519;
 13    /// The underlying prime field.
 14    pub const Fe = Curve.Fe;
 15    /// Field arithmetic mod the order of the main subgroup.
 16    pub const scalar = Curve.scalar;
 17    /// Length in byte of an encoded element.
 18    pub const encoded_length: usize = 32;
 19
 20    p: Curve,
 21
 22    fn sqrtRatioM1(u: Fe, v: Fe) struct { ratio_is_square: u32, root: Fe } {
 23        const v3 = v.sq().mul(v); // v^3
 24        var x = v3.sq().mul(u).mul(v).pow2523().mul(v3).mul(u); // uv^3(uv^7)^((q-5)/8)
 25        const vxx = x.sq().mul(v); // vx^2
 26        const m_root_check = vxx.sub(u); // vx^2-u
 27        const p_root_check = vxx.add(u); // vx^2+u
 28        const f_root_check = u.mul(Fe.sqrtm1).add(vxx); // vx^2+u*sqrt(-1)
 29        const has_m_root = m_root_check.isZero();
 30        const has_p_root = p_root_check.isZero();
 31        const has_f_root = f_root_check.isZero();
 32        const x_sqrtm1 = x.mul(Fe.sqrtm1); // x*sqrt(-1)
 33        x.cMov(x_sqrtm1, @intFromBool(has_p_root) | @intFromBool(has_f_root));
 34        return .{ .ratio_is_square = @intFromBool(has_m_root) | @intFromBool(has_p_root), .root = x.abs() };
 35    }
 36
 37    fn rejectNonCanonical(s: [encoded_length]u8) NonCanonicalError!void {
 38        if ((s[0] & 1) != 0) {
 39            return error.NonCanonical;
 40        }
 41        try Fe.rejectNonCanonical(s, false);
 42    }
 43
 44    /// Reject the neutral element.
 45    pub fn rejectIdentity(p: Ristretto255) IdentityElementError!void {
 46        return p.p.rejectIdentity();
 47    }
 48
 49    /// The base point (Ristretto is a curve in desguise).
 50    pub const basePoint = Ristretto255{ .p = Curve.basePoint };
 51
 52    /// Decode a Ristretto255 representative.
 53    pub fn fromBytes(s: [encoded_length]u8) (NonCanonicalError || EncodingError)!Ristretto255 {
 54        try rejectNonCanonical(s);
 55        const s_ = Fe.fromBytes(s);
 56        const ss = s_.sq(); // s^2
 57        const u1_ = Fe.one.sub(ss); // (1-s^2)
 58        const u1u1 = u1_.sq(); // (1-s^2)^2
 59        const u2_ = Fe.one.add(ss); // (1+s^2)
 60        const u2u2 = u2_.sq(); // (1+s^2)^2
 61        const v = Fe.edwards25519d.mul(u1u1).neg().sub(u2u2); // -(d*u1^2)-u2^2
 62        const v_u2u2 = v.mul(u2u2); // v*u2^2
 63
 64        const inv_sqrt = sqrtRatioM1(Fe.one, v_u2u2);
 65        var x = inv_sqrt.root.mul(u2_);
 66        const y = inv_sqrt.root.mul(x).mul(v).mul(u1_);
 67        x = x.mul(s_);
 68        x = x.add(x).abs();
 69        const t = x.mul(y);
 70        if ((1 - inv_sqrt.ratio_is_square) | @intFromBool(t.isNegative()) | @intFromBool(y.isZero()) != 0) {
 71            return error.InvalidEncoding;
 72        }
 73        const p: Curve = .{
 74            .x = x,
 75            .y = y,
 76            .z = Fe.one,
 77            .t = t,
 78        };
 79        return Ristretto255{ .p = p };
 80    }
 81
 82    /// Encode to a Ristretto255 representative.
 83    pub fn toBytes(e: Ristretto255) [encoded_length]u8 {
 84        const p = &e.p;
 85        var u1_ = p.z.add(p.y); // Z+Y
 86        const zmy = p.z.sub(p.y); // Z-Y
 87        u1_ = u1_.mul(zmy); // (Z+Y)*(Z-Y)
 88        const u2_ = p.x.mul(p.y); // X*Y
 89        const u1_u2u2 = u2_.sq().mul(u1_); // u1*u2^2
 90        const inv_sqrt = sqrtRatioM1(Fe.one, u1_u2u2);
 91        const den1 = inv_sqrt.root.mul(u1_);
 92        const den2 = inv_sqrt.root.mul(u2_);
 93        const z_inv = den1.mul(den2).mul(p.t); // den1*den2*T
 94        const ix = p.x.mul(Fe.sqrtm1); // X*sqrt(-1)
 95        const iy = p.y.mul(Fe.sqrtm1); // Y*sqrt(-1)
 96        const eden = den1.mul(Fe.edwards25519sqrtamd); // den1/sqrt(a-d)
 97        const t_z_inv = p.t.mul(z_inv); // T*z_inv
 98
 99        const rotate = @intFromBool(t_z_inv.isNegative());
100        var x = p.x;
101        var y = p.y;
102        var den_inv = den2;
103        x.cMov(iy, rotate);
104        y.cMov(ix, rotate);
105        den_inv.cMov(eden, rotate);
106
107        const x_z_inv = x.mul(z_inv);
108        const yneg = y.neg();
109        y.cMov(yneg, @intFromBool(x_z_inv.isNegative()));
110
111        return p.z.sub(y).mul(den_inv).abs().toBytes();
112    }
113
114    fn elligator(t: Fe) Curve {
115        const r = t.sq().mul(Fe.sqrtm1); // sqrt(-1)*t^2
116        const u = r.add(Fe.one).mul(Fe.edwards25519eonemsqd); // (r+1)*(1-d^2)
117        var c = comptime Fe.one.neg(); // -1
118        const v = c.sub(r.mul(Fe.edwards25519d)).mul(r.add(Fe.edwards25519d)); // (c-r*d)*(r+d)
119        const ratio_sqrt = sqrtRatioM1(u, v);
120        const wasnt_square = 1 - ratio_sqrt.ratio_is_square;
121        var s = ratio_sqrt.root;
122        const s_prime = s.mul(t).abs().neg(); // -|s*t|
123        s.cMov(s_prime, wasnt_square);
124        c.cMov(r, wasnt_square);
125
126        const n = r.sub(Fe.one).mul(c).mul(Fe.edwards25519sqdmone).sub(v); // c*(r-1)*(d-1)^2-v
127        const w0 = s.add(s).mul(v); // 2s*v
128        const w1 = n.mul(Fe.edwards25519sqrtadm1); // n*sqrt(ad-1)
129        const ss = s.sq(); // s^2
130        const w2 = Fe.one.sub(ss); // 1-s^2
131        const w3 = Fe.one.add(ss); // 1+s^2
132
133        return .{ .x = w0.mul(w3), .y = w2.mul(w1), .z = w1.mul(w3), .t = w0.mul(w2) };
134    }
135
136    /// Map a 64-bit string into a Ristretto255 group element
137    pub fn fromUniform(h: [64]u8) Ristretto255 {
138        const p0 = elligator(Fe.fromBytes(h[0..32].*));
139        const p1 = elligator(Fe.fromBytes(h[32..64].*));
140        return Ristretto255{ .p = p0.add(p1) };
141    }
142
143    /// Double a Ristretto255 element.
144    pub fn dbl(p: Ristretto255) Ristretto255 {
145        return .{ .p = p.p.dbl() };
146    }
147
148    /// Add two Ristretto255 elements.
149    pub fn add(p: Ristretto255, q: Ristretto255) Ristretto255 {
150        return .{ .p = p.p.add(q.p) };
151    }
152
153    /// Subtract two Ristretto255 elements.
154    pub fn sub(p: Ristretto255, q: Ristretto255) Ristretto255 {
155        return .{ .p = p.p.sub(q.p) };
156    }
157
158    /// Multiply a Ristretto255 element with a scalar.
159    /// Return error.WeakPublicKey if the resulting element is
160    /// the identity element.
161    pub fn mul(p: Ristretto255, s: [encoded_length]u8) (IdentityElementError || WeakPublicKeyError)!Ristretto255 {
162        return .{ .p = try p.p.mul(s) };
163    }
164
165    /// Return true if two Ristretto255 elements are equivalent
166    pub fn equivalent(p: Ristretto255, q: Ristretto255) bool {
167        const p_ = &p.p;
168        const q_ = &q.p;
169        const a = p_.x.mul(q_.y).equivalent(p_.y.mul(q_.x));
170        const b = p_.y.mul(q_.y).equivalent(p_.x.mul(q_.x));
171        return (@intFromBool(a) | @intFromBool(b)) != 0;
172    }
173};
174
175test "ristretto255" {
176    const p = Ristretto255.basePoint;
177    var buf: [256]u8 = undefined;
178    try std.testing.expectEqualStrings(try std.fmt.bufPrint(&buf, "{X}", .{&p.toBytes()}), "E2F2AE0A6ABC4E71A884A961C500515F58E30B6AA582DD8DB6A65945E08D2D76");
179
180    var r: [Ristretto255.encoded_length]u8 = undefined;
181    _ = try fmt.hexToBytes(r[0..], "6a493210f7499cd17fecb510ae0cea23a110e8d5b901f8acadd3095c73a3b919");
182    var q = try Ristretto255.fromBytes(r);
183    q = q.dbl().add(p);
184    try std.testing.expectEqualStrings(try std.fmt.bufPrint(&buf, "{X}", .{&q.toBytes()}), "E882B131016B52C1D3337080187CF768423EFCCBB517BB495AB812C4160FF44E");
185
186    const s = [_]u8{15} ++ [_]u8{0} ** 31;
187    const w = try p.mul(s);
188    try std.testing.expectEqualStrings(try std.fmt.bufPrint(&buf, "{X}", .{&w.toBytes()}), "E0C418F7C8D9C4CDD7395B93EA124F3AD99021BB681DFC3302A9D99A2E53E64E");
189
190    try std.testing.expect(p.dbl().dbl().dbl().dbl().equivalent(w.add(p)));
191
192    const h = [_]u8{69} ** 32 ++ [_]u8{42} ** 32;
193    const ph = Ristretto255.fromUniform(h);
194    try std.testing.expectEqualStrings(try std.fmt.bufPrint(&buf, "{X}", .{&ph.toBytes()}), "DCCA54E037A4311EFBEEF413ACD21D35276518970B7A61DC88F8587B493D5E19");
195}