Commit a31b70c4b8

LemonBoy <thatlemon@gmail.com>
2020-10-10 00:46:53
std: Add/Fix/Change parts of big.int
* Add an optimized squaring routine under the `sqr` name. Algorithms for squaring bigger numbers efficiently will come in a PR later. * Fix a bug where a multiplication was done twice if the threshold for the use of Karatsuba algorithm was crossed. Add a test to make sure this won't happen again. * Streamline `pow` method, take a `Const` parameter. * Minor tweaks to `pow`, avoid bit-reversing the exponent.
1 parent fbc6a00
Changed files (2)
lib
std
lib/std/math/big/int.zig
@@ -446,6 +446,26 @@ pub const Mutable = struct {
         rma.positive = (a.positive == b.positive);
     }
 
+    /// rma = a * a
+    ///
+    /// `rma` may not alias with `a`.
+    ///
+    /// Asserts the result fits in `rma`. An upper bound on the number of limbs needed by
+    /// rma is given by `2 * a.limbs.len + 1`.
+    ///
+    /// If `allocator` is provided, it will be used for temporary storage to improve
+    /// multiplication performance. `error.OutOfMemory` is handled with a fallback algorithm.
+    pub fn sqrNoAlias(rma: *Mutable, a: Const, opt_allocator: ?*Allocator) void {
+        assert(rma.limbs.ptr != a.limbs.ptr); // illegal aliasing
+
+        mem.set(Limb, rma.limbs, 0);
+
+        llsquare_basecase(rma.limbs, a.limbs);
+
+        rma.normalize(2 * a.limbs.len + 1);
+        rma.positive = true;
+    }
+
     /// q = a / b (rem r)
     ///
     /// a / b are floored (rounded towards 0).
@@ -1827,7 +1847,28 @@ pub const Managed = struct {
         rma.setMetadata(m.positive, m.len);
     }
 
-    pub fn pow(rma: *Managed, a: Managed, b: u32) !void {
+    /// r = a * a
+    pub fn sqr(rma: *Managed, a: Const) !void {
+        const needed_limbs = 2 * a.limbs.len + 1;
+
+        if (rma.limbs.ptr == a.limbs.ptr) {
+            var m = try Managed.initCapacity(rma.allocator, needed_limbs);
+            errdefer m.deinit();
+            var m_mut = m.toMutable();
+            m_mut.sqrNoAlias(a, rma.allocator);
+            m.setMetadata(m_mut.positive, m_mut.len);
+
+            rma.deinit();
+            rma.swap(&m);
+        } else {
+            try rma.ensureCapacity(needed_limbs);
+            var rma_mut = rma.toMutable();
+            rma_mut.sqrNoAlias(a, rma.allocator);
+            rma.setMetadata(rma_mut.positive, rma_mut.len);
+        }
+    }
+
+    pub fn pow(rma: *Managed, a: Const, b: u32) !void {
         const needed_limbs = calcPowLimbsBufferLen(a.bitCountAbs(), b);
 
         const limbs_buffer = try rma.allocator.alloc(Limb, needed_limbs);
@@ -1837,7 +1878,7 @@ pub const Managed = struct {
             var m = try Managed.initCapacity(rma.allocator, needed_limbs);
             errdefer m.deinit();
             var m_mut = m.toMutable();
-            try m_mut.pow(a.toConst(), b, limbs_buffer);
+            try m_mut.pow(a, b, limbs_buffer);
             m.setMetadata(m_mut.positive, m_mut.len);
 
             rma.deinit();
@@ -1845,7 +1886,7 @@ pub const Managed = struct {
         } else {
             try rma.ensureCapacity(needed_limbs);
             var rma_mut = rma.toMutable();
-            try rma_mut.pow(a.toConst(), b, limbs_buffer);
+            try rma_mut.pow(a, b, limbs_buffer);
             rma.setMetadata(rma_mut.positive, rma_mut.len);
         }
     }
@@ -1869,11 +1910,14 @@ fn llmulacc(opt_allocator: ?*Allocator, r: []Limb, a: []const Limb, b: []const L
     assert(r.len >= x.len + y.len + 1);
 
     // 48 is a pretty abitrary size chosen based on performance of a factorial program.
-    if (x.len > 48) {
-        if (opt_allocator) |allocator| {
-            llmulacc_karatsuba(allocator, r, x, y) catch |err| switch (err) {
-                error.OutOfMemory => {}, // handled below
-            };
+    k_mul: {
+        if (x.len > 48) {
+            if (opt_allocator) |allocator| {
+                llmulacc_karatsuba(allocator, r, x, y) catch |err| switch (err) {
+                    error.OutOfMemory => break :k_mul, // handled below
+                };
+                return;
+            }
         }
     }
 
@@ -2203,6 +2247,42 @@ fn llxor(r: []Limb, a: []const Limb, b: []const Limb) void {
     }
 }
 
+/// r MUST NOT alias x.
+fn llsquare_basecase(r: []Limb, x: []const Limb) void {
+    @setRuntimeSafety(debug_safety);
+
+    const x_norm = x;
+    assert(r.len >= 2 * x_norm.len + 1);
+
+    // Compute the square of a N-limb bigint with only (N^2 + N)/2
+    // multiplications by exploting the symmetry of the coefficients around the
+    // diagonal:
+    //
+    //           a   b   c *
+    //           a   b   c =
+    // -------------------
+    //          ca  cb  cc +
+    //      ba  bb  bc     +
+    //  aa  ab  ac
+    //
+    // Note that:
+    //  - Each mixed-product term appears twice for each column,
+    //  - Squares are always in the 2k (0 <= k < N) column
+
+    for (x_norm) |v, i| {
+        // Accumulate all the x[i]*x[j] (with x!=j) products
+        llmulDigit(r[2 * i + 1 ..], x_norm[i + 1 ..], v);
+    }
+
+    // Each product appears twice, multiply by 2
+    llshl(r, r[0 .. 2 * x_norm.len], 1);
+
+    for (x_norm) |v, i| {
+        // Compute and add the squares
+        llmulDigit(r[2 * i ..], x[i .. i + 1], v);
+    }
+}
+
 /// Knuth 4.6.3
 fn llpow(r: []Limb, a: []const Limb, b: u32, tmp_limbs: []Limb) void {
     var tmp1: []Limb = undefined;
@@ -2212,9 +2292,9 @@ fn llpow(r: []Limb, a: []const Limb, b: u32, tmp_limbs: []Limb) void {
     // variable, use the output limbs and another temporary set to overcome this
     // limitation.
     // The initial assignment makes the result end in `r` so an extra memory
-    // copy is saved, each 1 flips the index twice so it's a no-op so count the
-    // 0.
-    const b_leading_zeros = @intCast(u5, @clz(u32, b));
+    // copy is saved, each 1 flips the index twice so it's only the zeros that
+    // matter.
+    const b_leading_zeros = @clz(u32, b);
     const exp_zeros = @popCount(u32, ~b) - b_leading_zeros;
     if (exp_zeros & 1 != 0) {
         tmp1 = tmp_limbs;
@@ -2224,32 +2304,28 @@ fn llpow(r: []Limb, a: []const Limb, b: u32, tmp_limbs: []Limb) void {
         tmp2 = tmp_limbs;
     }
 
-    const a_norm = a[0..llnormalize(a)];
-
-    mem.copy(Limb, tmp1, a_norm);
-    mem.set(Limb, tmp1[a_norm.len..], 0);
+    mem.copy(Limb, tmp1, a);
+    mem.set(Limb, tmp1[a.len..], 0);
 
     // Scan the exponent as a binary number, from left to right, dropping the
     // most significant bit set.
-    const exp_bits = @intCast(u5, 31 - b_leading_zeros);
-    var exp = @bitReverse(u32, b) >> 1 + b_leading_zeros;
+    // Square the result if the current bit is zero, square and multiply by a if
+    // it is one.
+    var exp_bits = 32 - 1 - b_leading_zeros;
+    var exp = b << @intCast(u5, 1 + b_leading_zeros);
 
-    var i: u5 = 0;
+    var i: usize = 0;
     while (i < exp_bits) : (i += 1) {
         // Square
-        {
-            mem.set(Limb, tmp2, 0);
-            const op = tmp1[0..llnormalize(tmp1)];
-            llmulacc(null, tmp2, op, op);
-            mem.swap([]Limb, &tmp1, &tmp2);
-        }
+        mem.set(Limb, tmp2, 0);
+        llsquare_basecase(tmp2, tmp1[0..llnormalize(tmp1)]);
+        mem.swap([]Limb, &tmp1, &tmp2);
         // Multiply by a
-        if (exp & 1 != 0) {
+        if (@shlWithOverflow(u32, exp, 1, &exp)) {
             mem.set(Limb, tmp2, 0);
-            llmulacc(null, tmp2, tmp1[0..llnormalize(tmp1)], a_norm);
+            llmulacc(null, tmp2, tmp1[0..llnormalize(tmp1)], a);
             mem.swap([]Limb, &tmp1, &tmp2);
         }
-        exp >>= 1;
     }
 }
 
lib/std/math/big/int_test.zig
@@ -720,6 +720,27 @@ test "big.int mul 0*0" {
     testing.expect((try c.to(u32)) == 0);
 }
 
+test "big.int mul large" {
+    var a = try Managed.initCapacity(testing.allocator, 50);
+    defer a.deinit();
+    var b = try Managed.initCapacity(testing.allocator, 100);
+    defer b.deinit();
+    var c = try Managed.initCapacity(testing.allocator, 100);
+    defer c.deinit();
+
+    // Generate a number that's large enough to cross the thresholds for the use
+    // of subquadratic algorithms
+    for (a.limbs) |*p| {
+        p.* = std.math.maxInt(Limb);
+    }
+    a.setMetadata(true, 50);
+
+    try b.mul(a.toConst(), a.toConst());
+    try c.sqr(a.toConst());
+
+    testing.expect(b.eq(c));
+}
+
 test "big.int div single-single no rem" {
     var a = try Managed.initSet(testing.allocator, 50);
     defer a.deinit();
@@ -1483,11 +1504,14 @@ test "big.int const to managed" {
 
 test "big.int pow" {
     {
-        var a = try Managed.initSet(testing.allocator, 10);
+        var a = try Managed.initSet(testing.allocator, -3);
         defer a.deinit();
 
-        try a.pow(a, 8);
-        testing.expectEqual(@as(u32, 100000000), try a.to(u32));
+        try a.pow(a.toConst(), 3);
+        testing.expectEqual(@as(i32, -27), try a.to(i32));
+
+        try a.pow(a.toConst(), 4);
+        testing.expectEqual(@as(i32, 531441), try a.to(i32));
     }
     {
         var a = try Managed.initSet(testing.allocator, 10);
@@ -1497,9 +1521,9 @@ test "big.int pow" {
         defer y.deinit();
 
         // y and a are not aliased
-        try y.pow(a, 123);
+        try y.pow(a.toConst(), 123);
         // y and a are aliased
-        try a.pow(a, 123);
+        try a.pow(a.toConst(), 123);
 
         testing.expect(a.eq(y));
 
@@ -1517,18 +1541,18 @@ test "big.int pow" {
         var a = try Managed.initSet(testing.allocator, 0);
         defer a.deinit();
 
-        try a.pow(a, 100);
+        try a.pow(a.toConst(), 100);
         testing.expectEqual(@as(i32, 0), try a.to(i32));
 
         try a.set(1);
-        try a.pow(a, 0);
+        try a.pow(a.toConst(), 0);
         testing.expectEqual(@as(i32, 1), try a.to(i32));
-        try a.pow(a, 100);
+        try a.pow(a.toConst(), 100);
         testing.expectEqual(@as(i32, 1), try a.to(i32));
         try a.set(-1);
-        try a.pow(a, 15);
+        try a.pow(a.toConst(), 15);
         testing.expectEqual(@as(i32, -1), try a.to(i32));
-        try a.pow(a, 16);
+        try a.pow(a.toConst(), 16);
         testing.expectEqual(@as(i32, 1), try a.to(i32));
     }
 }