Commit 52207f22de

Brendan Hansknecht <brendan.hansknecht@gmail.com>
2019-11-03 05:37:53
Add karatsuba to big ints
1 parent 711520d
Changed files (1)
lib
std
math
lib/std/math/big/int.zig
@@ -766,13 +766,11 @@ pub const Int = struct {
             r.deinit();
         };
 
-        try r.ensureCapacity(a.len() + b.len());
+        try r.ensureCapacity(a.len() + b.len() + 1);
 
-        if (a.len() >= b.len()) {
-            llmul(r.limbs, a.limbs[0..a.len()], b.limbs[0..b.len()]);
-        } else {
-            llmul(r.limbs, b.limbs[0..b.len()], a.limbs[0..a.len()]);
-        }
+        mem.set(Limb, r.limbs[0 .. a.len() + b.len() + 1], 0);
+
+        try llmulacc(rma.allocator.?, r.limbs, a.limbs[0..a.len()], b.limbs[0..b.len()]);
 
         r.normalize(a.len() + b.len());
         r.setSign(a.isPositive() == b.isPositive());
@@ -780,6 +778,7 @@ pub const Int = struct {
 
     // a + b * c + *carry, sets carry to the overflow bits
     pub fn addMulLimbWithCarry(a: Limb, b: Limb, c: Limb, carry: *Limb) Limb {
+        @setRuntimeSafety(false);
         var r1: Limb = undefined;
 
         // r1 = a + *carry
@@ -800,25 +799,178 @@ pub const Int = struct {
         return r1;
     }
 
+    fn llmulDigit(acc: []Limb, y: []const Limb, xi: Limb) void {
+        @setRuntimeSafety(false);
+        if (xi == 0) {
+            return;
+        }
+
+        var carry: usize = 0;
+        var a_lo = acc[0..y.len];
+        var a_hi = acc[y.len..];
+
+        var j: usize = 0;
+        while (j < a_lo.len) : (j += 1) {
+            a_lo[j] = @inlineCall(addMulLimbWithCarry, a_lo[j], y[j], xi, &carry);
+        }
+
+        j = 0;
+        while ((carry != 0) and (j < a_hi.len)) : (j += 1) {
+            carry = @boolToInt(@addWithOverflow(Limb, a_hi[j], carry, &a_hi[j]));
+        }
+    }
+
     // Knuth 4.3.1, Algorithm M.
     //
     // r MUST NOT alias any of a or b.
-    fn llmul(r: []Limb, a: []const Limb, b: []const Limb) void {
+    fn llmulacc(allocator: *Allocator, r: []Limb, a: []const Limb, b: []const Limb) error{OutOfMemory}!void {
         @setRuntimeSafety(false);
-        debug.assert(a.len >= b.len);
-        debug.assert(r.len >= a.len + b.len);
 
-        mem.set(Limb, r[0 .. a.len + b.len], 0);
+        const a_norm = a[0..llnormalize(a)];
+        const b_norm = b[0..llnormalize(b)];
+        var x = a_norm;
+        var y = b_norm;
+        if (a_norm.len > b_norm.len) {
+            x = b_norm;
+            y = a_norm;
+        }
+
+        debug.assert(r.len >= x.len + y.len + 1);
+
+        // 48 is a pretty abitrary size chosen based on performance of a factorial program.
+        if (x.len <= 48) {
+            // Basecase multiplication
+            var i: usize = 0;
+            while (i < x.len) : (i += 1) {
+                llmulDigit(r[i..], y, x[i]);
+            }
+        } else {
+            // Karatsuba multiplication
+            const split = @divFloor(x.len, 2);
+            var x0 = x[0..split];
+            var x1 = x[split..x.len];
+            var y0 = y[0..split];
+            var y1 = y[split..y.len];
+
+            var tmp = try allocator.alloc(Limb, x1.len + y1.len + 1);
+            defer allocator.free(tmp);
+            mem.set(Limb, tmp, 0);
+
+            try llmulacc(allocator, tmp, x1, y1);
+
+            var length = llnormalize(tmp);
+            _ = llaccum(r[split..], tmp[0..length]);
+            _ = llaccum(r[split * 2 ..], tmp[0..length]);
+
+            mem.set(Limb, tmp[0..length], 0);
+
+            try llmulacc(allocator, tmp, x0, y0);
+
+            length = llnormalize(tmp);
+            _ = llaccum(r[0..], tmp[0..length]);
+            _ = llaccum(r[split..], tmp[0..length]);
+
+            const x_cmp = llcmp(x1, x0);
+            const y_cmp = llcmp(y1, y0);
+            if (x_cmp * y_cmp == 0) {
+                return;
+            }
+            const x0_len = llnormalize(x0);
+            const x1_len = llnormalize(x1);
+            var j0 = try allocator.alloc(Limb, math.max(x0_len, x1_len));
+            defer allocator.free(j0);
+            if (x_cmp == 1) {
+                llsub(j0, x1[0..x1_len], x0[0..x0_len]);
+            } else {
+                llsub(j0, x0[0..x0_len], x1[0..x1_len]);
+            }
+
+            const y0_len = llnormalize(y0);
+            const y1_len = llnormalize(y1);
+            var j1 = try allocator.alloc(Limb, math.max(y0_len, y1_len));
+            defer allocator.free(j1);
+            if (y_cmp == 1) {
+                llsub(j1, y1[0..y1_len], y0[0..y0_len]);
+            } else {
+                llsub(j1, y0[0..y0_len], y1[0..y1_len]);
+            }
+            const j0_len = llnormalize(j0);
+            const j1_len = llnormalize(j1);
+            if (x_cmp == y_cmp) {
+                mem.set(Limb, tmp[0..length], 0);
+                try llmulacc(allocator, tmp, j0, j1);
+
+                length = Int.llnormalize(tmp);
+                llsub(r[split..], r[split..], tmp[0..length]);
+            } else {
+                try llmulacc(allocator, r[split..], j0, j1);
+            }
+        }
+    }
+
+    // r = r + a
+    fn llaccum(r: []Limb, a: []const Limb) Limb {
+        @setRuntimeSafety(false);
+        debug.assert(r.len != 0 and a.len != 0);
+        debug.assert(r.len >= a.len);
 
         var i: usize = 0;
+        var carry: Limb = 0;
+
         while (i < a.len) : (i += 1) {
-            var carry: Limb = 0;
-            var j: usize = 0;
-            while (j < b.len) : (j += 1) {
-                r[i + j] = @inlineCall(addMulLimbWithCarry, r[i + j], a[i], b[j], &carry);
+            var c: Limb = 0;
+            c += @boolToInt(@addWithOverflow(Limb, r[i], a[i], &r[i]));
+            c += @boolToInt(@addWithOverflow(Limb, r[i], carry, &r[i]));
+            carry = c;
+        }
+
+        while ((carry != 0) and i < r.len) : (i += 1) {
+            carry = @boolToInt(@addWithOverflow(Limb, r[i], carry, &r[i]));
+        }
+
+        return carry;
+    }
+
+    /// Returns -1, 0, 1 if |a| < |b|, |a| == |b| or |a| > |b| respectively for limbs.
+    pub fn llcmp(a: []const Limb, b: []const Limb) i8 {
+        @setRuntimeSafety(false);
+        const a_len = llnormalize(a);
+        const b_len = llnormalize(b);
+        if (a_len < b_len) {
+            return -1;
+        }
+        if (a_len > b_len) {
+            return 1;
+        }
+
+        var i: usize = a_len - 1;
+        while (i != 0) : (i -= 1) {
+            if (a[i] != b[i]) {
+                break;
             }
-            r[i + j] = carry;
         }
+
+        if (a[i] < b[i]) {
+            return -1;
+        } else if (a[i] > b[i]) {
+            return 1;
+        } else {
+            return 0;
+        }
+    }
+
+    // returns the min length the limb could be.
+    fn llnormalize(a: []const Limb) usize {
+        @setRuntimeSafety(false);
+        var j = a.len;
+        while (j > 0) : (j -= 1) {
+            if (a[j - 1] != 0) {
+                break;
+            }
+        }
+
+        // Handle zero
+        return if (j != 0) j else 1;
     }
 
     /// q = a / b (rem r)