Commit 41e9c1bac1

Robin Voetter <robin@voetter.nl>
2021-10-03 16:03:43
big ints: Allow llmulaccum to wrap
1 parent 5907b3e
Changed files (1)
lib
std
math
lib/std/math/big/int.zig
@@ -658,7 +658,7 @@ pub const Mutable = struct {
 
         mem.set(Limb, rma.limbs[0 .. a.limbs.len + b.limbs.len + 1], 0);
 
-        llmulacc(.add, allocator, rma.limbs, a.limbs, b.limbs);
+        _ = llmulacc(.add, allocator, rma.limbs, a.limbs, b.limbs);
 
         rma.normalize(a.limbs.len + b.limbs.len);
         rma.positive = (a.positive == b.positive);
@@ -2365,9 +2365,12 @@ const AccOp = enum {
 ///
 /// r = r (op) a * b
 /// r MUST NOT alias any of a or b.
+///
+/// The result is computed modulo `r.len`. When `r.len >= a.len + b.len`, no overflow occurs.
 fn llmulacc(comptime op: AccOp, opt_allocator: ?*Allocator, r: []Limb, a: []const Limb, b: []const Limb) void {
     @setRuntimeSafety(debug_safety);
-    assert(r.len >= a.len + b.len);
+    assert(r.len >= a.len);
+    assert(r.len >= b.len);
 
     // Order greatest first.
     var x = a;
@@ -2395,6 +2398,8 @@ fn llmulacc(comptime op: AccOp, opt_allocator: ?*Allocator, r: []Limb, a: []cons
 ///
 /// r = r (op) a * b
 /// r MUST NOT alias any of a or b.
+///
+/// The result is computed modulo `r.len`. When `r.len >= a.len + b.len`, no overflow occurs.
 fn llmulaccKaratsuba(
     comptime op: AccOp,
     allocator: *Allocator,
@@ -2403,7 +2408,7 @@ fn llmulaccKaratsuba(
     b: []const Limb,
 ) error{OutOfMemory}!void {
     @setRuntimeSafety(debug_safety);
-    assert(r.len >= a.len + b.len);
+    assert(r.len >= a.len);
     assert(a.len >= b.len);
 
     // Classical karatsuba algorithm:
@@ -2437,49 +2442,84 @@ fn llmulaccKaratsuba(
     //
     // Note, when B is a multiple of the limb size, multiplies by B amount to shifts or
     // slices of a limbs array.
+    //
+    // This function computes the result of the multiplication modulo r.len. This means:
+    // - p2 and p1 only need to be computed modulo r.len - B.
+    // - In the case of p2, p2 * B^2 needs to be added modulo r.len - 2 * B.
 
     const split = b.len / 2; // B
+
+    const limbs_after_split = r.len - split; // Limbs to compute for p1 and p2.
+    const limbs_after_split2 = r.len - split * 2; // Limbs to add for p2 * B^2.
+
+    // For a0 and b0 we need the full range.
     const a0 = a[0..llnormalize(a[0..split])];
-    const a1 = a[split..][0..llnormalize(a[split..])];
     const b0 = b[0..llnormalize(b[0..split])];
-    const b1 = b[split..][0..llnormalize(b[split..])];
 
-    // Note that the above slices work because we have a.len > b.len.
-    // We now also have:
-    // a1.len >= a0.len
-    // a1.len >= b1.len >= b0.len
-    // a0.len == b0.len
+    // For a1 and b1 we only need `limbs_after_split` limbs.
+    const a1 = blk: {
+        var a1 = a[split..];
+        a1.len = math.min(llnormalize(a1), limbs_after_split);
+        break :blk a1;
+    };
+
+    const b1 = blk: {
+        var b1 = b[split..];
+        b1.len = math.min(llnormalize(b1), limbs_after_split);
+        break :blk b1;
+    };
+
+    // Note that the above slices relative to `split` work because we have a.len > b.len.
 
     // We need some temporary memory to store intermediate results.
     // Note, we can reduce the amount of temporaries we need by reordering the computation here:
     // ab = p2 * B^2 + (p0 + p1 + p2) * B + p0
     //    = p2 * B^2 + (p0 * B + p1 * B + p2 * B) + p0
     //    = (p2 * B^2 + p2 * B) + (p0 * B + p0) + p1 * B
-    // By allocating a1.len * b1.len we can be sure that all the intermediary results fit.
+
+    // Allocate at least enough memory to be able to multiply the upper two segments of a and b, assuming
+    // no overflow.
     const tmp = try allocator.alloc(Limb, a.len - split + b.len - split);
     defer allocator.free(tmp);
 
     // Compute p2.
-    mem.set(Limb, tmp, 0);
-    llmulacc(.add, allocator, tmp, a1, b1);
-    const p2 = tmp[0 .. llnormalize(tmp)];
+    // Note, we don't need to compute all of p2, just enough limbs to satisfy r.
+    const p2_limbs = math.min(limbs_after_split, a1.len + b1.len);
 
-    // Add terms p2 * B^2 and p2 * B to the result.
-    _ = llaccum(op, r[split..], p2);
-    _ = llaccum(op, r[split * 2..], p2);
+    mem.set(Limb, tmp[0..p2_limbs], 0);
+    llmulacc(.add, allocator, tmp[0..p2_limbs], a1[0..math.min(a1.len, p2_limbs)], b1[0..math.min(b1.len, p2_limbs)]);
+    const p2 = tmp[0 .. llnormalize(tmp[0..p2_limbs])];
+
+    // Add p2 * B to the result.
+    llaccum(op, r[split..], p2);
+
+    // Add p2 * B^2 to the result if required.
+    if (limbs_after_split2 > 0) {
+        llaccum(op, r[split * 2..], p2[0..math.min(p2.len, limbs_after_split2)]);
+    }
 
     // Compute p0.
-    mem.set(Limb, p2, 0);
-    llmulacc(.add, allocator, tmp, a0, b0);
-    const p0 = tmp[0 .. llnormalize(tmp[0..a0.len + b0.len])];
+    // Since a0.len, b0.len <= split and r.len >= split * 2, the full width of p0 needs to be computed.
+    const p0_limbs = a0.len + b0.len;
+    mem.set(Limb, tmp[0..p0_limbs], 0);
+    llmulacc(.add, allocator, tmp[0..p0_limbs], a0, b0);
+    const p0 = tmp[0 .. llnormalize(tmp[0..p0_limbs])];
+
+    // Add p0 to the result.
+    llaccum(op, r, p0);
+
+    // Add p0 * B to the result. In this case, we may not need all of it.
+    llaccum(op, r[split..], p0[0..math.min(limbs_after_split, p0.len)]);
 
-    // Add terms p0 * B and p0 to the result.
-    _ = llaccum(op, r, p0);
-    _ = llaccum(op, r[split..], p0);
 
     // Finally, compute and add p1.
-    const j0_sign = llcmp(a0, a1);
-    const j1_sign = llcmp(b1, b0);
+    // From now on we only need `limbs_after_split` limbs for a0 and b0, since the result of the
+    // following computation will be added * B.
+    const a0x = a0[0..std.math.min(a0.len, limbs_after_split)];
+    const b0x = b0[0..std.math.min(b0.len, limbs_after_split)];
+
+    const j0_sign = llcmp(a0x, a1);
+    const j1_sign = llcmp(b1, b0x);
 
     if (j0_sign * j1_sign == 0) {
         // p1 is zero, we don't need to do any computation at all.
@@ -2492,24 +2532,24 @@ fn llmulaccKaratsuba(
     // Note that in this case, we again need some storage for intermediary results
     // j0 and j1. Since we have tmp.len >= 2B, we can store both
     // intermediaries in the already allocated array.
-    const j0 = tmp[0..a1.len];
-    const j1 = tmp[a1.len..];
+    const j0 = tmp[0..a.len - split];
+    const j1 = tmp[a.len - split..];
 
     // Ensure that no subtraction overflows.
     if (j0_sign == 1) {
         // a0 > a1.
-        _ = llsubcarry(j0, a0, a1);
+        _ = llsubcarry(j0, a0x, a1);
     } else {
         // a0 < a1.
-        _ = llsubcarry(j0, a1, a0);
+        _ = llsubcarry(j0, a1, a0x);
     }
 
     if (j1_sign == 1) {
         // b1 > b0.
-        _ = llsubcarry(j1, b1, b0);
+        _ = llsubcarry(j1, b1, b0x);
     } else {
         // b1 > b0.
-        _ = llsubcarry(j1, b0, b1);
+        _ = llsubcarry(j1, b0x, b1);
     }
 
     if (j0_sign * j1_sign == 1) {
@@ -2528,11 +2568,13 @@ fn llmulaccKaratsuba(
     }
 }
 
-// r = r (op) a
-fn llaccum(comptime op: AccOp, r: []Limb, a: []const Limb) Limb {
+/// r = r (op) a.
+/// The result is computed modulo `r.len`.
+fn llaccum(comptime op: AccOp, r: []Limb, a: []const Limb) void {
     @setRuntimeSafety(debug_safety);
     if (op == .sub) {
-        return llsubcarry(r, r, a);
+        _ = llsubcarry(r, r, a);
+        return;
     }
 
     assert(r.len != 0 and a.len != 0);
@@ -2551,8 +2593,6 @@ fn llaccum(comptime op: AccOp, r: []Limb, a: []const Limb) Limb {
     while ((carry != 0) and i < r.len) : (i += 1) {
         carry = @boolToInt(@addWithOverflow(Limb, r[i], carry, &r[i]));
     }
-
-    return carry;
 }
 
 /// Returns -1, 0, 1 if |a| < |b|, |a| == |b| or |a| > |b| respectively for limbs.
@@ -2583,19 +2623,21 @@ pub fn llcmp(a: []const Limb, b: []const Limb) i8 {
     }
 }
 
-// r = r (op) y * xi
+/// r = r (op) y * xi
+/// The result is computed modulo `r.len`. When `r.len >= a.len + b.len`, no overflow occurs.
 fn llmulaccLong(comptime op: AccOp, r: []Limb, a: []const Limb, b: []const Limb) void {
     @setRuntimeSafety(debug_safety);
     assert(r.len >= a.len + b.len);
     assert(a.len >= b.len);
 
     var i: usize = 0;
-    while (i < a.len) : (i += 1) {
-        llmulLimb(op, r[i..], b, a[i]);
+    while (i < b.len) : (i += 1) {
+        llmulLimb(op, r[i..], a, b[i]);
     }
 }
 
-// r = r (op) y * xi
+/// r = r (op) y * xi
+/// The result is computed modulo `r.len`.
 fn llmulLimb(comptime op: AccOp, acc: []Limb, y: []const Limb, xi: Limb) void {
     @setRuntimeSafety(debug_safety);
     if (xi == 0) {