Commit 1924ffa67d

Josh Wolfe <thejoshwolfe@gmail.com>
2018-11-21 23:33:37
better debiased random range implementation
1 parent 4d747d4
Changed files (1)
std
std/rand/index.zig
@@ -69,22 +69,36 @@ pub const Random = struct {
     pub fn uintLessThan(r: *Random, comptime T: type, less_than: T) T {
         assert(T.is_signed == false);
         assert(0 < less_than);
-
-        const last_group_size_minus_one: T = maxInt(T) % less_than;
-        if (last_group_size_minus_one == less_than - 1) {
-            // less_than is a power of two.
-            assert(math.floorPowerOfTwo(T, less_than) == less_than);
-            // There is no retry zone. The optimal retry_zone_start would be maxInt(T) + 1.
-            return r.int(T) % less_than;
-        }
-        const retry_zone_start = maxInt(T) - last_group_size_minus_one;
-
-        while (true) {
-            const rand_val = r.int(T);
-            if (rand_val < retry_zone_start) {
-                return rand_val % less_than;
+        // Small is typically u32
+        const Small = @IntType(false, @divTrunc(T.bit_count + 31, 32) * 32);
+        // Large is typically u64
+        const Large = @IntType(false, Small.bit_count * 2);
+
+        // adapted from:
+        //   http://www.pcg-random.org/posts/bounded-rands.html
+        //   "Lemire's (with an extra tweak from me)"
+        var x: Small = r.int(Small);
+        var m: Large = Large(x) * Large(less_than);
+        var l: Small = @truncate(Small, m);
+        if (l < less_than) {
+            // TODO: workaround for https://github.com/ziglang/zig/issues/1770
+            // should be:
+            //   var t: Small = -%less_than;
+            var t: Small = @bitCast(Small, -%@bitCast(@IntType(true, Small.bit_count), Small(less_than)));
+
+            if (t >= less_than) {
+                t -= less_than;
+                if (t >= less_than) {
+                    t %= less_than;
+                }
+            }
+            while (l < t) {
+                x = r.int(Small);
+                m = Large(x) * Large(less_than);
+                l = @truncate(Small, m);
             }
         }
+        return @intCast(T, m >> Small.bit_count);
     }
 
     /// Returns an evenly distributed random unsigned integer `0 <= i <= at_most`.
@@ -294,10 +308,19 @@ fn testRandomIntLessThan() void {
     var r = SequentialPrng.init();
     r.next_value = 0xff;
     assert(r.random.uintLessThan(u8, 4) == 3);
-    r.next_value = 0xff;
-    assert(r.random.uintLessThan(u8, 3) == 0);
+    assert(r.next_value == 0);
+    assert(r.random.uintLessThan(u8, 4) == 0);
     assert(r.next_value == 1);
 
+    r.next_value = 0;
+    assert(r.random.uintLessThan(u64, 32) == 0);
+
+    // trigger the bias rejection code path
+    r.next_value = 0;
+    assert(r.random.uintLessThan(u8, 3) == 0);
+    // verify we incremented twice
+    assert(r.next_value == 2);
+
     r.next_value = 0xff;
     assert(r.random.intRangeLessThan(u8, 0, 0x80) == 0x7f);
     r.next_value = 0xff;
@@ -310,17 +333,10 @@ fn testRandomIntLessThan() void {
     r.next_value = 0xff;
     assert(r.random.intRangeLessThan(i8, -0x80, 0) == -1);
 
-    r.next_value = 0xff;
-    assert(r.random.intRangeLessThan(i64, -0x8000000000000000, 0) == -1);
     r.next_value = 0xff;
     assert(r.random.intRangeLessThan(i3, -4, 0) == -1);
     r.next_value = 0xff;
     assert(r.random.intRangeLessThan(i3, -2, 2) == 1);
-
-    // test retrying and eventually getting a good value
-    // start just out of bounds
-    r.next_value = 0x81;
-    assert(r.random.uintLessThan(u8, 0x81) == 0);
 }
 
 test "Random intAtMost" {
@@ -332,9 +348,14 @@ fn testRandomIntAtMost() void {
     var r = SequentialPrng.init();
     r.next_value = 0xff;
     assert(r.random.uintAtMost(u8, 3) == 3);
-    r.next_value = 0xff;
+    assert(r.next_value == 0);
+    assert(r.random.uintAtMost(u8, 3) == 0);
+
+    // trigger the bias rejection code path
+    r.next_value = 0;
     assert(r.random.uintAtMost(u8, 2) == 0);
-    assert(r.next_value == 1);
+    // verify we incremented twice
+    assert(r.next_value == 2);
 
     r.next_value = 0xff;
     assert(r.random.intRangeAtMost(u8, 0, 0x7f) == 0x7f);
@@ -348,17 +369,10 @@ fn testRandomIntAtMost() void {
     r.next_value = 0xff;
     assert(r.random.intRangeAtMost(i8, -0x80, -1) == -1);
 
-    r.next_value = 0xff;
-    assert(r.random.intRangeAtMost(i64, -0x8000000000000000, -1) == -1);
     r.next_value = 0xff;
     assert(r.random.intRangeAtMost(i3, -4, -1) == -1);
     r.next_value = 0xff;
     assert(r.random.intRangeAtMost(i3, -2, 1) == 1);
-
-    // test retrying and eventually getting a good value
-    // start just out of bounds
-    r.next_value = 0x81;
-    assert(r.random.uintAtMost(u8, 0x80) == 0);
 }
 
 // Generator to extend 64-bit seed values into longer sequences.