Commit 41e9c1bac1
Changed files (1)
lib
std
math
big
lib/std/math/big/int.zig
@@ -658,7 +658,7 @@ pub const Mutable = struct {
mem.set(Limb, rma.limbs[0 .. a.limbs.len + b.limbs.len + 1], 0);
- llmulacc(.add, allocator, rma.limbs, a.limbs, b.limbs);
+ _ = llmulacc(.add, allocator, rma.limbs, a.limbs, b.limbs);
rma.normalize(a.limbs.len + b.limbs.len);
rma.positive = (a.positive == b.positive);
@@ -2365,9 +2365,12 @@ const AccOp = enum {
///
/// r = r (op) a * b
/// r MUST NOT alias any of a or b.
+///
+/// The result is computed modulo `r.len`. When `r.len >= a.len + b.len`, no overflow occurs.
fn llmulacc(comptime op: AccOp, opt_allocator: ?*Allocator, r: []Limb, a: []const Limb, b: []const Limb) void {
@setRuntimeSafety(debug_safety);
- assert(r.len >= a.len + b.len);
+ assert(r.len >= a.len);
+ assert(r.len >= b.len);
// Order greatest first.
var x = a;
@@ -2395,6 +2398,8 @@ fn llmulacc(comptime op: AccOp, opt_allocator: ?*Allocator, r: []Limb, a: []cons
///
/// r = r (op) a * b
/// r MUST NOT alias any of a or b.
+///
+/// The result is computed modulo `r.len`. When `r.len >= a.len + b.len`, no overflow occurs.
fn llmulaccKaratsuba(
comptime op: AccOp,
allocator: *Allocator,
@@ -2403,7 +2408,7 @@ fn llmulaccKaratsuba(
b: []const Limb,
) error{OutOfMemory}!void {
@setRuntimeSafety(debug_safety);
- assert(r.len >= a.len + b.len);
+ assert(r.len >= a.len);
assert(a.len >= b.len);
// Classical karatsuba algorithm:
@@ -2437,49 +2442,84 @@ fn llmulaccKaratsuba(
//
// Note, when B is a multiple of the limb size, multiplies by B amount to shifts or
// slices of a limbs array.
+ //
+ // This function computes the result of the multiplication modulo r.len. This means:
+ // - p2 and p1 only need to be computed modulo r.len - B.
+ // - In the case of p2, p2 * B^2 needs to be added modulo r.len - 2 * B.
const split = b.len / 2; // B
+
+ const limbs_after_split = r.len - split; // Limbs to compute for p1 and p2.
+ const limbs_after_split2 = r.len - split * 2; // Limbs to add for p2 * B^2.
+
+ // For a0 and b0 we need the full range.
const a0 = a[0..llnormalize(a[0..split])];
- const a1 = a[split..][0..llnormalize(a[split..])];
const b0 = b[0..llnormalize(b[0..split])];
- const b1 = b[split..][0..llnormalize(b[split..])];
- // Note that the above slices work because we have a.len > b.len.
- // We now also have:
- // a1.len >= a0.len
- // a1.len >= b1.len >= b0.len
- // a0.len == b0.len
+ // For a1 and b1 we only need `limbs_after_split` limbs.
+ const a1 = blk: {
+ var a1 = a[split..];
+ a1.len = math.min(llnormalize(a1), limbs_after_split);
+ break :blk a1;
+ };
+
+ const b1 = blk: {
+ var b1 = b[split..];
+ b1.len = math.min(llnormalize(b1), limbs_after_split);
+ break :blk b1;
+ };
+
+ // Note that the above slices relative to `split` work because we have a.len > b.len.
// We need some temporary memory to store intermediate results.
// Note, we can reduce the amount of temporaries we need by reordering the computation here:
// ab = p2 * B^2 + (p0 + p1 + p2) * B + p0
// = p2 * B^2 + (p0 * B + p1 * B + p2 * B) + p0
// = (p2 * B^2 + p2 * B) + (p0 * B + p0) + p1 * B
- // By allocating a1.len * b1.len we can be sure that all the intermediary results fit.
+
+ // Allocate at least enough memory to be able to multiply the upper two segments of a and b, assuming
+ // no overflow.
const tmp = try allocator.alloc(Limb, a.len - split + b.len - split);
defer allocator.free(tmp);
// Compute p2.
- mem.set(Limb, tmp, 0);
- llmulacc(.add, allocator, tmp, a1, b1);
- const p2 = tmp[0 .. llnormalize(tmp)];
+ // Note, we don't need to compute all of p2, just enough limbs to satisfy r.
+ const p2_limbs = math.min(limbs_after_split, a1.len + b1.len);
- // Add terms p2 * B^2 and p2 * B to the result.
- _ = llaccum(op, r[split..], p2);
- _ = llaccum(op, r[split * 2..], p2);
+ mem.set(Limb, tmp[0..p2_limbs], 0);
+ llmulacc(.add, allocator, tmp[0..p2_limbs], a1[0..math.min(a1.len, p2_limbs)], b1[0..math.min(b1.len, p2_limbs)]);
+ const p2 = tmp[0 .. llnormalize(tmp[0..p2_limbs])];
+
+ // Add p2 * B to the result.
+ llaccum(op, r[split..], p2);
+
+ // Add p2 * B^2 to the result if required.
+ if (limbs_after_split2 > 0) {
+ llaccum(op, r[split * 2..], p2[0..math.min(p2.len, limbs_after_split2)]);
+ }
// Compute p0.
- mem.set(Limb, p2, 0);
- llmulacc(.add, allocator, tmp, a0, b0);
- const p0 = tmp[0 .. llnormalize(tmp[0..a0.len + b0.len])];
+ // Since a0.len, b0.len <= split and r.len >= split * 2, the full width of p0 needs to be computed.
+ const p0_limbs = a0.len + b0.len;
+ mem.set(Limb, tmp[0..p0_limbs], 0);
+ llmulacc(.add, allocator, tmp[0..p0_limbs], a0, b0);
+ const p0 = tmp[0 .. llnormalize(tmp[0..p0_limbs])];
+
+ // Add p0 to the result.
+ llaccum(op, r, p0);
+
+ // Add p0 * B to the result. In this case, we may not need all of it.
+ llaccum(op, r[split..], p0[0..math.min(limbs_after_split, p0.len)]);
- // Add terms p0 * B and p0 to the result.
- _ = llaccum(op, r, p0);
- _ = llaccum(op, r[split..], p0);
// Finally, compute and add p1.
- const j0_sign = llcmp(a0, a1);
- const j1_sign = llcmp(b1, b0);
+ // From now on we only need `limbs_after_split` limbs for a0 and b0, since the result of the
+ // following computation will be added * B.
+ const a0x = a0[0..std.math.min(a0.len, limbs_after_split)];
+ const b0x = b0[0..std.math.min(b0.len, limbs_after_split)];
+
+ const j0_sign = llcmp(a0x, a1);
+ const j1_sign = llcmp(b1, b0x);
if (j0_sign * j1_sign == 0) {
// p1 is zero, we don't need to do any computation at all.
@@ -2492,24 +2532,24 @@ fn llmulaccKaratsuba(
// Note that in this case, we again need some storage for intermediary results
// j0 and j1. Since we have tmp.len >= 2B, we can store both
// intermediaries in the already allocated array.
- const j0 = tmp[0..a1.len];
- const j1 = tmp[a1.len..];
+ const j0 = tmp[0..a.len - split];
+ const j1 = tmp[a.len - split..];
// Ensure that no subtraction overflows.
if (j0_sign == 1) {
// a0 > a1.
- _ = llsubcarry(j0, a0, a1);
+ _ = llsubcarry(j0, a0x, a1);
} else {
// a0 < a1.
- _ = llsubcarry(j0, a1, a0);
+ _ = llsubcarry(j0, a1, a0x);
}
if (j1_sign == 1) {
// b1 > b0.
- _ = llsubcarry(j1, b1, b0);
+ _ = llsubcarry(j1, b1, b0x);
} else {
// b1 > b0.
- _ = llsubcarry(j1, b0, b1);
+ _ = llsubcarry(j1, b0x, b1);
}
if (j0_sign * j1_sign == 1) {
@@ -2528,11 +2568,13 @@ fn llmulaccKaratsuba(
}
}
-// r = r (op) a
-fn llaccum(comptime op: AccOp, r: []Limb, a: []const Limb) Limb {
+/// r = r (op) a.
+/// The result is computed modulo `r.len`.
+fn llaccum(comptime op: AccOp, r: []Limb, a: []const Limb) void {
@setRuntimeSafety(debug_safety);
if (op == .sub) {
- return llsubcarry(r, r, a);
+ _ = llsubcarry(r, r, a);
+ return;
}
assert(r.len != 0 and a.len != 0);
@@ -2551,8 +2593,6 @@ fn llaccum(comptime op: AccOp, r: []Limb, a: []const Limb) Limb {
while ((carry != 0) and i < r.len) : (i += 1) {
carry = @boolToInt(@addWithOverflow(Limb, r[i], carry, &r[i]));
}
-
- return carry;
}
/// Returns -1, 0, 1 if |a| < |b|, |a| == |b| or |a| > |b| respectively for limbs.
@@ -2583,19 +2623,21 @@ pub fn llcmp(a: []const Limb, b: []const Limb) i8 {
}
}
-// r = r (op) y * xi
+/// r = r (op) y * xi
+/// The result is computed modulo `r.len`. When `r.len >= a.len + b.len`, no overflow occurs.
fn llmulaccLong(comptime op: AccOp, r: []Limb, a: []const Limb, b: []const Limb) void {
@setRuntimeSafety(debug_safety);
assert(r.len >= a.len + b.len);
assert(a.len >= b.len);
var i: usize = 0;
- while (i < a.len) : (i += 1) {
- llmulLimb(op, r[i..], b, a[i]);
+ while (i < b.len) : (i += 1) {
+ llmulLimb(op, r[i..], a, b[i]);
}
}
-// r = r (op) y * xi
+/// r = r (op) y * xi
+/// The result is computed modulo `r.len`.
fn llmulLimb(comptime op: AccOp, acc: []Limb, y: []const Limb, xi: Limb) void {
@setRuntimeSafety(debug_safety);
if (xi == 0) {