Commit 5907b3e383
Changed files (1)
lib
std
math
big
lib/std/math/big/int.zig
@@ -86,6 +86,24 @@ pub fn addMulLimbWithCarry(a: Limb, b: Limb, c: Limb, carry: *Limb) Limb {
return r1;
}
+/// a - b * c - *carry, sets carry to the overflow bits
+fn subMulLimbWithBorrow(a: Limb, b: Limb, c: Limb, carry: *Limb) Limb {
+ // r1 = a - *carry
+ var r1: Limb = undefined;
+ const c1: Limb = @boolToInt(@subWithOverflow(Limb, a, carry.*, &r1));
+
+ // r2 = b * c
+ const bc = @as(DoubleLimb, std.math.mulWide(Limb, b, c));
+ const r2 = @truncate(Limb, bc);
+ const c2 = @truncate(Limb, bc >> limb_bits);
+
+ // r1 = r1 - r2
+ const c3: Limb = @boolToInt(@subWithOverflow(Limb, r1, r2, &r1));
+ carry.* = c1 + c2 + c3;
+
+ return r1;
+}
+
/// Used to indicate either limit of a 2s-complement integer.
pub const TwosCompIntLimit = enum {
// The low limit, either 0x00 (unsigned) or (-)0x80 (signed) for an 8-bit integer.
@@ -640,7 +658,7 @@ pub const Mutable = struct {
mem.set(Limb, rma.limbs[0 .. a.limbs.len + b.limbs.len + 1], 0);
- llmulacc(allocator, rma.limbs, a.limbs, b.limbs);
+ llmulacc(.add, allocator, rma.limbs, a.limbs, b.limbs);
rma.normalize(a.limbs.len + b.limbs.len);
rma.positive = (a.positive == b.positive);
@@ -665,9 +683,9 @@ pub const Mutable = struct {
mem.set(Limb, rma.limbs[0..req_limbs], 0);
if (a_limbs.len >= b_limbs.len) {
- llmulacc_lo(rma.limbs, a_limbs, b_limbs);
+ llmulaccLow(rma.limbs, a_limbs, b_limbs);
} else {
- llmulacc_lo(rma.limbs, b_limbs, a_limbs);
+ llmulaccLow(rma.limbs, b_limbs, a_limbs);
}
rma.normalize(math.min(req_limbs, a.limbs.len + b.limbs.len));
@@ -691,7 +709,7 @@ pub const Mutable = struct {
mem.set(Limb, rma.limbs, 0);
- llsquare_basecase(rma.limbs, a.limbs);
+ llsquareBasecase(rma.limbs, a.limbs);
rma.normalize(2 * a.limbs.len + 1);
rma.positive = true;
@@ -1219,7 +1237,7 @@ pub const Mutable = struct {
/// Asserts `r` has enough storage to store the result.
/// The upper bound is `calcTwosCompLimbCount(a.len)`.
pub fn truncate(r: *Mutable, a: Const, signedness: std.builtin.Signedness, bit_count: usize) void {
- const req_limbs = (bit_count + @bitSizeOf(Limb) - 1) / @bitSizeOf(Limb);
+ const req_limbs = calcTwosCompLimbCount(bit_count);
// Handle 0-bit integers.
if (req_limbs == 0 or a.eqZero()) {
@@ -2319,8 +2337,8 @@ pub const Managed = struct {
}
};
-/// r = a * b, ignoring overflow
-fn llmulacc_lo(r: []Limb, a: []const Limb, b: []const Limb) void {
+/// r = r + a * b, ignoring overflow
+fn llmulaccLow(r: []Limb, a: []const Limb, b: []const Limb) void {
assert(r.len >= a.len);
assert(a.len >= b.len);
@@ -2328,32 +2346,41 @@ fn llmulacc_lo(r: []Limb, a: []const Limb, b: []const Limb) void {
var i: usize = 0;
while (i < b.len) : (i += 1) {
- llmulDigit(r[i..], a, b[i]);
+ llmulLimb(.add, r[i..], a, b[i]);
}
}
+/// Different operators which can be used in accumulation style functions
+/// (llmulacc, llmulaccKaratsuba, llmulaccLong, llmulLimb). In all these functions,
+/// a computed value is accumulated with an existing result.
+const AccOp = enum {
+ /// The computed value is added to the result.
+ add,
+
+ /// The computed value is subtracted from the result.
+ sub,
+};
+
/// Knuth 4.3.1, Algorithm M.
///
+/// r = r (op) a * b
/// r MUST NOT alias any of a or b.
-fn llmulacc(opt_allocator: ?*Allocator, r: []Limb, a: []const Limb, b: []const Limb) void {
+fn llmulacc(comptime op: AccOp, opt_allocator: ?*Allocator, r: []Limb, a: []const Limb, b: []const Limb) void {
@setRuntimeSafety(debug_safety);
+ assert(r.len >= a.len + b.len);
- const a_norm = a[0..llnormalize(a)];
- const b_norm = b[0..llnormalize(b)];
- var x = a_norm;
- var y = b_norm;
- if (a_norm.len > b_norm.len) {
- x = b_norm;
- y = a_norm;
+ // Order greatest first.
+ var x = a;
+ var y = b;
+ if (a.len < b.len) {
+ x = b;
+ y = a;
}
- assert(r.len >= x.len + y.len + 1);
-
- // 48 is a pretty abitrary size chosen based on performance of a factorial program.
k_mul: {
- if (x.len > 48) {
+ if (y.len > 48) {
if (opt_allocator) |allocator| {
- llmulacc_karatsuba(allocator, r, x, y) catch |err| switch (err) {
+ llmulaccKaratsuba(op, allocator, r, x, y) catch |err| switch (err) {
error.OutOfMemory => break :k_mul, // handled below
};
return;
@@ -2361,83 +2388,153 @@ fn llmulacc(opt_allocator: ?*Allocator, r: []Limb, a: []const Limb, b: []const L
}
}
- // Basecase multiplication
- var i: usize = 0;
- while (i < x.len) : (i += 1) {
- llmulDigit(r[i..], y, x[i]);
- }
+ llmulaccLong(op, r, x, y);
}
/// Knuth 4.3.1, Algorithm M.
///
+/// r = r (op) a * b
/// r MUST NOT alias any of a or b.
-fn llmulacc_karatsuba(allocator: *Allocator, r: []Limb, x: []const Limb, y: []const Limb) error{OutOfMemory}!void {
+fn llmulaccKaratsuba(
+ comptime op: AccOp,
+ allocator: *Allocator,
+ r: []Limb,
+ a: []const Limb,
+ b: []const Limb,
+) error{OutOfMemory}!void {
@setRuntimeSafety(debug_safety);
+ assert(r.len >= a.len + b.len);
+ assert(a.len >= b.len);
- assert(r.len >= x.len + y.len + 1);
-
- const split = @divFloor(x.len, 2);
- var x0 = x[0..split];
- var x1 = x[split..x.len];
- var y0 = y[0..split];
- var y1 = y[split..y.len];
-
- var tmp = try allocator.alloc(Limb, x1.len + y1.len + 1);
+ // Classical karatsuba algorithm:
+ // a = a1 * B + a0
+ // b = b1 * B + b0
+ // Where a0, b0 < B
+ //
+ // We then have:
+ // ab = a * b
+ // = (a1 * B + a0) * (b1 * B + b0)
+ // = a1 * b1 * B * B + a1 * B * b0 + a0 * b1 * B + a0 * b0
+ // = a1 * b1 * B * B + (a1 * b0 + a0 * b1) * B + a0 * b0
+ //
+ // Note that:
+ // a1 * b0 + a0 * b1
+ // = (a1 + a0)(b1 + b0) - a1 * b1 - a0 * b0
+ // = (a0 - a1)(b1 - b0) + a1 * b1 + a0 * b0
+ //
+ // This yields:
+ // ab = p2 * B^2 + (p0 + p1 + p2) * B + p0
+ //
+ // Where:
+ // p0 = a0 * b0
+ // p1 = (a0 - a1)(b1 - b0)
+ // p2 = a1 * b1
+ //
+ // Note, (a0 - a1) and (b1 - b0) produce values -B < x < B, and so we need to mind the sign here.
+ // We also have:
+ // 0 <= p0 <= 2B
+ // -2B <= p1 <= 2B
+ //
+ // Note, when B is a multiple of the limb size, multiplies by B amount to shifts or
+ // slices of a limbs array.
+
+ const split = b.len / 2; // B
+ const a0 = a[0..llnormalize(a[0..split])];
+ const a1 = a[split..][0..llnormalize(a[split..])];
+ const b0 = b[0..llnormalize(b[0..split])];
+ const b1 = b[split..][0..llnormalize(b[split..])];
+
+ // Note that the above slices work because we have a.len > b.len.
+ // We now also have:
+ // a1.len >= a0.len
+ // a1.len >= b1.len >= b0.len
+ // a0.len == b0.len
+
+ // We need some temporary memory to store intermediate results.
+ // Note, we can reduce the amount of temporaries we need by reordering the computation here:
+ // ab = p2 * B^2 + (p0 + p1 + p2) * B + p0
+ // = p2 * B^2 + (p0 * B + p1 * B + p2 * B) + p0
+ // = (p2 * B^2 + p2 * B) + (p0 * B + p0) + p1 * B
+ // By allocating a1.len * b1.len we can be sure that all the intermediary results fit.
+ const tmp = try allocator.alloc(Limb, a.len - split + b.len - split);
defer allocator.free(tmp);
- mem.set(Limb, tmp, 0);
- llmulacc(allocator, tmp, x1, y1);
+ // Compute p2.
+ mem.set(Limb, tmp, 0);
+ llmulacc(.add, allocator, tmp, a1, b1);
+ const p2 = tmp[0 .. llnormalize(tmp)];
- var length = llnormalize(tmp);
- _ = llaccum(r[split..], tmp[0..length]);
- _ = llaccum(r[split * 2 ..], tmp[0..length]);
+ // Add terms p2 * B^2 and p2 * B to the result.
+ _ = llaccum(op, r[split..], p2);
+ _ = llaccum(op, r[split * 2..], p2);
- mem.set(Limb, tmp[0..length], 0);
+ // Compute p0.
+ mem.set(Limb, p2, 0);
+ llmulacc(.add, allocator, tmp, a0, b0);
+ const p0 = tmp[0 .. llnormalize(tmp[0..a0.len + b0.len])];
- llmulacc(allocator, tmp, x0, y0);
+ // Add terms p0 * B and p0 to the result.
+ _ = llaccum(op, r, p0);
+ _ = llaccum(op, r[split..], p0);
- length = llnormalize(tmp);
- _ = llaccum(r[0..], tmp[0..length]);
- _ = llaccum(r[split..], tmp[0..length]);
+ // Finally, compute and add p1.
+ const j0_sign = llcmp(a0, a1);
+ const j1_sign = llcmp(b1, b0);
- const x_cmp = llcmp(x1, x0);
- const y_cmp = llcmp(y1, y0);
- if (x_cmp * y_cmp == 0) {
+ if (j0_sign * j1_sign == 0) {
+ // p1 is zero, we don't need to do any computation at all.
return;
}
- const x0_len = llnormalize(x0);
- const x1_len = llnormalize(x1);
- var j0 = try allocator.alloc(Limb, math.max(x0_len, x1_len));
- defer allocator.free(j0);
- if (x_cmp == 1) {
- llsub(j0, x1[0..x1_len], x0[0..x0_len]);
+
+ mem.set(Limb, tmp, 0);
+
+ // p1 is nonzero, so compute the intermediary terms j0 = a0 - a1 and j1 = b1 - b0.
+ // Note that in this case, we again need some storage for intermediary results
+ // j0 and j1. Since we have tmp.len >= 2B, we can store both
+ // intermediaries in the already allocated array.
+ const j0 = tmp[0..a1.len];
+ const j1 = tmp[a1.len..];
+
+ // Ensure that no subtraction overflows.
+ if (j0_sign == 1) {
+ // a0 > a1.
+ _ = llsubcarry(j0, a0, a1);
} else {
- llsub(j0, x0[0..x0_len], x1[0..x1_len]);
+ // a0 < a1.
+ _ = llsubcarry(j0, a1, a0);
}
- const y0_len = llnormalize(y0);
- const y1_len = llnormalize(y1);
- var j1 = try allocator.alloc(Limb, math.max(y0_len, y1_len));
- defer allocator.free(j1);
- if (y_cmp == 1) {
- llsub(j1, y1[0..y1_len], y0[0..y0_len]);
+ if (j1_sign == 1) {
+ // b1 > b0.
+ _ = llsubcarry(j1, b1, b0);
} else {
- llsub(j1, y0[0..y0_len], y1[0..y1_len]);
+ // b1 > b0.
+ _ = llsubcarry(j1, b0, b1);
}
- if (x_cmp == y_cmp) {
- mem.set(Limb, tmp[0..length], 0);
- llmulacc(allocator, tmp, j0, j1);
- length = llnormalize(tmp);
- llsub(r[split..], r[split..], tmp[0..length]);
+ if (j0_sign * j1_sign == 1) {
+ // If j0 and j1 are both positive, we now have:
+ // p1 = j0 * j1
+ // If j0 and j1 are both negative, we now have:
+ // p1 = -j0 * -j1 = j0 * j1
+ // In this case we can add p1 to the result using llmulacc.
+ llmulacc(op, allocator, r[split..], j0[0..llnormalize(j0)], j1[0..llnormalize(j1)]);
} else {
- llmulacc(allocator, r[split..], j0, j1);
+ // In this case either j0 or j1 is negative, an we have:
+ // p1 = -(j0 * j1)
+ // Now we need to subtract instead of accumulate.
+ const inverted_op = if (op == .add) .sub else .add;
+ llmulacc(inverted_op, allocator, r[split..], j0[0..llnormalize(j0)], j1[0..llnormalize(j1)]);
}
}
-// r = r + a
-fn llaccum(r: []Limb, a: []const Limb) Limb {
+// r = r (op) a
+fn llaccum(comptime op: AccOp, r: []Limb, a: []const Limb) Limb {
@setRuntimeSafety(debug_safety);
+ if (op == .sub) {
+ return llsubcarry(r, r, a);
+ }
+
assert(r.len != 0 and a.len != 0);
assert(r.len >= a.len);
@@ -2486,24 +2583,53 @@ pub fn llcmp(a: []const Limb, b: []const Limb) i8 {
}
}
-fn llmulDigit(acc: []Limb, y: []const Limb, xi: Limb) void {
+// r = r (op) y * xi
+fn llmulaccLong(comptime op: AccOp, r: []Limb, a: []const Limb, b: []const Limb) void {
+ @setRuntimeSafety(debug_safety);
+ assert(r.len >= a.len + b.len);
+ assert(a.len >= b.len);
+
+ var i: usize = 0;
+ while (i < a.len) : (i += 1) {
+ llmulLimb(op, r[i..], b, a[i]);
+ }
+}
+
+// r = r (op) y * xi
+fn llmulLimb(comptime op: AccOp, acc: []Limb, y: []const Limb, xi: Limb) void {
@setRuntimeSafety(debug_safety);
if (xi == 0) {
return;
}
- var carry: Limb = 0;
var a_lo = acc[0..y.len];
var a_hi = acc[y.len..];
- var j: usize = 0;
- while (j < a_lo.len) : (j += 1) {
- a_lo[j] = @call(.{ .modifier = .always_inline }, addMulLimbWithCarry, .{ a_lo[j], y[j], xi, &carry });
- }
+ switch (op) {
+ .add => {
+ var carry: Limb = 0;
+ var j: usize = 0;
+ while (j < a_lo.len) : (j += 1) {
+ a_lo[j] = addMulLimbWithCarry(a_lo[j], y[j], xi, &carry);
+ }
- j = 0;
- while ((carry != 0) and (j < a_hi.len)) : (j += 1) {
- carry = @boolToInt(@addWithOverflow(Limb, a_hi[j], carry, &a_hi[j]));
+ j = 0;
+ while ((carry != 0) and (j < a_hi.len)) : (j += 1) {
+ carry = @boolToInt(@addWithOverflow(Limb, a_hi[j], carry, &a_hi[j]));
+ }
+ },
+ .sub => {
+ var borrow: Limb = 0;
+ var j: usize = 0;
+ while (j < a_lo.len) : (j += 1) {
+ a_lo[j] = subMulLimbWithBorrow(a_lo[j], y[j], xi, &borrow);
+ }
+
+ j = 0;
+ while ((borrow != 0) and (j < a_hi.len)) : (j += 1) {
+ borrow = @boolToInt(@subWithOverflow(Limb, a_hi[j], borrow, &a_hi[j]));
+ }
+ },
}
}
@@ -2964,7 +3090,7 @@ fn llsignedxor(r: []Limb, a: []const Limb, a_positive: bool, b: []const Limb, b_
}
/// r MUST NOT alias x.
-fn llsquare_basecase(r: []Limb, x: []const Limb) void {
+fn llsquareBasecase(r: []Limb, x: []const Limb) void {
@setRuntimeSafety(debug_safety);
const x_norm = x;
@@ -2987,7 +3113,7 @@ fn llsquare_basecase(r: []Limb, x: []const Limb) void {
for (x_norm) |v, i| {
// Accumulate all the x[i]*x[j] (with x!=j) products
- llmulDigit(r[2 * i + 1 ..], x_norm[i + 1 ..], v);
+ llmulLimb(.add, r[2 * i + 1 ..], x_norm[i + 1 ..], v);
}
// Each product appears twice, multiply by 2
@@ -2995,7 +3121,7 @@ fn llsquare_basecase(r: []Limb, x: []const Limb) void {
for (x_norm) |v, i| {
// Compute and add the squares
- llmulDigit(r[2 * i ..], x[i .. i + 1], v);
+ llmulLimb(.add, r[2 * i ..], x[i .. i + 1], v);
}
}
@@ -3034,12 +3160,12 @@ fn llpow(r: []Limb, a: []const Limb, b: u32, tmp_limbs: []Limb) void {
while (i < exp_bits) : (i += 1) {
// Square
mem.set(Limb, tmp2, 0);
- llsquare_basecase(tmp2, tmp1[0..llnormalize(tmp1)]);
+ llsquareBasecase(tmp2, tmp1[0..llnormalize(tmp1)]);
mem.swap([]Limb, &tmp1, &tmp2);
// Multiply by a
if (@shlWithOverflow(u32, exp, 1, &exp)) {
mem.set(Limb, tmp2, 0);
- llmulacc(null, tmp2, tmp1[0..llnormalize(tmp1)], a);
+ llmulacc(.add, null, tmp2, tmp1[0..llnormalize(tmp1)], a);
mem.swap([]Limb, &tmp1, &tmp2);
}
}