Commit 5907b3e383

Robin Voetter <robin@voetter.nl>
2021-10-03 01:20:10
big ints: Improve karatsuba multiplication
1 parent 15351f2
Changed files (1)
lib
std
math
lib/std/math/big/int.zig
@@ -86,6 +86,24 @@ pub fn addMulLimbWithCarry(a: Limb, b: Limb, c: Limb, carry: *Limb) Limb {
     return r1;
 }
 
+/// a - b * c - *carry, sets carry to the overflow bits
+fn subMulLimbWithBorrow(a: Limb, b: Limb, c: Limb, carry: *Limb) Limb {
+    // r1 = a - *carry
+    var r1: Limb = undefined;
+    const c1: Limb = @boolToInt(@subWithOverflow(Limb, a, carry.*, &r1));
+
+    // r2 = b * c
+    const bc = @as(DoubleLimb, std.math.mulWide(Limb, b, c));
+    const r2 = @truncate(Limb, bc);
+    const c2 = @truncate(Limb, bc >> limb_bits);
+
+    // r1 = r1 - r2
+    const c3: Limb = @boolToInt(@subWithOverflow(Limb, r1, r2, &r1));
+    carry.* = c1 + c2 + c3;
+
+    return r1;
+}
+
 /// Used to indicate either limit of a 2s-complement integer.
 pub const TwosCompIntLimit = enum {
     // The low limit, either 0x00 (unsigned) or (-)0x80 (signed) for an 8-bit integer.
@@ -640,7 +658,7 @@ pub const Mutable = struct {
 
         mem.set(Limb, rma.limbs[0 .. a.limbs.len + b.limbs.len + 1], 0);
 
-        llmulacc(allocator, rma.limbs, a.limbs, b.limbs);
+        llmulacc(.add, allocator, rma.limbs, a.limbs, b.limbs);
 
         rma.normalize(a.limbs.len + b.limbs.len);
         rma.positive = (a.positive == b.positive);
@@ -665,9 +683,9 @@ pub const Mutable = struct {
         mem.set(Limb, rma.limbs[0..req_limbs], 0);
 
         if (a_limbs.len >= b_limbs.len) {
-            llmulacc_lo(rma.limbs, a_limbs, b_limbs);
+            llmulaccLow(rma.limbs, a_limbs, b_limbs);
         } else {
-            llmulacc_lo(rma.limbs, b_limbs, a_limbs);
+            llmulaccLow(rma.limbs, b_limbs, a_limbs);
         }
 
         rma.normalize(math.min(req_limbs, a.limbs.len + b.limbs.len));
@@ -691,7 +709,7 @@ pub const Mutable = struct {
 
         mem.set(Limb, rma.limbs, 0);
 
-        llsquare_basecase(rma.limbs, a.limbs);
+        llsquareBasecase(rma.limbs, a.limbs);
 
         rma.normalize(2 * a.limbs.len + 1);
         rma.positive = true;
@@ -1219,7 +1237,7 @@ pub const Mutable = struct {
     /// Asserts `r` has enough storage to store the result.
     /// The upper bound is `calcTwosCompLimbCount(a.len)`.
     pub fn truncate(r: *Mutable, a: Const, signedness: std.builtin.Signedness, bit_count: usize) void {
-        const req_limbs = (bit_count + @bitSizeOf(Limb) - 1) / @bitSizeOf(Limb);
+        const req_limbs = calcTwosCompLimbCount(bit_count);
 
         // Handle 0-bit integers.
         if (req_limbs == 0 or a.eqZero()) {
@@ -2319,8 +2337,8 @@ pub const Managed = struct {
     }
 };
 
-/// r = a * b, ignoring overflow
-fn llmulacc_lo(r: []Limb, a: []const Limb, b: []const Limb) void {
+/// r = r + a * b, ignoring overflow
+fn llmulaccLow(r: []Limb, a: []const Limb, b: []const Limb) void {
     assert(r.len >= a.len);
     assert(a.len >= b.len);
 
@@ -2328,32 +2346,41 @@ fn llmulacc_lo(r: []Limb, a: []const Limb, b: []const Limb) void {
 
     var i: usize = 0;
     while (i < b.len) : (i += 1) {
-        llmulDigit(r[i..], a, b[i]);
+        llmulLimb(.add, r[i..], a, b[i]);
     }
 }
 
+/// Different operators which can be used in accumulation style functions
+/// (llmulacc, llmulaccKaratsuba, llmulaccLong, llmulLimb). In all these functions,
+/// a computed value is accumulated with an existing result.
+const AccOp = enum {
+    /// The computed value is added to the result.
+    add,
+
+    /// The computed value is subtracted from the result.
+    sub,
+};
+
 /// Knuth 4.3.1, Algorithm M.
 ///
+/// r = r (op) a * b
 /// r MUST NOT alias any of a or b.
-fn llmulacc(opt_allocator: ?*Allocator, r: []Limb, a: []const Limb, b: []const Limb) void {
+fn llmulacc(comptime op: AccOp, opt_allocator: ?*Allocator, r: []Limb, a: []const Limb, b: []const Limb) void {
     @setRuntimeSafety(debug_safety);
+    assert(r.len >= a.len + b.len);
 
-    const a_norm = a[0..llnormalize(a)];
-    const b_norm = b[0..llnormalize(b)];
-    var x = a_norm;
-    var y = b_norm;
-    if (a_norm.len > b_norm.len) {
-        x = b_norm;
-        y = a_norm;
+    // Order greatest first.
+    var x = a;
+    var y = b;
+    if (a.len < b.len) {
+        x = b;
+        y = a;
     }
 
-    assert(r.len >= x.len + y.len + 1);
-
-    // 48 is a pretty abitrary size chosen based on performance of a factorial program.
     k_mul: {
-        if (x.len > 48) {
+        if (y.len > 48) {
             if (opt_allocator) |allocator| {
-                llmulacc_karatsuba(allocator, r, x, y) catch |err| switch (err) {
+                llmulaccKaratsuba(op, allocator, r, x, y) catch |err| switch (err) {
                     error.OutOfMemory => break :k_mul, // handled below
                 };
                 return;
@@ -2361,83 +2388,153 @@ fn llmulacc(opt_allocator: ?*Allocator, r: []Limb, a: []const Limb, b: []const L
         }
     }
 
-    // Basecase multiplication
-    var i: usize = 0;
-    while (i < x.len) : (i += 1) {
-        llmulDigit(r[i..], y, x[i]);
-    }
+    llmulaccLong(op, r, x, y);
 }
 
 /// Knuth 4.3.1, Algorithm M.
 ///
+/// r = r (op) a * b
 /// r MUST NOT alias any of a or b.
-fn llmulacc_karatsuba(allocator: *Allocator, r: []Limb, x: []const Limb, y: []const Limb) error{OutOfMemory}!void {
+fn llmulaccKaratsuba(
+    comptime op: AccOp,
+    allocator: *Allocator,
+    r: []Limb,
+    a: []const Limb,
+    b: []const Limb,
+) error{OutOfMemory}!void {
     @setRuntimeSafety(debug_safety);
+    assert(r.len >= a.len + b.len);
+    assert(a.len >= b.len);
 
-    assert(r.len >= x.len + y.len + 1);
-
-    const split = @divFloor(x.len, 2);
-    var x0 = x[0..split];
-    var x1 = x[split..x.len];
-    var y0 = y[0..split];
-    var y1 = y[split..y.len];
-
-    var tmp = try allocator.alloc(Limb, x1.len + y1.len + 1);
+    // Classical karatsuba algorithm:
+    // a = a1 * B + a0
+    // b = b1 * B + b0
+    // Where a0, b0 < B
+    //
+    // We then have:
+    // ab = a * b
+    //    = (a1 * B + a0) * (b1 * B + b0)
+    //    = a1 * b1 * B * B + a1 * B * b0 + a0 * b1 * B + a0 * b0
+    //    = a1 * b1 * B * B + (a1 * b0 + a0 * b1) * B + a0 * b0
+    //
+    // Note that:
+    // a1 * b0 + a0 * b1
+    //    = (a1 + a0)(b1 + b0) - a1 * b1 - a0 * b0
+    //    = (a0 - a1)(b1 - b0) + a1 * b1 + a0 * b0
+    //
+    // This yields:
+    // ab = p2 * B^2 + (p0 + p1 + p2) * B + p0
+    //
+    // Where:
+    // p0 = a0 * b0
+    // p1 = (a0 - a1)(b1 - b0)
+    // p2 = a1 * b1
+    //
+    // Note, (a0 - a1) and (b1 - b0) produce values -B < x < B, and so we need to mind the sign here.
+    // We also have:
+    // 0 <= p0 <= 2B
+    // -2B <= p1 <= 2B
+    //
+    // Note, when B is a multiple of the limb size, multiplies by B amount to shifts or
+    // slices of a limbs array.
+
+    const split = b.len / 2; // B
+    const a0 = a[0..llnormalize(a[0..split])];
+    const a1 = a[split..][0..llnormalize(a[split..])];
+    const b0 = b[0..llnormalize(b[0..split])];
+    const b1 = b[split..][0..llnormalize(b[split..])];
+
+    // Note that the above slices work because we have a.len > b.len.
+    // We now also have:
+    // a1.len >= a0.len
+    // a1.len >= b1.len >= b0.len
+    // a0.len == b0.len
+
+    // We need some temporary memory to store intermediate results.
+    // Note, we can reduce the amount of temporaries we need by reordering the computation here:
+    // ab = p2 * B^2 + (p0 + p1 + p2) * B + p0
+    //    = p2 * B^2 + (p0 * B + p1 * B + p2 * B) + p0
+    //    = (p2 * B^2 + p2 * B) + (p0 * B + p0) + p1 * B
+    // By allocating a1.len * b1.len we can be sure that all the intermediary results fit.
+    const tmp = try allocator.alloc(Limb, a.len - split + b.len - split);
     defer allocator.free(tmp);
-    mem.set(Limb, tmp, 0);
 
-    llmulacc(allocator, tmp, x1, y1);
+    // Compute p2.
+    mem.set(Limb, tmp, 0);
+    llmulacc(.add, allocator, tmp, a1, b1);
+    const p2 = tmp[0 .. llnormalize(tmp)];
 
-    var length = llnormalize(tmp);
-    _ = llaccum(r[split..], tmp[0..length]);
-    _ = llaccum(r[split * 2 ..], tmp[0..length]);
+    // Add terms p2 * B^2 and p2 * B to the result.
+    _ = llaccum(op, r[split..], p2);
+    _ = llaccum(op, r[split * 2..], p2);
 
-    mem.set(Limb, tmp[0..length], 0);
+    // Compute p0.
+    mem.set(Limb, p2, 0);
+    llmulacc(.add, allocator, tmp, a0, b0);
+    const p0 = tmp[0 .. llnormalize(tmp[0..a0.len + b0.len])];
 
-    llmulacc(allocator, tmp, x0, y0);
+    // Add terms p0 * B and p0 to the result.
+    _ = llaccum(op, r, p0);
+    _ = llaccum(op, r[split..], p0);
 
-    length = llnormalize(tmp);
-    _ = llaccum(r[0..], tmp[0..length]);
-    _ = llaccum(r[split..], tmp[0..length]);
+    // Finally, compute and add p1.
+    const j0_sign = llcmp(a0, a1);
+    const j1_sign = llcmp(b1, b0);
 
-    const x_cmp = llcmp(x1, x0);
-    const y_cmp = llcmp(y1, y0);
-    if (x_cmp * y_cmp == 0) {
+    if (j0_sign * j1_sign == 0) {
+        // p1 is zero, we don't need to do any computation at all.
         return;
     }
-    const x0_len = llnormalize(x0);
-    const x1_len = llnormalize(x1);
-    var j0 = try allocator.alloc(Limb, math.max(x0_len, x1_len));
-    defer allocator.free(j0);
-    if (x_cmp == 1) {
-        llsub(j0, x1[0..x1_len], x0[0..x0_len]);
+
+    mem.set(Limb, tmp, 0);
+
+    // p1 is nonzero, so compute the intermediary terms j0 = a0 - a1 and j1 = b1 - b0.
+    // Note that in this case, we again need some storage for intermediary results
+    // j0 and j1. Since we have tmp.len >= 2B, we can store both
+    // intermediaries in the already allocated array.
+    const j0 = tmp[0..a1.len];
+    const j1 = tmp[a1.len..];
+
+    // Ensure that no subtraction overflows.
+    if (j0_sign == 1) {
+        // a0 > a1.
+        _ = llsubcarry(j0, a0, a1);
     } else {
-        llsub(j0, x0[0..x0_len], x1[0..x1_len]);
+        // a0 < a1.
+        _ = llsubcarry(j0, a1, a0);
     }
 
-    const y0_len = llnormalize(y0);
-    const y1_len = llnormalize(y1);
-    var j1 = try allocator.alloc(Limb, math.max(y0_len, y1_len));
-    defer allocator.free(j1);
-    if (y_cmp == 1) {
-        llsub(j1, y1[0..y1_len], y0[0..y0_len]);
+    if (j1_sign == 1) {
+        // b1 > b0.
+        _ = llsubcarry(j1, b1, b0);
     } else {
-        llsub(j1, y0[0..y0_len], y1[0..y1_len]);
+        // b1 > b0.
+        _ = llsubcarry(j1, b0, b1);
     }
-    if (x_cmp == y_cmp) {
-        mem.set(Limb, tmp[0..length], 0);
-        llmulacc(allocator, tmp, j0, j1);
 
-        length = llnormalize(tmp);
-        llsub(r[split..], r[split..], tmp[0..length]);
+    if (j0_sign * j1_sign == 1) {
+        // If j0 and j1 are both positive, we now have:
+        // p1 = j0 * j1
+        // If j0 and j1 are both negative, we now have:
+        // p1 = -j0 * -j1 = j0 * j1
+        // In this case we can add p1 to the result using llmulacc.
+        llmulacc(op, allocator, r[split..], j0[0..llnormalize(j0)], j1[0..llnormalize(j1)]);
     } else {
-        llmulacc(allocator, r[split..], j0, j1);
+        // In this case either j0 or j1 is negative, an we have:
+        // p1 = -(j0 * j1)
+        // Now we need to subtract instead of accumulate.
+        const inverted_op = if (op == .add) .sub else .add;
+        llmulacc(inverted_op, allocator, r[split..], j0[0..llnormalize(j0)], j1[0..llnormalize(j1)]);
     }
 }
 
-// r = r + a
-fn llaccum(r: []Limb, a: []const Limb) Limb {
+// r = r (op) a
+fn llaccum(comptime op: AccOp, r: []Limb, a: []const Limb) Limb {
     @setRuntimeSafety(debug_safety);
+    if (op == .sub) {
+        return llsubcarry(r, r, a);
+    }
+
     assert(r.len != 0 and a.len != 0);
     assert(r.len >= a.len);
 
@@ -2486,24 +2583,53 @@ pub fn llcmp(a: []const Limb, b: []const Limb) i8 {
     }
 }
 
-fn llmulDigit(acc: []Limb, y: []const Limb, xi: Limb) void {
+// r = r (op) y * xi
+fn llmulaccLong(comptime op: AccOp, r: []Limb, a: []const Limb, b: []const Limb) void {
+    @setRuntimeSafety(debug_safety);
+    assert(r.len >= a.len + b.len);
+    assert(a.len >= b.len);
+
+    var i: usize = 0;
+    while (i < a.len) : (i += 1) {
+        llmulLimb(op, r[i..], b, a[i]);
+    }
+}
+
+// r = r (op) y * xi
+fn llmulLimb(comptime op: AccOp, acc: []Limb, y: []const Limb, xi: Limb) void {
     @setRuntimeSafety(debug_safety);
     if (xi == 0) {
         return;
     }
 
-    var carry: Limb = 0;
     var a_lo = acc[0..y.len];
     var a_hi = acc[y.len..];
 
-    var j: usize = 0;
-    while (j < a_lo.len) : (j += 1) {
-        a_lo[j] = @call(.{ .modifier = .always_inline }, addMulLimbWithCarry, .{ a_lo[j], y[j], xi, &carry });
-    }
+    switch (op) {
+        .add => {
+            var carry: Limb = 0;
+            var j: usize = 0;
+            while (j < a_lo.len) : (j += 1) {
+                a_lo[j] = addMulLimbWithCarry(a_lo[j], y[j], xi, &carry);
+            }
 
-    j = 0;
-    while ((carry != 0) and (j < a_hi.len)) : (j += 1) {
-        carry = @boolToInt(@addWithOverflow(Limb, a_hi[j], carry, &a_hi[j]));
+            j = 0;
+            while ((carry != 0) and (j < a_hi.len)) : (j += 1) {
+                carry = @boolToInt(@addWithOverflow(Limb, a_hi[j], carry, &a_hi[j]));
+            }
+        },
+        .sub => {
+            var borrow: Limb = 0;
+            var j: usize = 0;
+            while (j < a_lo.len) : (j += 1) {
+                a_lo[j] = subMulLimbWithBorrow(a_lo[j], y[j], xi, &borrow);
+            }
+
+            j = 0;
+            while ((borrow != 0) and (j < a_hi.len)) : (j += 1) {
+                borrow = @boolToInt(@subWithOverflow(Limb, a_hi[j], borrow, &a_hi[j]));
+            }
+        },
     }
 }
 
@@ -2964,7 +3090,7 @@ fn llsignedxor(r: []Limb, a: []const Limb, a_positive: bool, b: []const Limb, b_
 }
 
 /// r MUST NOT alias x.
-fn llsquare_basecase(r: []Limb, x: []const Limb) void {
+fn llsquareBasecase(r: []Limb, x: []const Limb) void {
     @setRuntimeSafety(debug_safety);
 
     const x_norm = x;
@@ -2987,7 +3113,7 @@ fn llsquare_basecase(r: []Limb, x: []const Limb) void {
 
     for (x_norm) |v, i| {
         // Accumulate all the x[i]*x[j] (with x!=j) products
-        llmulDigit(r[2 * i + 1 ..], x_norm[i + 1 ..], v);
+        llmulLimb(.add, r[2 * i + 1 ..], x_norm[i + 1 ..], v);
     }
 
     // Each product appears twice, multiply by 2
@@ -2995,7 +3121,7 @@ fn llsquare_basecase(r: []Limb, x: []const Limb) void {
 
     for (x_norm) |v, i| {
         // Compute and add the squares
-        llmulDigit(r[2 * i ..], x[i .. i + 1], v);
+        llmulLimb(.add, r[2 * i ..], x[i .. i + 1], v);
     }
 }
 
@@ -3034,12 +3160,12 @@ fn llpow(r: []Limb, a: []const Limb, b: u32, tmp_limbs: []Limb) void {
     while (i < exp_bits) : (i += 1) {
         // Square
         mem.set(Limb, tmp2, 0);
-        llsquare_basecase(tmp2, tmp1[0..llnormalize(tmp1)]);
+        llsquareBasecase(tmp2, tmp1[0..llnormalize(tmp1)]);
         mem.swap([]Limb, &tmp1, &tmp2);
         // Multiply by a
         if (@shlWithOverflow(u32, exp, 1, &exp)) {
             mem.set(Limb, tmp2, 0);
-            llmulacc(null, tmp2, tmp1[0..llnormalize(tmp1)], a);
+            llmulacc(.add, null, tmp2, tmp1[0..llnormalize(tmp1)], a);
             mem.swap([]Limb, &tmp1, &tmp2);
         }
     }