Commit dc1f698545

Robin Voetter <robin@voetter.nl>
2021-09-29 00:11:42
big ints: unify add/sub with their wrapping variants
1 parent a36ef84
Changed files (1)
lib
std
math
lib/std/math/big/int.zig
@@ -300,49 +300,55 @@ pub const Mutable = struct {
         return add(r, a, operand);
     }
 
-    /// r = a + b
-    ///
+    /// Base implementation for addition. Adds `max(a.limbs.len, b.limbs.len)` elements from a and b,
+    /// and returns whether any overflow occured.
     /// r, a and b may be aliases.
     ///
-    /// Asserts the result fits in `r`. An upper bound on the number of limbs needed by
-    /// r is `math.max(a.limbs.len, b.limbs.len) + 1`.
-    pub fn add(r: *Mutable, a: Const, b: Const) void {
+    /// Asserts r has enough elements to hold the result. The upper bound is `max(a.limbs.len, b.limbs.len)`.
+    fn addCarry(r: *Mutable, a: Const, b: Const) bool {
         if (a.eqZero()) {
             r.copy(b);
-            return;
+            return false;
         } else if (b.eqZero()) {
             r.copy(a);
-            return;
-        }
-
-        if (a.limbs.len == 1 and b.limbs.len == 1 and a.positive == b.positive) {
-            var o: Limb = undefined;
-            if (!@addWithOverflow(Limb, a.limbs[0], b.limbs[0], &o)) {
-                r.limbs[0] = o;
-                r.len = 1;
-                r.positive = a.positive;
-                return;
-            }
-        }
-
-        if (a.positive != b.positive) {
+            return false;
+        } else if (a.positive != b.positive) {
             if (a.positive) {
                 // (a) + (-b) => a - b
-                r.sub(a, b.abs());
+                return r.subCarry(a, b.abs());
             } else {
                 // (-a) + (b) => b - a
-                r.sub(b, a.abs());
+                return r.subCarry(b, a.abs());
             }
         } else {
+            r.positive = a.positive;
             if (a.limbs.len >= b.limbs.len) {
-                lladd(r.limbs[0..], a.limbs, b.limbs);
-                r.normalize(a.limbs.len + 1);
+                const c = lladdcarry(r.limbs, a.limbs, b.limbs);
+                r.normalize(a.limbs.len);
+                return c != 0;
             } else {
-                lladd(r.limbs[0..], b.limbs, a.limbs);
-                r.normalize(b.limbs.len + 1);
+                const c = lladdcarry(r.limbs, b.limbs, a.limbs);
+                r.normalize(b.limbs.len);
+                return c != 0;
             }
+        }
+    }
 
-            r.positive = a.positive;
+    /// r = a + b
+    ///
+    /// r, a and b may be aliases.
+    ///
+    /// Asserts the result fits in `r`. An upper bound on the number of limbs needed by
+    /// r is `math.max(a.limbs.len, b.limbs.len) + 1`.
+    pub fn add(r: *Mutable, a: Const, b: Const) void {
+        if (r.addCarry(a, b)) {
+            // Fix up the result. Note that addCarry normalizes by a.limbs.len or b.limbs.len,
+            // so we need to set the length here.
+            const msl = math.max(a.limbs.len, b.limbs.len);
+            // `[add|sub]Carry` normalizes by `msl`, so we need to fix up the result manually here.
+            // Note, the fact that it normalized means that the intermediary limbs are zero here.
+            r.len = msl + 1;
+            r.limbs[msl] = 1; // If this panics, there wasn't enough space in `r`.
         }
     }
 
@@ -354,37 +360,82 @@ pub const Mutable = struct {
     pub fn addWrap(r: *Mutable, a: Const, b: Const, signedness: std.builtin.Signedness, bit_count: usize) void {
         const req_limbs = calcTwosCompLimbCount(bit_count);
 
-        // We can ignore the upper bits here, those results will be discarded anyway.
-        const a_limbs = a.limbs[0..math.min(req_limbs, a.limbs.len)];
-        const b_limbs = b.limbs[0..math.min(req_limbs, b.limbs.len)];
+        // Slice of the upper bits if they exist, these will be ignored and allows us to use addCarry to determine
+        // if an overflow occured.
+        const x = Const{
+            .positive = a.positive,
+            .limbs = a.limbs[0..math.min(req_limbs, a.limbs.len)],
+        };
+
+        const y = Const{
+            .positive = b.positive,
+            .limbs = b.limbs[0..math.min(req_limbs, b.limbs.len)],
+        };
 
+        if (r.addCarry(x, y)) {
+            // There are two possibilities here:
+            // - We overflowed req_limbs. In this case, the carry is ignored.
+            // - a and b had less elements than req_limbs, and those were overflowed. This case needs to be handled.
+            const msl = math.max(a.limbs.len, b.limbs.len);
+            if (msl < req_limbs) {
+                r.limbs[msl] = 1;
+                r.len = req_limbs;
+            }
+        }
+
+        r.truncate(r.toConst(), signedness, bit_count);
+    }
+
+    /// Base implementation for subtraction. Subtracts `max(a.limbs.len, b.limbs.len)` elements from a and b,
+    /// and returns whether any overflow occured.
+    /// r, a and b may be aliases.
+    ///
+    /// Asserts r has enough elements to hold the result. The upper bound is `max(a.limbs.len, b.limbs.len)`.
+    fn subCarry(r: *Mutable, a: Const, b: Const) bool {
         if (a.eqZero()) {
             r.copy(b);
+            r.positive = !b.positive;
+            return false;
         } else if (b.eqZero()) {
             r.copy(a);
-        } else if (a.positive != b.positive) {
+            return false;
+        } if (a.positive != b.positive) {
             if (a.positive) {
-                // (a) + (-b) => a - b
-                r.subWrap(a, b.abs(), signedness, bit_count);
+                // (a) - (-b) => a + b
+                return r.addCarry(a, b.abs());
             } else {
-                // (-a) + (b) => b - a
-                r.subWrap(b, a.abs(), signedness, bit_count);
+                // (-a) - (b) => -a + -b
+                return r.addCarry(a, b.negate());
+            }
+        } else if (a.positive) {
+            if (a.order(b) != .lt) {
+                // (a) - (b) => a - b
+                const c = llsubcarry(r.limbs, a.limbs, b.limbs);
+                r.normalize(a.limbs.len);
+                r.positive = true;
+                return c != 0;
+            } else {
+                // (a) - (b) => -b + a => -(b - a)
+                const c = llsubcarry(r.limbs, b.limbs, a.limbs);
+                r.normalize(b.limbs.len);
+                r.positive = false;
+                return c != 0;
             }
-            // Don't need to truncate, subWrap does that for us.
-            return;
         } else {
-            if (a_limbs.len >= b_limbs.len) {
-                _ = lladdcarry(r.limbs, a_limbs, b_limbs);
-                r.normalize(a_limbs.len);
+            if (a.order(b) == .lt) {
+                // (-a) - (-b) => -(a - b)
+                const c = llsubcarry(r.limbs, a.limbs, b.limbs);
+                r.normalize(a.limbs.len);
+                r.positive = false;
+                return c != 0;
             } else {
-                _ = lladdcarry(r.limbs, b_limbs, b_limbs);
-                r.normalize(b_limbs.len);
+                // (-a) - (-b) => --b + -a => b - a
+                const c = llsubcarry(r.limbs, b.limbs, a.limbs);
+                r.normalize(b.limbs.len);
+                r.positive = true;
+                return c != 0;
             }
-
-            r.positive = a.positive;
         }
-
-        r.truncate(r.toConst(), signedness, bit_count);
     }
 
     /// r = a - b
@@ -394,39 +445,14 @@ pub const Mutable = struct {
     /// Asserts the result fits in `r`. An upper bound on the number of limbs needed by
     /// r is `math.max(a.limbs.len, b.limbs.len) + 1`. The +1 is not needed if both operands are positive.
     pub fn sub(r: *Mutable, a: Const, b: Const) void {
-        if (a.positive != b.positive) {
-            if (a.positive) {
-                // (a) - (-b) => a + b
-                r.add(a, b.abs());
-            } else {
-                // (-a) - (b) => -(a + b)
-                r.add(a.abs(), b);
-                r.positive = false;
-            }
-        } else {
-            if (a.positive) {
-                // (a) - (b) => a - b
-                if (a.order(b) != .lt) {
-                    llsub(r.limbs[0..], a.limbs[0..a.limbs.len], b.limbs[0..b.limbs.len]);
-                    r.normalize(a.limbs.len);
-                    r.positive = true;
-                } else {
-                    llsub(r.limbs[0..], b.limbs[0..b.limbs.len], a.limbs[0..a.limbs.len]);
-                    r.normalize(b.limbs.len);
-                    r.positive = false;
-                }
-            } else {
-                // (-a) - (-b) => -(a - b)
-                if (a.order(b) == .lt) {
-                    llsub(r.limbs[0..], a.limbs[0..a.limbs.len], b.limbs[0..b.limbs.len]);
-                    r.normalize(a.limbs.len);
-                    r.positive = false;
-                } else {
-                    llsub(r.limbs[0..], b.limbs[0..b.limbs.len], a.limbs[0..a.limbs.len]);
-                    r.normalize(b.limbs.len);
-                    r.positive = true;
-                }
-            }
+        if (r.subCarry(a, b)) {
+            // Fix up the result. Note that addCarry normalizes by a.limbs.len or b.limbs.len,
+            // so we need to set the length here.
+            const msl = math.max(a.limbs.len, b.limbs.len);
+            // `addCarry` normalizes by `msl`, so we need to fix up the result manually here.
+            // Note, the fact that it normalized means that the intermediary limbs are zero here.
+            r.len = msl + 1;
+            r.limbs[msl] = 1; // If this panics, there wasn't enough space in `r`.
         }
     }
 
@@ -438,45 +464,26 @@ pub const Mutable = struct {
     pub fn subWrap(r: *Mutable, a: Const, b: Const, signedness: std.builtin.Signedness, bit_count: usize) void {
         const req_limbs = calcTwosCompLimbCount(bit_count);
 
-        // We can ignore the upper bits here, those results will be discarded anyway.
-        // We also don't need to mind order here. Again, overflow is ignored here.
-        const a_limbs = a.limbs[0..math.min(req_limbs, a.limbs.len)];
-        const b_limbs = b.limbs[0..math.min(req_limbs, b.limbs.len)];
+        // Slice of the upper bits if they exist, these will be ignored and allows us to use addCarry to determine
+        // if an overflow occured.
+        const x = Const{
+            .positive = a.positive,
+            .limbs = a.limbs[0..math.min(req_limbs, a.limbs.len)],
+        };
 
-        if (a.positive != b.positive) {
-            if (a.positive) {
-                // (a) - (-b) => a + b
-                r.addWrap(a, b.abs(), signedness, bit_count);
-            } else {
-                // (-a) - (b) => -a + -b
-                // Note, we don't do -(a + b) here to avoid a second truncate.
-                r.addWrap(a, b.negate(), signedness, bit_count);
-            }
-            // Don't need to truncate, addWrap does that for us.
-            return;
-        } else if (a.positive) {
-            if (a_limbs.len >= b_limbs.len) {
-                // (a) - (b) => a - b
-                _ = llsubcarry(r.limbs, a_limbs, b_limbs);
-                r.normalize(a_limbs.len);
-                r.positive = true;
-            } else {
-                // (a) - (b) => -b + a => -(b - a)
-                _ = llsubcarry(r.limbs, b_limbs, a_limbs);
-                r.normalize(b_limbs.len);
-                r.positive = false;
-            }
-        } else {
-            if (a_limbs.len >= b_limbs.len) {
-                // (-a) - (-b) => -(a - b)
-                _ = llsubcarry(r.limbs, a_limbs, b_limbs);
-                r.normalize(a_limbs.len);
-                r.positive = false;
-            } else {
-                // (-a) - (-b) => --b + -a => b - a
-                _ = llsubcarry(r.limbs, b_limbs, a_limbs);
-                r.normalize(b_limbs.len);
-                r.positive = true;
+        const y = Const{
+            .positive = b.positive,
+            .limbs = b.limbs[0..math.min(req_limbs, b.limbs.len)],
+        };
+
+        if (r.subCarry(x, y)) {
+            // There are two possibilities here:
+            // - We overflowed req_limbs. In this case, the carry is ignored.
+            // - a and b had less elements than req_limbs, and those were overflowed. This case needs to be handled.
+            const msl = math.max(a.limbs.len, b.limbs.len);
+            if (msl < req_limbs) {
+                r.limbs[msl] = 1;
+                r.len = req_limbs;
             }
         }