Commit a31b70c4b8
Changed files (2)
lib
std
math
lib/std/math/big/int.zig
@@ -446,6 +446,26 @@ pub const Mutable = struct {
rma.positive = (a.positive == b.positive);
}
+ /// rma = a * a
+ ///
+ /// `rma` may not alias with `a`.
+ ///
+ /// Asserts the result fits in `rma`. An upper bound on the number of limbs needed by
+ /// rma is given by `2 * a.limbs.len + 1`.
+ ///
+ /// If `allocator` is provided, it will be used for temporary storage to improve
+ /// multiplication performance. `error.OutOfMemory` is handled with a fallback algorithm.
+ pub fn sqrNoAlias(rma: *Mutable, a: Const, opt_allocator: ?*Allocator) void {
+ assert(rma.limbs.ptr != a.limbs.ptr); // illegal aliasing
+
+ mem.set(Limb, rma.limbs, 0);
+
+ llsquare_basecase(rma.limbs, a.limbs);
+
+ rma.normalize(2 * a.limbs.len + 1);
+ rma.positive = true;
+ }
+
/// q = a / b (rem r)
///
/// a / b are floored (rounded towards 0).
@@ -1827,7 +1847,28 @@ pub const Managed = struct {
rma.setMetadata(m.positive, m.len);
}
- pub fn pow(rma: *Managed, a: Managed, b: u32) !void {
+ /// r = a * a
+ pub fn sqr(rma: *Managed, a: Const) !void {
+ const needed_limbs = 2 * a.limbs.len + 1;
+
+ if (rma.limbs.ptr == a.limbs.ptr) {
+ var m = try Managed.initCapacity(rma.allocator, needed_limbs);
+ errdefer m.deinit();
+ var m_mut = m.toMutable();
+ m_mut.sqrNoAlias(a, rma.allocator);
+ m.setMetadata(m_mut.positive, m_mut.len);
+
+ rma.deinit();
+ rma.swap(&m);
+ } else {
+ try rma.ensureCapacity(needed_limbs);
+ var rma_mut = rma.toMutable();
+ rma_mut.sqrNoAlias(a, rma.allocator);
+ rma.setMetadata(rma_mut.positive, rma_mut.len);
+ }
+ }
+
+ pub fn pow(rma: *Managed, a: Const, b: u32) !void {
const needed_limbs = calcPowLimbsBufferLen(a.bitCountAbs(), b);
const limbs_buffer = try rma.allocator.alloc(Limb, needed_limbs);
@@ -1837,7 +1878,7 @@ pub const Managed = struct {
var m = try Managed.initCapacity(rma.allocator, needed_limbs);
errdefer m.deinit();
var m_mut = m.toMutable();
- try m_mut.pow(a.toConst(), b, limbs_buffer);
+ try m_mut.pow(a, b, limbs_buffer);
m.setMetadata(m_mut.positive, m_mut.len);
rma.deinit();
@@ -1845,7 +1886,7 @@ pub const Managed = struct {
} else {
try rma.ensureCapacity(needed_limbs);
var rma_mut = rma.toMutable();
- try rma_mut.pow(a.toConst(), b, limbs_buffer);
+ try rma_mut.pow(a, b, limbs_buffer);
rma.setMetadata(rma_mut.positive, rma_mut.len);
}
}
@@ -1869,11 +1910,14 @@ fn llmulacc(opt_allocator: ?*Allocator, r: []Limb, a: []const Limb, b: []const L
assert(r.len >= x.len + y.len + 1);
// 48 is a pretty abitrary size chosen based on performance of a factorial program.
- if (x.len > 48) {
- if (opt_allocator) |allocator| {
- llmulacc_karatsuba(allocator, r, x, y) catch |err| switch (err) {
- error.OutOfMemory => {}, // handled below
- };
+ k_mul: {
+ if (x.len > 48) {
+ if (opt_allocator) |allocator| {
+ llmulacc_karatsuba(allocator, r, x, y) catch |err| switch (err) {
+ error.OutOfMemory => break :k_mul, // handled below
+ };
+ return;
+ }
}
}
@@ -2203,6 +2247,42 @@ fn llxor(r: []Limb, a: []const Limb, b: []const Limb) void {
}
}
+/// r MUST NOT alias x.
+fn llsquare_basecase(r: []Limb, x: []const Limb) void {
+ @setRuntimeSafety(debug_safety);
+
+ const x_norm = x;
+ assert(r.len >= 2 * x_norm.len + 1);
+
+ // Compute the square of a N-limb bigint with only (N^2 + N)/2
+ // multiplications by exploting the symmetry of the coefficients around the
+ // diagonal:
+ //
+ // a b c *
+ // a b c =
+ // -------------------
+ // ca cb cc +
+ // ba bb bc +
+ // aa ab ac
+ //
+ // Note that:
+ // - Each mixed-product term appears twice for each column,
+ // - Squares are always in the 2k (0 <= k < N) column
+
+ for (x_norm) |v, i| {
+ // Accumulate all the x[i]*x[j] (with x!=j) products
+ llmulDigit(r[2 * i + 1 ..], x_norm[i + 1 ..], v);
+ }
+
+ // Each product appears twice, multiply by 2
+ llshl(r, r[0 .. 2 * x_norm.len], 1);
+
+ for (x_norm) |v, i| {
+ // Compute and add the squares
+ llmulDigit(r[2 * i ..], x[i .. i + 1], v);
+ }
+}
+
/// Knuth 4.6.3
fn llpow(r: []Limb, a: []const Limb, b: u32, tmp_limbs: []Limb) void {
var tmp1: []Limb = undefined;
@@ -2212,9 +2292,9 @@ fn llpow(r: []Limb, a: []const Limb, b: u32, tmp_limbs: []Limb) void {
// variable, use the output limbs and another temporary set to overcome this
// limitation.
// The initial assignment makes the result end in `r` so an extra memory
- // copy is saved, each 1 flips the index twice so it's a no-op so count the
- // 0.
- const b_leading_zeros = @intCast(u5, @clz(u32, b));
+ // copy is saved, each 1 flips the index twice so it's only the zeros that
+ // matter.
+ const b_leading_zeros = @clz(u32, b);
const exp_zeros = @popCount(u32, ~b) - b_leading_zeros;
if (exp_zeros & 1 != 0) {
tmp1 = tmp_limbs;
@@ -2224,32 +2304,28 @@ fn llpow(r: []Limb, a: []const Limb, b: u32, tmp_limbs: []Limb) void {
tmp2 = tmp_limbs;
}
- const a_norm = a[0..llnormalize(a)];
-
- mem.copy(Limb, tmp1, a_norm);
- mem.set(Limb, tmp1[a_norm.len..], 0);
+ mem.copy(Limb, tmp1, a);
+ mem.set(Limb, tmp1[a.len..], 0);
// Scan the exponent as a binary number, from left to right, dropping the
// most significant bit set.
- const exp_bits = @intCast(u5, 31 - b_leading_zeros);
- var exp = @bitReverse(u32, b) >> 1 + b_leading_zeros;
+ // Square the result if the current bit is zero, square and multiply by a if
+ // it is one.
+ var exp_bits = 32 - 1 - b_leading_zeros;
+ var exp = b << @intCast(u5, 1 + b_leading_zeros);
- var i: u5 = 0;
+ var i: usize = 0;
while (i < exp_bits) : (i += 1) {
// Square
- {
- mem.set(Limb, tmp2, 0);
- const op = tmp1[0..llnormalize(tmp1)];
- llmulacc(null, tmp2, op, op);
- mem.swap([]Limb, &tmp1, &tmp2);
- }
+ mem.set(Limb, tmp2, 0);
+ llsquare_basecase(tmp2, tmp1[0..llnormalize(tmp1)]);
+ mem.swap([]Limb, &tmp1, &tmp2);
// Multiply by a
- if (exp & 1 != 0) {
+ if (@shlWithOverflow(u32, exp, 1, &exp)) {
mem.set(Limb, tmp2, 0);
- llmulacc(null, tmp2, tmp1[0..llnormalize(tmp1)], a_norm);
+ llmulacc(null, tmp2, tmp1[0..llnormalize(tmp1)], a);
mem.swap([]Limb, &tmp1, &tmp2);
}
- exp >>= 1;
}
}
lib/std/math/big/int_test.zig
@@ -720,6 +720,27 @@ test "big.int mul 0*0" {
testing.expect((try c.to(u32)) == 0);
}
+test "big.int mul large" {
+ var a = try Managed.initCapacity(testing.allocator, 50);
+ defer a.deinit();
+ var b = try Managed.initCapacity(testing.allocator, 100);
+ defer b.deinit();
+ var c = try Managed.initCapacity(testing.allocator, 100);
+ defer c.deinit();
+
+ // Generate a number that's large enough to cross the thresholds for the use
+ // of subquadratic algorithms
+ for (a.limbs) |*p| {
+ p.* = std.math.maxInt(Limb);
+ }
+ a.setMetadata(true, 50);
+
+ try b.mul(a.toConst(), a.toConst());
+ try c.sqr(a.toConst());
+
+ testing.expect(b.eq(c));
+}
+
test "big.int div single-single no rem" {
var a = try Managed.initSet(testing.allocator, 50);
defer a.deinit();
@@ -1483,11 +1504,14 @@ test "big.int const to managed" {
test "big.int pow" {
{
- var a = try Managed.initSet(testing.allocator, 10);
+ var a = try Managed.initSet(testing.allocator, -3);
defer a.deinit();
- try a.pow(a, 8);
- testing.expectEqual(@as(u32, 100000000), try a.to(u32));
+ try a.pow(a.toConst(), 3);
+ testing.expectEqual(@as(i32, -27), try a.to(i32));
+
+ try a.pow(a.toConst(), 4);
+ testing.expectEqual(@as(i32, 531441), try a.to(i32));
}
{
var a = try Managed.initSet(testing.allocator, 10);
@@ -1497,9 +1521,9 @@ test "big.int pow" {
defer y.deinit();
// y and a are not aliased
- try y.pow(a, 123);
+ try y.pow(a.toConst(), 123);
// y and a are aliased
- try a.pow(a, 123);
+ try a.pow(a.toConst(), 123);
testing.expect(a.eq(y));
@@ -1517,18 +1541,18 @@ test "big.int pow" {
var a = try Managed.initSet(testing.allocator, 0);
defer a.deinit();
- try a.pow(a, 100);
+ try a.pow(a.toConst(), 100);
testing.expectEqual(@as(i32, 0), try a.to(i32));
try a.set(1);
- try a.pow(a, 0);
+ try a.pow(a.toConst(), 0);
testing.expectEqual(@as(i32, 1), try a.to(i32));
- try a.pow(a, 100);
+ try a.pow(a.toConst(), 100);
testing.expectEqual(@as(i32, 1), try a.to(i32));
try a.set(-1);
- try a.pow(a, 15);
+ try a.pow(a.toConst(), 15);
testing.expectEqual(@as(i32, -1), try a.to(i32));
- try a.pow(a, 16);
+ try a.pow(a.toConst(), 16);
testing.expectEqual(@as(i32, 1), try a.to(i32));
}
}