Commit b08924e938

samy007 <samy2014@free.fr>
2025-04-14 20:46:06
std.math.big.int: changed llshr and llshl implementation
1 parent b635b37
Changed files (1)
lib
std
math
lib/std/math/big/int.zig
@@ -17,7 +17,6 @@ const Endian = std.builtin.Endian;
 const Signedness = std.builtin.Signedness;
 const native_endian = builtin.cpu.arch.endian();
 
-
 /// Returns the number of limbs needed to store `scalar`, which must be a
 /// primitive integer value.
 /// Note: A comptime-known upper bound of this value that may be used
@@ -210,7 +209,7 @@ pub const Mutable = struct {
         for (self.limbs[0..self.len]) |limb| {
             std.debug.print("{x} ", .{limb});
         }
-        std.debug.print("capacity={} positive={}\n", .{ self.limbs.len, self.positive });
+        std.debug.print("len={} capacity={} positive={}\n", .{ self.len, self.limbs.len, self.positive });
     }
 
     /// Clones an Mutable and returns a new Mutable with the same value. The new Mutable is a deep copy and
@@ -1104,8 +1103,8 @@ pub const Mutable = struct {
     /// Asserts there is enough memory to fit the result. The upper bound Limb count is
     /// `a.limbs.len + (shift / (@sizeOf(Limb) * 8))`.
     pub fn shiftLeft(r: *Mutable, a: Const, shift: usize) void {
-        llshl(r.limbs, a.limbs, shift);
-        r.normalize(a.limbs.len + (shift / limb_bits) + 1);
+        const new_len = llshl(r.limbs, a.limbs, shift);
+        r.normalize(new_len);
         r.positive = a.positive;
     }
 
@@ -1173,8 +1172,8 @@ pub const Mutable = struct {
 
         // This shift should not be able to overflow, so invoke llshl and normalize manually
         // to avoid the extra required limb.
-        llshl(r.limbs, a.limbs, shift);
-        r.normalize(a.limbs.len + (shift / limb_bits));
+        const new_len = llshl(r.limbs, a.limbs, shift);
+        r.normalize(new_len);
         r.positive = a.positive;
     }
 
@@ -1182,7 +1181,7 @@ pub const Mutable = struct {
     /// r and a may alias.
     ///
     /// Asserts there is enough memory to fit the result. The upper bound Limb count is
-    /// `a.limbs.len - (shift / (@sizeOf(Limb) * 8))`.
+    /// `a.limbs.len - (shift / (@bitSizeOf(Limb)))`.
     pub fn shiftRight(r: *Mutable, a: Const, shift: usize) void {
         const full_limbs_shifted_out = shift / limb_bits;
         const remaining_bits_shifted_out = shift % limb_bits;
@@ -1210,9 +1209,9 @@ pub const Mutable = struct {
             break :nonzero a.limbs[full_limbs_shifted_out] << not_covered != 0;
         };
 
-        llshr(r.limbs, a.limbs, shift);
+        const new_len = llshr(r.limbs, a.limbs, shift);
 
-        r.len = a.limbs.len - full_limbs_shifted_out;
+        r.len = new_len;
         r.positive = a.positive;
         if (nonzero_negative_shiftout) r.addScalar(r.toConst(), -1);
         r.normalize(r.len);
@@ -1971,7 +1970,7 @@ pub const Const = struct {
         for (self.limbs[0..self.limbs.len]) |limb| {
             std.debug.print("{x} ", .{limb});
         }
-        std.debug.print("positive={}\n", .{self.positive});
+        std.debug.print("len={} positive={}\n", .{ self.len, self.positive });
     }
 
     pub fn abs(self: Const) Const {
@@ -2673,7 +2672,7 @@ pub const Managed = struct {
         for (self.limbs[0..self.len()]) |limb| {
             std.debug.print("{x} ", .{limb});
         }
-        std.debug.print("capacity={} positive={}\n", .{ self.limbs.len, self.isPositive() });
+        std.debug.print("len={} capacity={} positive={}\n", .{ self.len(), self.limbs.len, self.isPositive() });
     }
 
     /// Negate the sign.
@@ -3711,68 +3710,114 @@ fn lldiv0p5(quo: []Limb, rem: *Limb, a: []const Limb, b: HalfLimb) void {
     }
 }
 
-fn llshl(r: []Limb, a: []const Limb, shift: usize) void {
-    @setRuntimeSafety(debug_safety);
-    assert(a.len >= 1);
+/// Performs r = a << shift and returns the amount of limbs affected
+///
+/// if a and r overlaps, then r.ptr >= a.ptr is asserted
+/// r must have the capacity to store a << shift
+fn llshl(r: []Limb, a: []const Limb, shift: usize) usize {
+    std.debug.assert(a.len >= 1);
+    if (slicesOverlap(a, r))
+        std.debug.assert(@intFromPtr(r.ptr) >= @intFromPtr(a.ptr));
+
+    if (shift == 0) {
+        if (a.ptr != r.ptr)
+            std.mem.copyBackwards(Limb, r[0..a.len], a);
+        return a.len;
+    }
+    if (shift >= limb_bits) {
+        const limb_shift = shift / limb_bits;
+
+        const affected = llshl(r[limb_shift..], a, shift % limb_bits);
+        @memset(r[0..limb_shift], 0);
+
+        return limb_shift + affected;
+    }
 
-    const interior_limb_shift = @as(Log2Limb, @truncate(shift));
+    // shift is guaranteed to be < limb_bits
+    const bit_shift: Log2Limb = @truncate(shift);
+    const opposite_bit_shift: Log2Limb = @truncate(limb_bits - bit_shift);
 
     // We only need the extra limb if the shift of the last element overflows.
     // This is useful for the implementation of `shiftLeftSat`.
-    if (a[a.len - 1] << interior_limb_shift >> interior_limb_shift != a[a.len - 1]) {
-        assert(r.len >= a.len + (shift / limb_bits) + 1);
+    const overflows = a[a.len - 1] >> opposite_bit_shift != 0;
+    if (overflows) {
+        std.debug.assert(r.len >= a.len + 1);
     } else {
-        assert(r.len >= a.len + (shift / limb_bits));
+        std.debug.assert(r.len >= a.len);
     }
 
-    const limb_shift = shift / limb_bits + 1;
+    var i: usize = a.len;
+    if (overflows) {
+        // r is asserted to be large enough above
+        r[a.len] = a[a.len - 1] >> opposite_bit_shift;
+    }
+    while (i > 1) {
+        i -= 1;
+        r[i] = (a[i - 1] >> opposite_bit_shift) | (a[i] << bit_shift);
+    }
+    r[0] = a[0] << bit_shift;
 
-    var carry: Limb = 0;
-    var i: usize = 0;
-    while (i < a.len) : (i += 1) {
-        const src_i = a.len - i - 1;
-        const dst_i = src_i + limb_shift;
+    return a.len + @intFromBool(overflows);
+}
 
-        const src_digit = a[src_i];
-        r[dst_i] = carry | @call(.always_inline, math.shr, .{
-            Limb,
-            src_digit,
-            limb_bits - @as(Limb, @intCast(interior_limb_shift)),
-        });
-        carry = (src_digit << interior_limb_shift);
+/// Performs r = a >> shift and returns the amount of limbs affected
+///
+/// if a and r overlaps, then r.ptr <= a.ptr is asserted
+/// r must have the capacity to store a >> shift
+///
+/// See tests below for examples of behaviour
+fn llshr(r: []Limb, a: []const Limb, shift: usize) usize {
+    if (slicesOverlap(a, r))
+        std.debug.assert(@intFromPtr(r.ptr) <= @intFromPtr(a.ptr));
+
+    if (a.len == 0) return 0;
+
+    if (shift == 0) {
+        std.debug.assert(r.len >= a.len);
+
+        if (a.ptr != r.ptr)
+            std.mem.copyForwards(Limb, r[0..a.len], a);
+        return a.len;
+    }
+    if (shift >= limb_bits) {
+        if (shift / limb_bits >= a.len) {
+            r[0] = 0;
+            return 1;
+        }
+        return llshr(r, a[shift / limb_bits ..], shift % limb_bits);
     }
 
-    r[limb_shift - 1] = carry;
-    @memset(r[0 .. limb_shift - 1], 0);
-}
+    // shift is guaranteed to be < limb_bits
+    const bit_shift: Log2Limb = @truncate(shift);
+    const opposite_bit_shift: Log2Limb = @truncate(limb_bits - bit_shift);
 
-fn llshr(r: []Limb, a: []const Limb, shift: usize) void {
-    @setRuntimeSafety(debug_safety);
-    assert(a.len >= 1);
-    assert(r.len >= a.len - (shift / limb_bits));
+    // special case, where there is a risk to set r to 0
+    if (a.len == 1) {
+        r[0] = a[0] >> bit_shift;
+        return 1;
+    }
+    if (a.len == 0) {
+        r[0] = 0;
+        return 1;
+    }
 
-    const limb_shift = shift / limb_bits;
-    const interior_limb_shift = @as(Log2Limb, @truncate(shift));
+    // if the most significant limb becomes 0 after the shift
+    const shrink = a[a.len - 1] >> bit_shift == 0;
+    std.debug.assert(r.len >= a.len - @intFromBool(!shrink));
 
     var i: usize = 0;
-    while (i < a.len - limb_shift) : (i += 1) {
-        const dst_i = i;
-        const src_i = dst_i + limb_shift;
-
-        const src_digit = a[src_i];
-        const src_digit_next = if (src_i + 1 < a.len) a[src_i + 1] else 0;
-        const carry = @call(.always_inline, math.shl, .{
-            Limb,
-            src_digit_next,
-            limb_bits - @as(Limb, @intCast(interior_limb_shift)),
-        });
-        r[dst_i] = carry | (src_digit >> interior_limb_shift);
+    while (i < a.len - 1) : (i += 1) {
+        r[i] = (a[i] >> bit_shift) | (a[i + 1] << opposite_bit_shift);
     }
+
+    if (!shrink)
+        r[i] = a[i] >> bit_shift;
+
+    return a.len - @intFromBool(shrink);
 }
 
 // r = ~r
 fn llnot(r: []Limb) void {
-
     for (r) |*elem| {
         elem.* = ~elem.*;
     }
@@ -4107,7 +4152,7 @@ fn llsquareBasecase(r: []Limb, x: []const Limb) void {
     }
 
     // Each product appears twice, multiply by 2
-    llshl(r, r[0 .. 2 * x_norm.len], 1);
+    _ = llshl(r, r[0 .. 2 * x_norm.len], 1);
 
     for (x_norm, 0..) |v, i| {
         // Compute and add the squares