Commit dbc11be038

LemonBoy <thatlemon@gmail.com>
2020-10-05 20:36:09
std: Fix two bugs in bigint pow
* Correctly scan all the exponent bits, this caused the incorrect result to be computed for exponents being powers of two. * Allocate enough limbs to make llmulacc stop whining.
1 parent 7f7e2d6
Changed files (2)
lib
std
lib/std/math/big/int.zig
@@ -59,8 +59,8 @@ pub fn calcSetStringLimbCount(base: u8, string_len: usize) usize {
 }
 
 pub fn calcPowLimbsBufferLen(a_bit_count: usize, y: usize) usize {
-    // The 1 accounts for the multiplication carry
-    return 1 + (a_bit_count * y + (limb_bits - 1)) / limb_bits;
+    // The 2 accounts for the minimum space requirement for llmulacc
+    return 2 + (a_bit_count * y + (limb_bits - 1)) / limb_bits;
 }
 
 /// a + b * c + *carry, sets carry to the overflow bits
@@ -2205,47 +2205,51 @@ fn llxor(r: []Limb, a: []const Limb, b: []const Limb) void {
 
 /// Knuth 4.6.3
 fn llpow(r: []Limb, a: []const Limb, b: u32, tmp_limbs: []Limb) void {
-    mem.copy(Limb, r, a);
-    mem.set(Limb, r[a.len..], 0);
+    var tmp1: []Limb = undefined;
+    var tmp2: []Limb = undefined;
 
     // Multiplication requires no aliasing between the operand and the result
     // variable, use the output limbs and another temporary set to overcome this
-    // limit.
-    // Note that the order is important in the code below.
-    var list = [_][]Limb{ r, tmp_limbs };
-    var index: usize = 0;
+    // limitation.
+    // The initial assignment makes the result end in `r` so an extra memory
+    // copy is saved, each 1 flips the index twice so it's a no-op so count the
+    // 0.
+    const b_leading_zeros = @intCast(u5, @clz(u32, b));
+    const exp_zeros = @popCount(u32, ~b) - b_leading_zeros;
+    if (exp_zeros & 1 != 0) {
+        tmp1 = tmp_limbs;
+        tmp2 = r;
+    } else {
+        tmp1 = r;
+        tmp2 = tmp_limbs;
+    }
+
+    const a_norm = a[0..llnormalize(a)];
+
+    mem.copy(Limb, tmp1, a_norm);
+    mem.set(Limb, tmp1[a_norm.len..], 0);
 
     // Scan the exponent as a binary number, from left to right, dropping the
-    // most significant bit set
-    var exp = @bitReverse(u32, b) >> (1 + @intCast(u5, @clz(u32, b)));
-    while (exp != 0) : (exp >>= 1) {
+    // most significant bit set.
+    const exp_bits = @intCast(u5, 31 - b_leading_zeros);
+    var exp = @bitReverse(u32, b) >> 1 + b_leading_zeros;
+
+    var i: u5 = 0;
+    while (i < exp_bits) : (i += 1) {
         // Square
         {
-            const cur_buf = list[index];
-            const cur_buf_len = llnormalize(cur_buf);
-            const cur_buf_out = list[index ^ 1];
-
-            mem.set(Limb, cur_buf_out, 0);
-            llmulacc(null, cur_buf_out, cur_buf[0..cur_buf_len], cur_buf[0..cur_buf_len]);
-
-            index ^= 1;
+            mem.set(Limb, tmp2, 0);
+            const op = tmp1[0..llnormalize(tmp1)];
+            llmulacc(null, tmp2, op, op);
+            mem.swap([]Limb, &tmp1, &tmp2);
         }
-
-        if ((exp & 1) != 0) {
-            // Multiply
-            const cur_buf = list[index];
-            const cur_buf_len = llnormalize(cur_buf);
-            const cur_buf_out = list[index ^ 1];
-
-            mem.set(Limb, cur_buf_out, 0);
-            llmulacc(null, cur_buf_out, cur_buf, a);
-
-            index ^= 1;
+        // Multiply by a
+        if (exp & 1 != 0) {
+            mem.set(Limb, tmp2, 0);
+            llmulacc(null, tmp2, tmp1[0..llnormalize(tmp1)], a_norm);
+            mem.swap([]Limb, &tmp1, &tmp2);
         }
-    }
-
-    if (index != 0) {
-        mem.copy(Limb, r, tmp_limbs);
+        exp >>= 1;
     }
 }
 
lib/std/math/big/int_test.zig
@@ -1482,6 +1482,13 @@ test "big.int const to managed" {
 }
 
 test "big.int pow" {
+    {
+        var a = try Managed.initSet(testing.allocator, 10);
+        defer a.deinit();
+
+        try a.pow(a, 8);
+        testing.expectEqual(@as(u32, 100000000), try a.to(u32));
+    }
     {
         var a = try Managed.initSet(testing.allocator, 10);
         defer a.deinit();