Commit 87b7b31557

Robin Voetter <robin@voetter.nl>
2021-10-23 21:03:19
big ints: improve division
1 parent ee98d87
Changed files (3)
lib
std
src
lib/std/math/big/int.zig
@@ -33,7 +33,7 @@ pub fn calcToStringLimbsBufferLen(a_len: usize, base: u8) usize {
 }
 
 pub fn calcDivLimbsBufferLen(a_len: usize, b_len: usize) usize {
-    return calcMulLimbsBufferLen(a_len, b_len, 2) * 4;
+    return a_len + b_len + 4;
 }
 
 pub fn calcMulLimbsBufferLen(a_len: usize, b_len: usize, aliases: usize) usize {
@@ -760,8 +760,8 @@ pub const Mutable = struct {
     /// q may alias with a or b.
     ///
     /// Asserts there is enough memory to store q and r.
-    /// The upper bound for r limb count is a.limbs.len.
-    /// The upper bound for q limb count is given by `a.limbs.len + b.limbs.len + 1`.
+    /// The upper bound for r limb count is b.limbs.len.
+    /// The upper bound for q limb count is given by `a.limbs.len + b.limbs.len`.
     ///
     /// If `allocator` is provided, it will be used for temporary storage to improve
     /// multiplication performance. `error.OutOfMemory` is handled with a fallback algorithm.
@@ -773,19 +773,17 @@ pub const Mutable = struct {
         a: Const,
         b: Const,
         limbs_buffer: []Limb,
-        allocator: ?*Allocator,
     ) void {
-        div(q, r, a, b, limbs_buffer, allocator);
+        div(q, r, a, b, limbs_buffer);
 
         // Trunc -> Floor.
         if (a.positive and b.positive) return;
 
         if ((!q.positive or q.eqZero()) and !r.eqZero()) {
-            const one: Const = .{ .limbs = &[_]Limb{1}, .positive = true };
-            q.sub(q.toConst(), one);
+            q.addScalar(q.toConst(), -1);
         }
 
-        r.mulNoAlias(q.toConst(), b, allocator);
+        r.mulNoAlias(q.toConst(), b, null);
         r.sub(a, r.toConst());
     }
 
@@ -809,9 +807,8 @@ pub const Mutable = struct {
         a: Const,
         b: Const,
         limbs_buffer: []Limb,
-        allocator: ?*Allocator,
     ) void {
-        div(q, r, a, b, limbs_buffer, allocator);
+        div(q, r, a, b, limbs_buffer);
         r.positive = a.positive;
     }
 
@@ -1177,7 +1174,7 @@ pub const Mutable = struct {
     }
 
     /// Truncates by default.
-    fn div(quo: *Mutable, rem: *Mutable, a: Const, b: Const, limbs_buffer: []Limb, allocator: ?*Allocator) void {
+    fn div(quo: *Mutable, rem: *Mutable, a: Const, b: Const, limbs_buffer: []Limb) void {
         assert(!b.eqZero()); // division by zero
         assert(quo != rem); // illegal aliasing
 
@@ -1220,11 +1217,9 @@ pub const Mutable = struct {
             rem.positive = true;
         } else {
             // x and y are modified during division
-            const sep_len = calcMulLimbsBufferLen(a.limbs.len, b.limbs.len, 2);
-            const x_limbs = limbs_buffer[0 * sep_len ..][0..sep_len];
-            const y_limbs = limbs_buffer[1 * sep_len ..][0..sep_len];
-            const t_limbs = limbs_buffer[2 * sep_len ..][0..sep_len];
-            const mul_limbs_buf = limbs_buffer[3 * sep_len ..][0..sep_len];
+            const sep_len = a.limbs.len + 2;
+            const x_limbs = limbs_buffer[0 .. sep_len];
+            const y_limbs = limbs_buffer[sep_len..];
 
             var x: Mutable = .{
                 .limbs = x_limbs,
@@ -1238,119 +1233,159 @@ pub const Mutable = struct {
             };
 
             // Shrink x, y such that the trailing zero limbs shared between are removed.
-            mem.copy(Limb, x.limbs, a.limbs[ab_zero_limb_count..a.limbs.len]);
-            mem.copy(Limb, y.limbs, b.limbs[ab_zero_limb_count..b.limbs.len]);
+            mem.copy(Limb, x.limbs, a.limbs[ab_zero_limb_count..]);
+            mem.copy(Limb, y.limbs, b.limbs[ab_zero_limb_count..]);
 
-            divN(quo, rem, &x, &y, t_limbs, mul_limbs_buf, allocator);
+            divmod(quo, rem, &x, &y);
             quo.positive = (a.positive == b.positive);
         }
 
         if (ab_zero_limb_count != 0) {
-            rem.shiftLeft(rem.toConst(), ab_zero_limb_count * limb_bits);
+            // Manually shift here since we know its limb aligned.
+            mem.copyBackwards(Limb, rem.limbs[ab_zero_limb_count..], rem.limbs[0..rem.len]);
+            mem.set(Limb, rem.limbs[0..ab_zero_limb_count], 0);
+            rem.len += ab_zero_limb_count;
         }
     }
 
     /// Handbook of Applied Cryptography, 14.20
     ///
     /// x = qy + r where 0 <= r < y
-    fn divN(
+    fn divmod(
         q: *Mutable,
         r: *Mutable,
         x: *Mutable,
         y: *Mutable,
-        tmp_limbs: []Limb,
-        mul_limb_buf: []Limb,
-        allocator: ?*Allocator,
     ) void {
-        assert(y.len >= 2);
-        assert(x.len >= y.len);
-        assert(q.limbs.len >= x.len + y.len - 1);
-
-        // See 3.2
-        var backup_tmp_limbs: [3]Limb = undefined;
-        const t_limbs = if (tmp_limbs.len < 3) &backup_tmp_limbs else tmp_limbs;
-
-        var tmp: Mutable = .{
-            .limbs = t_limbs,
-            .len = 1,
-            .positive = true,
-        };
-        tmp.limbs[0] = 0;
+        // 0.
+        // Normalize so that y[t] > b/2
+        const lz = @clz(Limb, y.limbs[y.len - 1]);
+        const norm_shift = if (lz == 0 and y.toConst().isOdd())
+            limb_bits // Force an extra limb so that y is even.
+        else
+            lz;
 
-        // Normalize so y > limb_bits / 2 (i.e. leading bit is set) and even
-        var norm_shift = @clz(Limb, y.limbs[y.len - 1]);
-        if (norm_shift == 0 and y.toConst().isOdd()) {
-            norm_shift = limb_bits;
-        }
         x.shiftLeft(x.toConst(), norm_shift);
         y.shiftLeft(y.toConst(), norm_shift);
 
         const n = x.len - 1;
         const t = y.len - 1;
+        const shift = n - t;
 
         // 1.
-        q.len = n - t + 1;
+        // for 0 <= j <= n - t, set q[j] to 0
+        q.len = shift + 1;
         q.positive = true;
         mem.set(Limb, q.limbs[0..q.len], 0);
 
         // 2.
-        tmp.shiftLeft(y.toConst(), limb_bits * (n - t));
-        while (x.toConst().order(tmp.toConst()) != .lt) {
-            q.limbs[n - t] += 1;
-            x.sub(x.toConst(), tmp.toConst());
+        // while x >= y * b^(n - t):
+        //    x -= y * b^(n - t)
+        //    q[n - t] += 1
+        // Note, this algorithm is performed only once if y[t] > radix/2 and y is even, which we
+        // enforced in step 0. This means we can replace the while with an if.
+        // Note, multiplication by b^(n - t) comes down to shifting to the right by n - t limbs.
+        // We can also replace x >= y * b^(n - t) by x/b^(n - t) >= y, and use shifts for that.
+        {
+            // x >= y * b^(n - t) can be replaced by x/b^(n - t) >= y.
+
+            // 'divide' x by b^(n - t)
+            var tmp = Mutable{
+                .limbs = x.limbs[shift..],
+                .len = x.len - shift,
+                .positive = true,
+            };
+
+            if (tmp.toConst().order(y.toConst()) != .lt) {
+                // Perform x -= y * b^(n - t)
+                // Note, we can subtract y from x[n - t..] and get the result without shifting.
+                // We can also re-use tmp which already contains the relevant part of x. Note that
+                // this also edits x.
+                // Due to the check above, this cannot underflow.
+                tmp.sub(tmp.toConst(), y.toConst());
+
+                // tmp.sub normalized tmp, but we need to normalize x now.
+                x.limbs.len = tmp.limbs.len + shift;
+
+                q.limbs[shift] += 1;
+            }
         }
 
         // 3.
+        // for i from n down to t + 1, do
         var i = n;
-        while (i > t) : (i -= 1) {
-            // 3.1
+        while (i >= t + 1) : (i -= 1) {
+            const k = i - t - 1;
+            // 3.1.
+            // if x_i == y_t:
+            //   q[i - t - 1] = b - 1
+            // else:
+            //   q[i - t - 1] = (x[i] * b + x[i - 1]) / y[t]
             if (x.limbs[i] == y.limbs[t]) {
-                q.limbs[i - t - 1] = maxInt(Limb);
+                q.limbs[k] = maxInt(Limb);
             } else {
-                const num = (@as(DoubleLimb, x.limbs[i]) << limb_bits) | @as(DoubleLimb, x.limbs[i - 1]);
-                const z = @intCast(Limb, num / @as(DoubleLimb, y.limbs[t]));
-                q.limbs[i - t - 1] = if (z > maxInt(Limb)) maxInt(Limb) else @as(Limb, z);
+                const q0 = (@as(DoubleLimb, x.limbs[i]) << limb_bits) | @as(DoubleLimb, x.limbs[i - 1]);
+                const n0 = @as(DoubleLimb, y.limbs[t]);
+                q.limbs[k] = @intCast(Limb, q0 / n0);
             }
 
             // 3.2
-            tmp.limbs[0] = if (i >= 2) x.limbs[i - 2] else 0;
-            tmp.limbs[1] = if (i >= 1) x.limbs[i - 1] else 0;
-            tmp.limbs[2] = x.limbs[i];
-            tmp.normalize(3);
+            // while q[i - t - 1] * (y[t] * b + y[t - 1] > x[i] * b * b + x[i - 1] + x[i - 2]:
+            //   q[i - t - 1] -= 1
+            // Note, if y[t] > b / 2 this part is repeated no more than twice.
+
+            // Extract from y.
+            const y0 = if (t > 0) y.limbs[t - 1] else 0;
+            const y1 = y.limbs[t];
+
+            // Extract from x.
+            // Note, big endian.
+            const tmp0 = [_]Limb{
+                x.limbs[i],
+                if (i >= 1) x.limbs[i - 1] else 0,
+                if (i >= 2) x.limbs[i - 2] else 0,
+            };
 
             while (true) {
-                // 2x1 limb multiplication unrolled against single-limb q[i-t-1]
-                var carry: Limb = 0;
-                r.limbs[0] = addMulLimbWithCarry(0, if (t >= 1) y.limbs[t - 1] else 0, q.limbs[i - t - 1], &carry);
-                r.limbs[1] = addMulLimbWithCarry(0, y.limbs[t], q.limbs[i - t - 1], &carry);
-                r.limbs[2] = carry;
-                r.normalize(3);
-
-                if (r.toConst().orderAbs(tmp.toConst()) != .gt) {
+                // Ad-hoc 2x1 multiplication with q[i - t - 1].
+                // Note, big endian.
+                var tmp1 = [_]Limb{0, undefined, undefined};
+                tmp1[2] = addMulLimbWithCarry(0, y0, q.limbs[k], &tmp1[0]);
+                tmp1[1] = addMulLimbWithCarry(0, y1, q.limbs[k], &tmp1[0]);
+
+                // Big-endian compare
+                if (mem.order(Limb, &tmp1, &tmp0) != .gt)
                     break;
-                }
 
-                q.limbs[i - t - 1] -= 1;
+                q.limbs[k] -= 1;
             }
 
-            // 3.3
-            tmp.set(q.limbs[i - t - 1]);
-            tmp.mul(tmp.toConst(), y.toConst(), mul_limb_buf, allocator);
-            tmp.shiftLeft(tmp.toConst(), limb_bits * (i - t - 1));
-            x.sub(x.toConst(), tmp.toConst());
-
-            if (!x.positive) {
-                tmp.shiftLeft(y.toConst(), limb_bits * (i - t - 1));
-                x.add(x.toConst(), tmp.toConst());
-                q.limbs[i - t - 1] -= 1;
+            // 3.3.
+            // x -= q[i - t - 1] * y * b^(i - t - 1)
+            // Note, we multiply by a single limb here.
+            // The shift doesn't need to be performed if we add the result of the first multiplication
+            // to x[i - t - 1].
+            // mem.set(Limb, x.limbs, 0);
+            const underflow = llmulLimb(.sub, x.limbs[k .. x.len], y.limbs[0 .. y.len], q.limbs[k]);
+
+            // 3.4.
+            // if x < 0:
+            //   x += y * b^(i - t - 1)
+            //   q[i - t - 1] -= 1
+            // Note, we check for x < 0 using the underflow flag from the previous operation.
+            if (underflow) {
+                // While we didn't properly set the signedness of x, this operation should 'flow' it back to positive.
+                llaccum(.add, x.limbs[k .. x.len], y.limbs[0 .. y.len]);
+                q.limbs[k] -= 1;
             }
+
+            x.normalize(x.len);
         }
 
-        // Denormalize
         q.normalize(q.len);
 
+        // De-normalize r.
         r.shiftRight(x.toConst(), norm_shift);
-        r.normalize(r.len);
     }
 
     /// Truncate an integer to a number of bits, following 2s-complement semantics.
@@ -1808,7 +1843,7 @@ pub const Const = struct {
             while (q.len >= 2) {
                 // Passing an allocator here would not be helpful since this division is destroying
                 // information, not creating it. [TODO citation needed]
-                q.divTrunc(&r, q.toConst(), b, rest_of_the_limbs_buf, null);
+                q.divTrunc(&r, q.toConst(), b, rest_of_the_limbs_buf);
 
                 var r_word = r.limbs[0];
                 var i: usize = 0;
@@ -2435,16 +2470,14 @@ pub const Managed = struct {
     /// a / b are floored (rounded towards 0).
     ///
     /// Returns an error if memory could not be allocated.
-    ///
-    /// q's allocator is used for temporary storage to speed up the multiplication.
     pub fn divFloor(q: *Managed, r: *Managed, a: Const, b: Const) !void {
-        try q.ensureCapacity(a.limbs.len + b.limbs.len + 1);
-        try r.ensureCapacity(a.limbs.len);
+        try q.ensureCapacity(a.limbs.len + b.limbs.len);
+        try r.ensureCapacity(b.limbs.len);
         var mq = q.toMutable();
         var mr = r.toMutable();
         const limbs_buffer = try q.allocator.alloc(Limb, calcDivLimbsBufferLen(a.limbs.len, b.limbs.len));
         defer q.allocator.free(limbs_buffer);
-        mq.divFloor(&mr, a, b, limbs_buffer, q.allocator);
+        mq.divFloor(&mr, a, b, limbs_buffer);
         q.setMetadata(mq.positive, mq.len);
         r.setMetadata(mr.positive, mr.len);
     }
@@ -2454,16 +2487,14 @@ pub const Managed = struct {
     /// a / b are truncated (rounded towards -inf).
     ///
     /// Returns an error if memory could not be allocated.
-    ///
-    /// q's allocator is used for temporary storage to speed up the multiplication.
     pub fn divTrunc(q: *Managed, r: *Managed, a: Const, b: Const) !void {
-        try q.ensureCapacity(a.limbs.len + b.limbs.len + 1);
-        try r.ensureCapacity(a.limbs.len);
+        try q.ensureCapacity(a.limbs.len + b.limbs.len);
+        try r.ensureCapacity(b.limbs.len);
         var mq = q.toMutable();
         var mr = r.toMutable();
         const limbs_buffer = try q.allocator.alloc(Limb, calcDivLimbsBufferLen(a.limbs.len, b.limbs.len));
         defer q.allocator.free(limbs_buffer);
-        mq.divTrunc(&mr, a, b, limbs_buffer, q.allocator);
+        mq.divTrunc(&mr, a, b, limbs_buffer);
         q.setMetadata(mq.positive, mq.len);
         r.setMetadata(mr.positive, mr.len);
     }
@@ -2893,20 +2924,22 @@ fn llmulaccLong(comptime op: AccOp, r: []Limb, a: []const Limb, b: []const Limb)
 
     var i: usize = 0;
     while (i < b.len) : (i += 1) {
-        llmulLimb(op, r[i..], a, b[i]);
+        _ = llmulLimb(op, r[i..], a, b[i]);
     }
 }
 
 /// r = r (op) y * xi
 /// The result is computed modulo `r.len`.
-fn llmulLimb(comptime op: AccOp, acc: []Limb, y: []const Limb, xi: Limb) void {
+/// Returns whether the operation overflowed.
+fn llmulLimb(comptime op: AccOp, acc: []Limb, y: []const Limb, xi: Limb) bool {
     @setRuntimeSafety(debug_safety);
     if (xi == 0) {
-        return;
+        return false;
     }
 
-    var a_lo = acc[0..y.len];
-    var a_hi = acc[y.len..];
+    const split = std.math.min(y.len, acc.len);
+    var a_lo = acc[0..split];
+    var a_hi = acc[split..];
 
     switch (op) {
         .add => {
@@ -2920,6 +2953,8 @@ fn llmulLimb(comptime op: AccOp, acc: []Limb, y: []const Limb, xi: Limb) void {
             while ((carry != 0) and (j < a_hi.len)) : (j += 1) {
                 carry = @boolToInt(@addWithOverflow(Limb, a_hi[j], carry, &a_hi[j]));
             }
+
+            return carry != 0;
         },
         .sub => {
             var borrow: Limb = 0;
@@ -2932,6 +2967,8 @@ fn llmulLimb(comptime op: AccOp, acc: []Limb, y: []const Limb, xi: Limb) void {
             while ((borrow != 0) and (j < a_hi.len)) : (j += 1) {
                 borrow = @boolToInt(@subWithOverflow(Limb, a_hi[j], borrow, &a_hi[j]));
             }
+
+            return borrow != 0;
         },
     }
 }
@@ -3424,7 +3461,8 @@ fn llsquareBasecase(r: []Limb, x: []const Limb) void {
 
     for (x_norm) |v, i| {
         // Accumulate all the x[i]*x[j] (with x!=j) products
-        llmulLimb(.add, r[2 * i + 1 ..], x_norm[i + 1 ..], v);
+        const overflow = llmulLimb(.add, r[2 * i + 1 ..], x_norm[i + 1 ..], v);
+        assert(!overflow);
     }
 
     // Each product appears twice, multiply by 2
@@ -3432,7 +3470,8 @@ fn llsquareBasecase(r: []Limb, x: []const Limb) void {
 
     for (x_norm) |v, i| {
         // Compute and add the squares
-        llmulLimb(.add, r[2 * i ..], x[i .. i + 1], v);
+        const overflow = llmulLimb(.add, r[2 * i ..], x[i .. i + 1], v);
+        assert(!overflow);
     }
 }
 
lib/std/math/big/int_test.zig
@@ -1016,7 +1016,7 @@ test "big.int mulWrap multi-multi unsigned" {
     defer c.deinit();
     try c.mulWrap(a.toConst(), b.toConst(), .unsigned, 65);
 
-    try testing.expect((try c.to(u256)) == (op1 * op2) & ((1 << 65) - 1));
+    try testing.expect((try c.to(u128)) == (op1 * op2) & ((1 << 65) - 1));
 }
 
 test "big.int mulWrap multi-multi signed" {
src/value.zig
@@ -2301,11 +2301,11 @@ pub const Value = extern union {
         const rhs_bigint = rhs.toBigInt(&rhs_space);
         const limbs_q = try allocator.alloc(
             std.math.big.Limb,
-            lhs_bigint.limbs.len + rhs_bigint.limbs.len + 1,
+            lhs_bigint.limbs.len + rhs_bigint.limbs.len,
         );
         const limbs_r = try allocator.alloc(
             std.math.big.Limb,
-            lhs_bigint.limbs.len,
+            rhs_bigint.limbs.len,
         );
         const limbs_buffer = try allocator.alloc(
             std.math.big.Limb,
@@ -2313,7 +2313,7 @@ pub const Value = extern union {
         );
         var result_q = BigIntMutable{ .limbs = limbs_q, .positive = undefined, .len = undefined };
         var result_r = BigIntMutable{ .limbs = limbs_r, .positive = undefined, .len = undefined };
-        result_q.divTrunc(&result_r, lhs_bigint, rhs_bigint, limbs_buffer, null);
+        result_q.divTrunc(&result_r, lhs_bigint, rhs_bigint, limbs_buffer);
         const result_limbs = result_q.limbs[0..result_q.len];
 
         if (result_q.positive) {
@@ -2332,11 +2332,11 @@ pub const Value = extern union {
         const rhs_bigint = rhs.toBigInt(&rhs_space);
         const limbs_q = try allocator.alloc(
             std.math.big.Limb,
-            lhs_bigint.limbs.len + rhs_bigint.limbs.len + 1,
+            lhs_bigint.limbs.len + rhs_bigint.limbs.len,
         );
         const limbs_r = try allocator.alloc(
             std.math.big.Limb,
-            lhs_bigint.limbs.len,
+            rhs_bigint.limbs.len,
         );
         const limbs_buffer = try allocator.alloc(
             std.math.big.Limb,
@@ -2344,7 +2344,7 @@ pub const Value = extern union {
         );
         var result_q = BigIntMutable{ .limbs = limbs_q, .positive = undefined, .len = undefined };
         var result_r = BigIntMutable{ .limbs = limbs_r, .positive = undefined, .len = undefined };
-        result_q.divFloor(&result_r, lhs_bigint, rhs_bigint, limbs_buffer, null);
+        result_q.divFloor(&result_r, lhs_bigint, rhs_bigint, limbs_buffer);
         const result_limbs = result_q.limbs[0..result_q.len];
 
         if (result_q.positive) {
@@ -2363,13 +2363,13 @@ pub const Value = extern union {
         const rhs_bigint = rhs.toBigInt(&rhs_space);
         const limbs_q = try allocator.alloc(
             std.math.big.Limb,
-            lhs_bigint.limbs.len + rhs_bigint.limbs.len + 1,
+            lhs_bigint.limbs.len + rhs_bigint.limbs.len,
         );
         const limbs_r = try allocator.alloc(
             std.math.big.Limb,
-            // TODO: audit this size, and also consider reworking Sema to re-use Values rather than
+            // TODO: consider reworking Sema to re-use Values rather than
             // always producing new Value objects.
-            rhs_bigint.limbs.len + 1,
+            rhs_bigint.limbs.len,
         );
         const limbs_buffer = try allocator.alloc(
             std.math.big.Limb,
@@ -2377,7 +2377,7 @@ pub const Value = extern union {
         );
         var result_q = BigIntMutable{ .limbs = limbs_q, .positive = undefined, .len = undefined };
         var result_r = BigIntMutable{ .limbs = limbs_r, .positive = undefined, .len = undefined };
-        result_q.divTrunc(&result_r, lhs_bigint, rhs_bigint, limbs_buffer, null);
+        result_q.divTrunc(&result_r, lhs_bigint, rhs_bigint, limbs_buffer);
         const result_limbs = result_r.limbs[0..result_r.len];
 
         if (result_r.positive) {
@@ -2396,11 +2396,11 @@ pub const Value = extern union {
         const rhs_bigint = rhs.toBigInt(&rhs_space);
         const limbs_q = try allocator.alloc(
             std.math.big.Limb,
-            lhs_bigint.limbs.len + rhs_bigint.limbs.len + 1,
+            lhs_bigint.limbs.len + rhs_bigint.limbs.len,
         );
         const limbs_r = try allocator.alloc(
             std.math.big.Limb,
-            lhs_bigint.limbs.len,
+            rhs_bigint.limbs.len,
         );
         const limbs_buffer = try allocator.alloc(
             std.math.big.Limb,
@@ -2408,7 +2408,7 @@ pub const Value = extern union {
         );
         var result_q = BigIntMutable{ .limbs = limbs_q, .positive = undefined, .len = undefined };
         var result_r = BigIntMutable{ .limbs = limbs_r, .positive = undefined, .len = undefined };
-        result_q.divFloor(&result_r, lhs_bigint, rhs_bigint, limbs_buffer, null);
+        result_q.divFloor(&result_r, lhs_bigint, rhs_bigint, limbs_buffer);
         const result_limbs = result_r.limbs[0..result_r.len];
 
         if (result_r.positive) {