Commit c905ceb23c

Robin Voetter <robin@voetter.nl>
2021-10-24 02:39:56
big ints: fix divFloor
1 parent 87b7b31
Changed files (1)
lib
std
math
lib/std/math/big/int.zig
@@ -774,17 +774,114 @@ pub const Mutable = struct {
         b: Const,
         limbs_buffer: []Limb,
     ) void {
-        div(q, r, a, b, limbs_buffer);
+        const sep = a.limbs.len + 2;
+        var x = a.toMutable(limbs_buffer[0..sep]);
+        var y = b.toMutable(limbs_buffer[sep..]);
+
+        div(q, r, &x, &y);
+
+        // Note, `div` performs truncating division, which satisfies
+        // @divTrunc(a, b) * b + @rem(a, b) = a
+        // so r = a - @divTrunc(a, b) * b
+        // Note,  @rem(a, -b) = @rem(-b, a) = -@rem(a, b) = -@rem(-a, -b)
+        // For divTrunc, we want to perform
+        // @divFloor(a, b) * b + @mod(a, b) = a
+        // Note:
+        // @divFloor(-a, b)
+        // = @divFloor(a, -b)
+        // = -@divCeil(a, b)
+        // = -@divFloor(a + b - 1, b)
+        // = -@divTrunc(a + b - 1, b)
+
+        // Note (1):
+        // @divTrunc(a + b - 1, b) * b + @rem(a + b - 1, b) = a + b - 1
+        // = @divTrunc(a + b - 1, b) * b + @rem(a - 1, b) = a + b - 1
+        // = @divTrunc(a + b - 1, b) * b + @rem(a - 1, b) - b + 1 = a
+
+        if (a.positive and b.positive) {
+            // Positive-positive case, don't need to do anything.
+        } else if (a.positive and !b.positive) {
+            // a/-b -> q is negative, and so we need to fix flooring.
+            // Subtract one to make the division flooring.
+
+            // @divFloor(a, -b) * -b + @mod(a, -b) = a
+            // If b divides a exactly, we have @divFloor(a, -b) * -b = a
+            // Else, we have @divFloor(a, -b) * -b > a, so @mod(a, -b) becomes negative
+
+            // We have:
+            // @divFloor(a, -b) * -b + @mod(a, -b) = a
+            // = -@divTrunc(a + b - 1, b) * -b + @mod(a, -b) = a
+            // = @divTrunc(a + b - 1, b) * b + @mod(a, -b) = a
+
+            // Substitute a for (1):
+            // @divTrunc(a + b - 1, b) * b + @rem(a - 1, b) - b + 1 = @divTrunc(a + b - 1, b) * b + @mod(a, -b)
+            // Yields:
+            // @mod(a, -b) = @rem(a - 1, b) - b + 1
+            // Note that `r` holds @rem(a, b) at this point.
+            //
+            // If @rem(a, b) is not 0:
+            //   @rem(a - 1, b) = @rem(a, b) - 1
+            //   => @mod(a, -b) = @rem(a, b) - 1 - b + 1 = @rem(a, b) - b
+            // Else:
+            //   @rem(a - 1, b) = @rem(a + b - 1, b) = @rem(b - 1, b) = b - 1
+            //   => @mod(a, -b) = b - 1 - b + 1 = 0
+            if (!r.eqZero()) {
+                q.addScalar(q.toConst(), -1);
+                r.positive = true;
+                r.sub(r.toConst(), y.toConst().abs());
+            }
+        } else if (!a.positive and b.positive) {
+            // -a/b -> q is negative, and so we need to fix flooring.
+            // Subtract one to make the division flooring.
+
+            // @divFloor(-a, b) * b + @mod(-a, b) = a
+            // If b divides a exactly, we have @divFloor(-a, b) * b = -a
+            // Else, we have @divFloor(-a, b) * b < -a, so @mod(-a, b) becomes positive
+
+            // We have:
+            // @divFloor(-a, b) * b + @mod(-a, b) = -a
+            // = -@divTrunc(a + b - 1, b) * b + @mod(-a, b) = -a
+            // = @divTrunc(a + b - 1, b) * b - @mod(-a, b) = a
+
+            // Substitute a for (1):
+            // @divTrunc(a + b - 1, b) * b + @rem(a - 1, b) - b + 1 = @divTrunc(a + b - 1, b) * b - @mod(-a, b)
+            // Yields:
+            // @rem(a - 1, b) - b + 1 = -@mod(-a, b)
+            // => -@mod(-a, b) = @rem(a - 1, b) - b + 1
+            // => @mod(-a, b) = -(@rem(a - 1, b) - b + 1) = -@rem(a - 1, b) + b - 1
+            //
+            // If @rem(a, b) is not 0:
+            //   @rem(a - 1, b) = @rem(a, b) - 1
+            //   => @mod(-a, b) = -(@rem(a, b) - 1) + b - 1 = -@rem(a, b) + 1 + b - 1 = -@rem(a, b) + b
+            // Else :
+            //   @rem(a - 1, b) = b - 1
+            //   => @mod(-a, b) = -(b - 1) + b - 1 = 0
+            if (!r.eqZero()) {
+                q.addScalar(q.toConst(), -1);
+                r.positive = false;
+                r.add(r.toConst(), y.toConst().abs());
+            }
+        } else if (!a.positive and !b.positive) {
+            // a/b -> q is positive, don't need to do anything to fix flooring.
 
-        // Trunc -> Floor.
-        if (a.positive and b.positive) return;
+            // @divFloor(-a, -b) * -b + @mod(-a, -b) = -a
+            // If b divides a exactly, we have @divFloor(-a, -b) * -b = -a
+            // Else, we have @divFloor(-a, -b) * -b > -a, so @mod(-a, -b) becomes negative
 
-        if ((!q.positive or q.eqZero()) and !r.eqZero()) {
-            q.addScalar(q.toConst(), -1);
-        }
+            // We have:
+            // @divFloor(-a, -b) * -b + @mod(-a, -b) = -a
+            // = @divTrunc(a, b) * -b + @mod(-a, -b) = -a
+            // = @divTrunc(a, b) * b - @mod(-a, -b) = a
+
+            // We also have:
+            // @divTrunc(a, b) * b + @rem(a, b) = a
 
-        r.mulNoAlias(q.toConst(), b, null);
-        r.sub(a, r.toConst());
+            // Substitute a:
+            // @divTrunc(a, b) * b + @rem(a, b) = @divTrunc(a, b) * b - @mod(-a, -b)
+            // => @rem(a, b) = -@mod(-a, -b)
+            // => @mod(-a, -b) = -@rem(a, b)
+            r.positive = false;
+        }
     }
 
     /// q = a / b (rem r)
@@ -808,8 +905,11 @@ pub const Mutable = struct {
         b: Const,
         limbs_buffer: []Limb,
     ) void {
-        div(q, r, a, b, limbs_buffer);
-        r.positive = a.positive;
+        const sep = a.limbs.len + 2;
+        var x = a.toMutable(limbs_buffer[0..sep]);
+        var y = b.toMutable(limbs_buffer[sep..]);
+
+        div(q, r, &x, &y);
     }
 
     /// r = a << shift, in other words, r = a * 2^shift
@@ -1173,84 +1273,78 @@ pub const Mutable = struct {
         result.copy(x.toConst());
     }
 
-    /// Truncates by default.
-    fn div(quo: *Mutable, rem: *Mutable, a: Const, b: Const, limbs_buffer: []Limb) void {
-        assert(!b.eqZero()); // division by zero
-        assert(quo != rem); // illegal aliasing
+    // Truncates by default.
+    fn div(q: *Mutable, r: *Mutable, x: *Mutable, y: *Mutable) void {
+        assert(!y.eqZero()); // division by zero
+        assert(q != r); // illegal aliasing
 
-        if (a.orderAbs(b) == .lt) {
-            // quo may alias a so handle rem first
-            rem.copy(a);
-            rem.positive = a.positive == b.positive;
+        const q_positive = (x.positive == y.positive);
+        const r_positive = x.positive;
 
-            quo.positive = true;
-            quo.len = 1;
-            quo.limbs[0] = 0;
+        if (x.toConst().orderAbs(y.toConst()) == .lt) {
+            // q may alias x so handle r first.
+            r.copy(x.toConst());
+            r.positive = r_positive;
+
+            q.set(0);
             return;
         }
 
         // Handle trailing zero-words of divisor/dividend. These are not handled in the following
         // algorithms.
-        const a_zero_limb_count = blk: {
-            var i: usize = 0;
-            while (i < a.limbs.len) : (i += 1) {
-                if (a.limbs[i] != 0) break;
-            }
-            break :blk i;
-        };
-        const b_zero_limb_count = blk: {
-            var i: usize = 0;
-            while (i < b.limbs.len) : (i += 1) {
-                if (b.limbs[i] != 0) break;
-            }
-            break :blk i;
-        };
+        // Note, there must be a non-zero limb for either.
+        // const x_trailing = std.mem.indexOfScalar(Limb, x.limbs[0..x.len], 0).?;
+        // const y_trailing = std.mem.indexOfScalar(Limb, y.limbs[0..y.len], 0).?;
 
-        const ab_zero_limb_count = math.min(a_zero_limb_count, b_zero_limb_count);
+        const x_trailing = for (x.limbs[0..x.len]) |xi, i| {
+            if (xi != 0) break i;
+        } else unreachable;
 
-        if (b.limbs.len - ab_zero_limb_count == 1) {
-            lldiv1(quo.limbs[0..], &rem.limbs[0], a.limbs[ab_zero_limb_count..a.limbs.len], b.limbs[b.limbs.len - 1]);
-            quo.normalize(a.limbs.len - ab_zero_limb_count);
-            quo.positive = (a.positive == b.positive);
+        const y_trailing = for (y.limbs[0..y.len]) |yi, i| {
+            if (yi != 0) break i;
+        } else unreachable;
 
-            rem.len = 1;
-            rem.positive = true;
-        } else {
-            // x and y are modified during division
-            const sep_len = a.limbs.len + 2;
-            const x_limbs = limbs_buffer[0 .. sep_len];
-            const y_limbs = limbs_buffer[sep_len..];
+        const xy_trailing = math.min(x_trailing, y_trailing);
+
+        if (y.len - xy_trailing == 1) {
+            lldiv1(q.limbs, &r.limbs[0], x.limbs[xy_trailing..x.len], y.limbs[y.len - 1]);
+            q.normalize(x.len - xy_trailing);
+            q.positive = q_positive;
 
-            var x: Mutable = .{
-                .limbs = x_limbs,
+            r.len = 1;
+            r.positive = r_positive;
+        } else {
+            // Shrink x, y such that the trailing zero limbs shared between are removed.
+            var x0 = Mutable{
+                .limbs = x.limbs[xy_trailing..],
+                .len = x.len - xy_trailing,
                 .positive = true,
-                .len = a.limbs.len - ab_zero_limb_count,
             };
-            var y: Mutable = .{
-                .limbs = y_limbs,
+
+            var y0 = Mutable{
+                .limbs = y.limbs[xy_trailing..],
+                .len = y.len - xy_trailing,
                 .positive = true,
-                .len = b.limbs.len - ab_zero_limb_count,
             };
 
-            // Shrink x, y such that the trailing zero limbs shared between are removed.
-            mem.copy(Limb, x.limbs, a.limbs[ab_zero_limb_count..]);
-            mem.copy(Limb, y.limbs, b.limbs[ab_zero_limb_count..]);
+            divmod(q, r, &x0, &y0);
+            q.positive = q_positive;
 
-            divmod(quo, rem, &x, &y);
-            quo.positive = (a.positive == b.positive);
+            r.positive = r_positive;
         }
 
-        if (ab_zero_limb_count != 0) {
+        if (xy_trailing != 0) {
             // Manually shift here since we know its limb aligned.
-            mem.copyBackwards(Limb, rem.limbs[ab_zero_limb_count..], rem.limbs[0..rem.len]);
-            mem.set(Limb, rem.limbs[0..ab_zero_limb_count], 0);
-            rem.len += ab_zero_limb_count;
+            mem.copyBackwards(Limb, r.limbs[xy_trailing..], r.limbs[0..r.len]);
+            mem.set(Limb, r.limbs[0..xy_trailing], 0);
+            r.len += xy_trailing;
         }
     }
 
     /// Handbook of Applied Cryptography, 14.20
     ///
     /// x = qy + r where 0 <= r < y
+    /// y is modified but returned intact.
     fn divmod(
         q: *Mutable,
         r: *Mutable,
@@ -1349,7 +1443,7 @@ pub const Mutable = struct {
             while (true) {
                 // Ad-hoc 2x1 multiplication with q[i - t - 1].
                 // Note, big endian.
-                var tmp1 = [_]Limb{0, undefined, undefined};
+                var tmp1 = [_]Limb{ 0, undefined, undefined };
                 tmp1[2] = addMulLimbWithCarry(0, y0, q.limbs[k], &tmp1[0]);
                 tmp1[1] = addMulLimbWithCarry(0, y1, q.limbs[k], &tmp1[0]);
 
@@ -1366,7 +1460,7 @@ pub const Mutable = struct {
             // The shift doesn't need to be performed if we add the result of the first multiplication
             // to x[i - t - 1].
             // mem.set(Limb, x.limbs, 0);
-            const underflow = llmulLimb(.sub, x.limbs[k .. x.len], y.limbs[0 .. y.len], q.limbs[k]);
+            const underflow = llmulLimb(.sub, x.limbs[k..x.len], y.limbs[0..y.len], q.limbs[k]);
 
             // 3.4.
             // if x < 0:
@@ -1375,7 +1469,7 @@ pub const Mutable = struct {
             // Note, we check for x < 0 using the underflow flag from the previous operation.
             if (underflow) {
                 // While we didn't properly set the signedness of x, this operation should 'flow' it back to positive.
-                llaccum(.add, x.limbs[k .. x.len], y.limbs[0 .. y.len]);
+                llaccum(.add, x.limbs[k..x.len], y.limbs[0..y.len]);
                 q.limbs[k] -= 1;
             }
 
@@ -1384,8 +1478,9 @@ pub const Mutable = struct {
 
         q.normalize(q.len);
 
-        // De-normalize r.
+        // De-normalize r and y.
         r.shiftRight(x.toConst(), norm_shift);
+        y.shiftRight(y.toConst(), norm_shift);
     }
 
     /// Truncate an integer to a number of bits, following 2s-complement semantics.