Commit 49d6dd3ecb

Andrew Kelley <andrew@ziglang.org>
2023-11-22 04:21:57
std.crypto.ff: simplify implementation
* Take advantage of multi-object for loops. * Remove use of BoundedArray since it had no meaningful impact on safety or readability. * Simplify some complex expressions, such as using `!` to invert a boolean value.
1 parent 7b3556a
Changed files (1)
lib
std
crypto
lib/std/crypto/ff.zig
@@ -12,7 +12,6 @@ const math = std.math;
 const mem = std.mem;
 const meta = std.meta;
 const testing = std.testing;
-const BoundedArray = std.BoundedArray;
 const assert = std.debug.assert;
 const Endian = std.builtin.Endian;
 
@@ -63,46 +62,54 @@ pub fn Uint(comptime max_bits: comptime_int) type {
 
     return struct {
         const Self = @This();
-
         const max_limbs_count = math.divCeil(usize, max_bits, t_bits) catch unreachable;
-        const Limbs = BoundedArray(Limb, max_limbs_count);
-        limbs: Limbs,
+
+        limbs_buffer: [max_limbs_count]Limb,
+        /// The number of active limbs.
+        limbs_len: usize,
 
         /// Number of bytes required to serialize an integer.
         pub const encoded_bytes = math.divCeil(usize, max_bits, 8) catch unreachable;
 
-        // Returns the number of active limbs.
-        fn limbs_count(self: Self) usize {
-            return self.limbs.len;
+        /// Constant slice of active limbs.
+        fn limbsConst(self: *const Self) []const Limb {
+            return self.limbs_buffer[0..self.limbs_len];
+        }
+
+        /// Mutable slice of active limbs.
+        fn limbs(self: *Self) []Limb {
+            return self.limbs_buffer[0..self.limbs_len];
         }
 
         // Removes limbs whose value is zero from the active limbs.
         fn normalize(self: Self) Self {
             var res = self;
-            if (self.limbs_count() < 2) {
+            if (self.limbs_len < 2) {
                 return res;
             }
-            var i = self.limbs_count() - 1;
-            while (i > 0 and res.limbs.get(i) == 0) : (i -= 1) {}
-            res.limbs.resize(i + 1) catch unreachable;
+            var i = self.limbs_len - 1;
+            while (i > 0 and res.limbsConst()[i] == 0) : (i -= 1) {}
+            res.limbs_len = i + 1;
+            assert(res.limbs_len <= res.limbs_buffer.len);
             return res;
         }
 
         /// The zero integer.
-        pub const zero = zero: {
-            var limbs = Limbs.init(0) catch unreachable;
-            limbs.appendNTimesAssumeCapacity(0, max_limbs_count);
-            break :zero Self{ .limbs = limbs };
+        pub const zero: Self = .{
+            .limbs_buffer = [1]Limb{0} ** max_limbs_count,
+            .limbs_len = max_limbs_count,
         };
 
         /// Creates a new big integer from a primitive type.
         /// This function may not run in constant time.
-        pub fn fromPrimitive(comptime T: type, x_: T) OverflowError!Self {
-            var x = x_;
-            var out = Self.zero;
-            for (0..out.limbs.capacity()) |i| {
-                const t = if (@bitSizeOf(T) > t_bits) @as(TLimb, @truncate(x)) else x;
-                out.limbs.set(i, t);
+        pub fn fromPrimitive(comptime T: type, init_value: T) OverflowError!Self {
+            var x = init_value;
+            var out: Self = .{
+                .limbs_buffer = undefined,
+                .limbs_len = max_limbs_count,
+            };
+            for (&out.limbs_buffer) |*limb| {
+                limb.* = if (@bitSizeOf(T) > t_bits) @as(TLimb, @truncate(x)) else x;
                 x = math.shr(T, x, t_bits);
             }
             if (x != 0) {
@@ -115,13 +122,13 @@ pub fn Uint(comptime max_bits: comptime_int) type {
         /// This function may not run in constant time.
         pub fn toPrimitive(self: Self, comptime T: type) OverflowError!T {
             var x: T = 0;
-            var i = self.limbs_count() - 1;
+            var i = self.limbs_len - 1;
             while (true) : (i -= 1) {
                 if (@bitSizeOf(T) >= t_bits and math.shr(T, x, @bitSizeOf(T) - t_bits) != 0) {
                     return error.Overflow;
                 }
                 x = math.shl(T, x, t_bits);
-                const v = math.cast(T, self.limbs.get(i)) orelse return error.Overflow;
+                const v = math.cast(T, self.limbsConst()[i]) orelse return error.Overflow;
                 x |= v;
                 if (i == 0) break;
             }
@@ -140,9 +147,9 @@ pub fn Uint(comptime max_bits: comptime_int) type {
                 .big => bytes.len - 1,
                 .little => 0,
             };
-            for (0..self.limbs.len) |i| {
+            for (0..self.limbs_len) |i| {
                 var remaining_bits = t_bits;
-                var limb = self.limbs.get(i);
+                var limb = self.limbsConst()[i];
                 while (remaining_bits >= 8) {
                     bytes[out_i] |= math.shl(u8, @as(u8, @truncate(limb)), shift);
                     const consumed = 8 - shift;
@@ -152,7 +159,7 @@ pub fn Uint(comptime max_bits: comptime_int) type {
                     switch (endian) {
                         .big => {
                             if (out_i == 0) {
-                                if (i != self.limbs.len - 1 or limb != 0) {
+                                if (i != self.limbs_len - 1 or limb != 0) {
                                     return error.Overflow;
                                 }
                                 return;
@@ -162,7 +169,7 @@ pub fn Uint(comptime max_bits: comptime_int) type {
                         .little => {
                             out_i += 1;
                             if (out_i == bytes.len) {
-                                if (i != self.limbs.len - 1 or limb != 0) {
+                                if (i != self.limbs_len - 1 or limb != 0) {
                                     return error.Overflow;
                                 }
                                 return;
@@ -187,20 +194,20 @@ pub fn Uint(comptime max_bits: comptime_int) type {
             };
             while (true) {
                 const bi = bytes[i];
-                out.limbs.set(out_i, out.limbs.get(out_i) | math.shl(Limb, bi, shift));
+                out.limbs()[out_i] |= math.shl(Limb, bi, shift);
                 shift += 8;
                 if (shift >= t_bits) {
                     shift -= t_bits;
-                    out.limbs.set(out_i, @as(TLimb, @truncate(out.limbs.get(out_i))));
+                    out.limbs()[out_i] = @as(TLimb, @truncate(out.limbs()[out_i]));
                     const overflow = math.shr(Limb, bi, 8 - shift);
                     out_i += 1;
-                    if (out_i >= out.limbs.len) {
+                    if (out_i >= out.limbs_len) {
                         if (overflow != 0 or i != 0) {
                             return error.Overflow;
                         }
                         break;
                     }
-                    out.limbs.set(out_i, overflow);
+                    out.limbs()[out_i] = overflow;
                 }
                 switch (endian) {
                     .big => {
@@ -218,32 +225,31 @@ pub fn Uint(comptime max_bits: comptime_int) type {
 
         /// Returns `true` if both integers are equal.
         pub fn eql(x: Self, y: Self) bool {
-            return crypto.utils.timingSafeEql([max_limbs_count]Limb, x.limbs.buffer, y.limbs.buffer);
+            return crypto.utils.timingSafeEql([max_limbs_count]Limb, x.limbs_buffer, y.limbs_buffer);
         }
 
         /// Compares two integers.
         pub fn compare(x: Self, y: Self) math.Order {
             return crypto.utils.timingSafeCompare(
                 Limb,
-                x.limbs.constSlice(),
-                y.limbs.constSlice(),
+                x.limbsConst(),
+                y.limbsConst(),
                 .little,
             );
         }
 
         /// Returns `true` if the integer is zero.
         pub fn isZero(x: Self) bool {
-            const x_limbs = x.limbs.constSlice();
             var t: Limb = 0;
-            for (0..x.limbs_count()) |i| {
-                t |= x_limbs[i];
+            for (x.limbsConst()) |elem| {
+                t |= elem;
             }
             return ct.eql(t, 0);
         }
 
         /// Returns `true` if the integer is odd.
         pub fn isOdd(x: Self) bool {
-            return @as(bool, @bitCast(@as(u1, @truncate(x.limbs.get(0)))));
+            return @as(u1, @truncate(x.limbsConst()[0])) != 0;
         }
 
         /// Adds `y` to `x`, and returns `true` if the operation overflowed.
@@ -258,39 +264,31 @@ pub fn Uint(comptime max_bits: comptime_int) type {
 
         // Replaces the limbs of `x` with the limbs of `y` if `on` is `true`.
         fn cmov(x: *Self, on: bool, y: Self) void {
-            const x_limbs = x.limbs.slice();
-            const y_limbs = y.limbs.constSlice();
-            for (0..y.limbs_count()) |i| {
-                x_limbs[i] = ct.select(on, y_limbs[i], x_limbs[i]);
+            for (x.limbs(), y.limbsConst()) |*x_limb, y_limb| {
+                x_limb.* = ct.select(on, y_limb, x_limb.*);
             }
         }
 
-        // Adds `y` to `x` if `on` is `true`, and returns `true` if the operation overflowed.
+        // Adds `y` to `x` if `on` is `true`, and returns `true` if the
+        // operation overflowed.
         fn conditionalAddWithOverflow(x: *Self, on: bool, y: Self) u1 {
-            assert(x.limbs_count() == y.limbs_count()); // Operands must have the same size.
-            const x_limbs = x.limbs.slice();
-            const y_limbs = y.limbs.constSlice();
-
             var carry: u1 = 0;
-            for (0..x.limbs_count()) |i| {
-                const res = x_limbs[i] + y_limbs[i] + carry;
-                x_limbs[i] = ct.select(on, @as(TLimb, @truncate(res)), x_limbs[i]);
-                carry = @as(u1, @truncate(res >> t_bits));
+            for (x.limbs(), y.limbsConst()) |*x_limb, y_limb| {
+                const res = x_limb.* + y_limb + carry;
+                x_limb.* = ct.select(on, @as(TLimb, @truncate(res)), x_limb.*);
+                carry = @truncate(res >> t_bits);
             }
             return carry;
         }
 
-        // Subtracts `y` from `x` if `on` is `true`, and returns `true` if the operation overflowed.
+        // Subtracts `y` from `x` if `on` is `true`, and returns `true` if the
+        // operation overflowed.
         fn conditionalSubWithOverflow(x: *Self, on: bool, y: Self) u1 {
-            assert(x.limbs_count() == y.limbs_count()); // Operands must have the same size.
-            const x_limbs = x.limbs.slice();
-            const y_limbs = y.limbs.constSlice();
-
             var borrow: u1 = 0;
-            for (0..x.limbs_count()) |i| {
-                const res = x_limbs[i] -% y_limbs[i] -% borrow;
-                x_limbs[i] = ct.select(on, @as(TLimb, @truncate(res)), x_limbs[i]);
-                borrow = @as(u1, @truncate(res >> t_bits));
+            for (x.limbs(), y.limbsConst()) |*x_limb, y_limb| {
+                const res = x_limb.* -% y_limb -% borrow;
+                x_limb.* = ct.select(on, @as(TLimb, @truncate(res)), x_limb.*);
+                borrow = @truncate(res >> t_bits);
             }
             return borrow;
         }
@@ -315,7 +313,7 @@ fn Fe_(comptime bits: comptime_int) type {
 
         // The number of active limbs to represent the field element.
         fn limbs_count(self: Self) usize {
-            return self.v.limbs_count();
+            return self.v.limbs_len;
         }
 
         /// Creates a field element from a primitive.
@@ -398,7 +396,7 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
 
         // Number of active limbs in the modulus.
         fn limbs_count(self: Self) usize {
-            return self.v.limbs_count();
+            return self.v.limbs_len;
         }
 
         /// Actual size of the modulus, in bits.
@@ -409,7 +407,7 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
         /// Returns the element `1`.
         pub fn one(self: Self) Fe {
             var fe = self.zero;
-            fe.v.limbs.set(0, 1);
+            fe.v.limbs()[0] = 1;
             return fe;
         }
 
@@ -419,10 +417,10 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
             if (!v_.isOdd()) return error.EvenModulus;
 
             var v = v_.normalize();
-            const hi = v.limbs.get(v.limbs_count() - 1);
-            const lo = v.limbs.get(0);
+            const hi = v.limbsConst()[v.limbs_len - 1];
+            const lo = v.limbsConst()[0];
 
-            if (v.limbs_count() < 2 and lo < 3) {
+            if (v.limbs_len < 2 and lo < 3) {
                 return error.ModulusTooSmall;
             }
 
@@ -481,18 +479,19 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
             const new_len = self.limbs_count();
             if (fe.limbs_count() < new_len) return error.Overflow;
             var acc: Limb = 0;
-            for (fe.v.limbs.constSlice()[new_len..]) |limb| {
+            for (fe.v.limbsConst()[new_len..]) |limb| {
                 acc |= limb;
             }
             if (acc != 0) return error.Overflow;
-            try fe.v.limbs.resize(new_len);
+            if (new_len > fe.v.limbs_buffer.len) return error.Overflow;
+            fe.v.limbs_len = new_len;
         }
 
         // Computes R^2 for the Montgomery representation.
         fn computeRR(self: *Self) void {
             self.rr = self.zero;
             const n = self.rr.limbs_count();
-            self.rr.v.limbs.set(n - 1, 1);
+            self.rr.v.limbs()[n - 1] = 1;
             for ((n - 1)..(2 * n)) |_| {
                 self.shiftIn(&self.rr, 0);
             }
@@ -502,9 +501,9 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
         /// Computes x << t_bits + y (mod m)
         fn shiftIn(self: Self, x: *Fe, y: Limb) void {
             var d = self.zero;
-            const x_limbs = x.v.limbs.slice();
-            const d_limbs = d.v.limbs.slice();
-            const m_limbs = self.v.limbs.constSlice();
+            const x_limbs = x.v.limbs();
+            const d_limbs = d.v.limbs();
+            const m_limbs = self.v.limbsConst();
 
             var need_sub = false;
             var i: usize = t_bits - 1;
@@ -569,18 +568,18 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
         /// Reduces an arbitrary `Uint`, converting it to a field element.
         pub fn reduce(self: Self, x: anytype) Fe {
             var out = self.zero;
-            var i = x.limbs_count() - 1;
+            var i = x.limbs_len - 1;
             if (self.limbs_count() >= 2) {
                 const start = @min(i, self.limbs_count() - 2);
                 var j = start;
                 while (true) : (j -= 1) {
-                    out.v.limbs.set(j, x.limbs.get(i));
+                    out.v.limbs()[j] = x.limbsConst()[i];
                     i -= 1;
                     if (j == 0) break;
                 }
             }
             while (true) : (i -= 1) {
-                self.shiftIn(&out, x.limbs.get(i));
+                self.shiftIn(&out, x.limbsConst()[i]);
                 if (i == 0) break;
             }
             return out;
@@ -591,10 +590,10 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
             assert(d.limbs_count() == y.limbs_count());
             assert(d.limbs_count() == self.limbs_count());
 
-            const a_limbs = x.v.limbs.constSlice();
-            const b_limbs = y.v.limbs.constSlice();
-            const d_limbs = d.v.limbs.slice();
-            const m_limbs = self.v.limbs.constSlice();
+            const a_limbs = x.v.limbsConst();
+            const b_limbs = y.v.limbsConst();
+            const d_limbs = d.v.limbs();
+            const m_limbs = self.v.limbsConst();
 
             var overflow: u1 = 0;
             for (0..self.limbs_count()) |i| {
@@ -685,7 +684,7 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
                         const k: u1 = @truncate(b >> j);
                         if (k != 0) {
                             const t = self.montgomeryMul(out, x_m);
-                            @memcpy(out.v.limbs.slice(), t.v.limbs.constSlice());
+                            @memcpy(out.v.limbs(), t.v.limbsConst());
                         }
                         if (j == 0) break;
                     }
@@ -731,7 +730,7 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
                         }
                         const t1 = self.montgomeryMul(out, t0);
                         if (public) {
-                            @memcpy(out.v.limbs.slice(), t1.v.limbs.constSlice());
+                            @memcpy(out.v.limbs(), t1.v.limbsConst());
                         } else {
                             out.v.cmov(!ct.eql(k, 0), t1.v);
                         }
@@ -790,9 +789,9 @@ pub fn Modulus(comptime max_bits: comptime_int) type {
         pub fn powPublic(self: Self, x: Fe, e: Fe) NullExponentError!Fe {
             var e_normalized = Fe{ .v = e.v.normalize() };
             var buf_: [Fe.encoded_bytes]u8 = undefined;
-            var buf = buf_[0 .. math.divCeil(usize, e_normalized.v.limbs_count() * t_bits, 8) catch unreachable];
+            var buf = buf_[0 .. math.divCeil(usize, e_normalized.v.limbs_len * t_bits, 8) catch unreachable];
             e_normalized.toBytes(buf, .little) catch unreachable;
-            const leading = @clz(e_normalized.v.limbs.get(e_normalized.v.limbs_count() - carry_bits));
+            const leading = @clz(e_normalized.v.limbsConst()[e_normalized.v.limbs_len - carry_bits]);
             buf = buf[0 .. buf.len - leading / 8];
             return self.powWithEncodedPublicExponent(x, buf, .little);
         }
@@ -835,20 +834,16 @@ const ct_protected = struct {
 
     // Compares two big integers in constant time, returning true if x < y.
     fn limbsCmpLt(x: anytype, y: @TypeOf(x)) bool {
-        assert(x.limbs_count() == y.limbs_count());
-        const x_limbs = x.limbs.constSlice();
-        const y_limbs = y.limbs.constSlice();
-
         var c: u1 = 0;
-        for (0..x.limbs_count()) |i| {
-            c = @as(u1, @truncate((x_limbs[i] -% y_limbs[i] -% c) >> t_bits));
+        for (x.limbsConst(), y.limbsConst()) |x_limb, y_limb| {
+            c = @truncate((x_limb -% y_limb -% c) >> t_bits);
         }
-        return @as(bool, @bitCast(c));
+        return c != 0;
     }
 
     // Compares two big integers in constant time, returning true if x >= y.
     fn limbsCmpGeq(x: anytype, y: @TypeOf(x)) bool {
-        return @as(bool, @bitCast(1 - @intFromBool(ct.limbsCmpLt(x, y))));
+        return !ct.limbsCmpLt(x, y);
     }
 
     // Multiplies two limbs and returns the result as a wide limb.