Commit a5c79c7998

Frank Denis <124872+jedisct1@users.noreply.github.com>
2023-10-21 11:09:43
crypto.ff: faster exponentiation with short/public exponents (#17617)
RSA exponents are typically 3 or 65537, and public. For those, we don't need to use conditional moves on the exponent, and precomputing a lookup table is not worth it. So, save a few cpu cycles and some memory for that common case. For safety, make `powWithEncodedExponent()` constant-time by default, and introduce a `powWithEncodedPublicExponent()` function for exponents that are assumed to be public. With `powWithEncodedPublicExponent()`, short (<= 36 bits) exponents will take the fast path.
1 parent 54a4f24
Changed files (1)
lib
std
crypto
lib/std/crypto/ff.zig
@@ -656,6 +656,101 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
             return d;
         }
 
+        // Returns x^e (mod m), with the exponent provided as a byte string.
+        // `public` must be set to `false` if the exponent it secret.
+        fn powWithEncodedExponentInternal(self: Self, x: Fe, e: []const u8, endian: builtin.Endian, comptime public: bool) NullExponentError!Fe {
+            var acc: u8 = 0;
+            for (e) |b| acc |= b;
+            if (acc == 0) return error.NullExponent;
+
+            var out = self.one();
+            self.toMontgomery(&out) catch unreachable;
+
+            if (public and e.len < 3 or (e.len == 3 and e[if (endian == .Big) 0 else 2] <= 0b1111)) {
+                // Do not use a precomputation table for short, public exponents
+                var x_m = x;
+                if (x.montgomery == false) {
+                    self.toMontgomery(&x_m) catch unreachable;
+                }
+                var s = switch (endian) {
+                    .Big => 0,
+                    .Little => e.len - 1,
+                };
+                while (true) {
+                    const b = e[s];
+                    var j: u3 = 7;
+                    while (true) : (j -= 1) {
+                        out = self.montgomerySq(out);
+                        const k: u1 = @truncate(b >> j);
+                        if (k != 0) {
+                            const t = self.montgomeryMul(out, x_m);
+                            @memcpy(out.v.limbs.slice(), t.v.limbs.constSlice());
+                        }
+                        if (j == 0) break;
+                    }
+                    switch (endian) {
+                        .Big => {
+                            s += 1;
+                            if (s == e.len) break;
+                        },
+                        .Little => {
+                            if (s == 0) break;
+                            s -= 1;
+                        },
+                    }
+                }
+            } else {
+                // Use a precomputation table for large exponents
+                var pc = [1]Fe{x} ++ [_]Fe{self.zero} ** 14;
+                if (x.montgomery == false) {
+                    self.toMontgomery(&pc[0]) catch unreachable;
+                }
+                for (1..pc.len) |i| {
+                    pc[i] = self.montgomeryMul(pc[i - 1], pc[0]);
+                }
+                var t0 = self.zero;
+                var s = switch (endian) {
+                    .Big => 0,
+                    .Little => e.len - 1,
+                };
+                while (true) {
+                    const b = e[s];
+                    for ([_]u3{ 4, 0 }) |j| {
+                        for (0..4) |_| {
+                            out = self.montgomerySq(out);
+                        }
+                        const k = (b >> j) & 0b1111;
+                        if (public or std.options.side_channels_mitigations == .none) {
+                            if (k == 0) continue;
+                            t0 = pc[k - 1];
+                        } else {
+                            for (pc, 0..) |t, i| {
+                                t0.v.cmov(ct.eql(k, @as(u8, @truncate(i + 1))), t.v);
+                            }
+                        }
+                        const t1 = self.montgomeryMul(out, t0);
+                        if (public) {
+                            @memcpy(out.v.limbs.slice(), t1.v.limbs.constSlice());
+                        } else {
+                            out.v.cmov(!ct.eql(k, 0), t1.v);
+                        }
+                    }
+                    switch (endian) {
+                        .Big => {
+                            s += 1;
+                            if (s == e.len) break;
+                        },
+                        .Little => {
+                            if (s == 0) break;
+                            s -= 1;
+                        },
+                    }
+                }
+            }
+            self.fromMontgomery(&out) catch unreachable;
+            return out;
+        }
+
         /// Multiplies two field elements.
         pub fn mul(self: Self, x: Fe, y: Fe) Fe {
             if (x.montgomery != y.montgomery) {
@@ -698,62 +793,25 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
             e_normalized.toBytes(buf, .Little) catch unreachable;
             const leading = @clz(e_normalized.v.limbs.get(e_normalized.v.limbs_count() - carry_bits));
             buf = buf[0 .. buf.len - leading / 8];
-            return self.powWithEncodedExponent(x, buf, .Little);
+            return self.powWithEncodedPublicExponent(x, buf, .Little);
         }
 
-        /// Returns x^e (mod m), assuming that the exponent is public, and provided as a byte string.
+        /// Returns x^e (mod m), with the exponent provided as a byte string.
         /// Exponents are usually small, so this function is faster than `powPublic` as a field element
         /// doesn't have to be created if a serialized representation is already available.
+        ///
+        /// If the exponent is public, `powWithEncodedPublicExponent()` can be used instead for a slight speedup.
         pub fn powWithEncodedExponent(self: Self, x: Fe, e: []const u8, endian: builtin.Endian) NullExponentError!Fe {
-            var acc: u8 = 0;
-            for (e) |b| acc |= b;
-            if (acc == 0) return error.NullExponent;
+            return self.powWithEncodedExponentInternal(x, e, endian, false);
+        }
 
-            var pc = [1]Fe{x} ++ [_]Fe{self.zero} ** 14;
-            if (x.montgomery == false) {
-                self.toMontgomery(&pc[0]) catch unreachable;
-            }
-            for (1..pc.len) |i| {
-                pc[i] = self.montgomeryMul(pc[i - 1], pc[0]);
-            }
-            var out = self.one();
-            self.toMontgomery(&out) catch unreachable;
-            var t0 = self.zero;
-            var s = switch (endian) {
-                .Big => 0,
-                .Little => e.len - 1,
-            };
-            while (true) {
-                const b = e[s];
-                for ([_]u3{ 4, 0 }) |j| {
-                    for (0..4) |_| {
-                        out = self.montgomerySq(out);
-                    }
-                    const k = (b >> j) & 0b1111;
-                    if (std.options.side_channels_mitigations == .none) {
-                        if (k == 0) continue;
-                        t0 = pc[k - 1];
-                    } else {
-                        for (pc, 0..) |t, i| {
-                            t0.v.cmov(ct.eql(k, @as(u8, @truncate(i + 1))), t.v);
-                        }
-                    }
-                    const t1 = self.montgomeryMul(out, t0);
-                    out.v.cmov(!ct.eql(k, 0), t1.v);
-                }
-                switch (endian) {
-                    .Big => {
-                        s += 1;
-                        if (s == e.len) break;
-                    },
-                    .Little => {
-                        if (s == 0) break;
-                        s -= 1;
-                    },
-                }
-            }
-            self.fromMontgomery(&out) catch unreachable;
-            return out;
+        /// Returns x^e (mod m), the exponent being public and provided as a byte string.
+        /// Exponents are usually small, so this function is faster than `powPublic` as a field element
+        /// doesn't have to be created if a serialized representation is already available.
+        ///
+        /// If the exponent is secret, `powWithEncodedExponent` must be used instead.
+        pub fn powWithEncodedPublicExponent(self: Self, x: Fe, e: []const u8, endian: builtin.Endian) NullExponentError!Fe {
+            return self.powWithEncodedExponentInternal(x, e, endian, true);
         }
     };
 }