Commit e7d9d00ac8

Josh Wolfe <thejoshwolfe@gmail.com>
2018-09-27 06:35:38
overhaul api for getting random integers (#1578)
* rand api overhaul * no retry limits. instead documented a recommendation to call int(T) % len directly.
1 parent 1c26c2f
Changed files (1)
std
std/rand/index.zig
@@ -5,11 +5,11 @@
 // ```
 // var buf: [8]u8 = undefined;
 // try std.os.getRandomBytes(buf[0..]);
-// const seed = mem.readInt(buf[0..8], u64, builtin.Endian.Little);
+// const seed = mem.readIntLE(u64, buf[0..8]);
 //
 // var r = DefaultPrng.init(seed);
 //
-// const s = r.random.scalar(u64);
+// const s = r.random.int(u64);
 // ```
 //
 // TODO(tiehuis): Benchmark these against other reference implementations.
@@ -35,60 +35,117 @@ pub const Random = struct {
         r.fillFn(r, buf);
     }
 
-    /// Return a random integer/boolean type.
-    pub fn scalar(r: *Random, comptime T: type) T {
-        var rand_bytes: [@sizeOf(T)]u8 = undefined;
+    pub fn boolean(r: *Random) bool {
+        return r.int(u1) != 0;
+    }
+
+    /// Returns a random int `i` such that `0 <= i <= @maxValue(T)`.
+    /// `i` is evenly distributed.
+    pub fn int(r: *Random, comptime T: type) T {
+        const UnsignedT = @IntType(false, T.bit_count);
+        const ByteAlignedT = @IntType(false, @divTrunc(T.bit_count + 7, 8) * 8);
+
+        var rand_bytes: [@sizeOf(ByteAlignedT)]u8 = undefined;
         r.bytes(rand_bytes[0..]);
 
-        if (T == bool) {
-            return rand_bytes[0] & 0b1 == 0;
+        // use LE instead of native endian for better portability maybe?
+        // TODO: endian portability is pointless if the underlying prng isn't endian portable.
+        // TODO: document the endian portability of this library.
+        const byte_aligned_result = mem.readIntLE(ByteAlignedT, rand_bytes);
+        const unsigned_result = @truncate(UnsignedT, byte_aligned_result);
+        return @bitCast(T, unsigned_result);
+    }
+
+    /// Returns an evenly distributed random unsigned integer `0 <= i < less_than`.
+    /// This function assumes that the underlying ::fillFn produces evenly distributed values.
+    /// Within this assumption, the runtime of this function is exponentially distributed.
+    /// If ::fillFn were backed by a true random generator,
+    /// the runtime of this function would technically be unbounded.
+    /// However, if ::fillFn is backed by any evenly distributed pseudo random number generator,
+    /// this function is guaranteed to return.
+    /// If you need deterministic runtime bounds, consider instead using `r.int(T) % less_than`,
+    /// which will usually be biased toward smaller values.
+    pub fn uintLessThan(r: *Random, comptime T: type, less_than: T) T {
+        assert(T.is_signed == false);
+        assert(0 < less_than);
+
+        const last_group_size_minus_one: T = @maxValue(T) % less_than;
+        if (last_group_size_minus_one == less_than - 1) {
+            // less_than is a power of two.
+            assert(math.floorPowerOfTwo(T, less_than) == less_than);
+            // There is no retry zone. The optimal retry_zone_start would be @maxValue(T) + 1.
+            return r.int(T) % less_than;
+        }
+        const retry_zone_start = @maxValue(T) - last_group_size_minus_one;
+
+        while (true) {
+            const rand_val = r.int(T);
+            if (rand_val < retry_zone_start) {
+                return rand_val % less_than;
+            }
+        }
+    }
+
+    /// Returns an evenly distributed random unsigned integer `0 <= i <= at_most`.
+    /// See ::uintLessThan, which this function uses in most cases,
+    /// for commentary on the runtime of this function.
+    pub fn uintAtMost(r: *Random, comptime T: type, at_most: T) T {
+        assert(T.is_signed == false);
+        if (at_most == @maxValue(T)) {
+            // have the full range
+            return r.int(T);
+        }
+        return r.uintLessThan(T, at_most + 1);
+    }
+
+    /// Returns an evenly distributed random integer `at_least <= i < less_than`.
+    /// See ::uintLessThan, which this function uses in most cases,
+    /// for commentary on the runtime of this function.
+    pub fn intRangeLessThan(r: *Random, comptime T: type, at_least: T, less_than: T) T {
+        assert(at_least < less_than);
+        if (T.is_signed) {
+            // Two's complement makes this math pretty easy.
+            const UnsignedT = @IntType(false, T.bit_count);
+            const lo = @bitCast(UnsignedT, at_least);
+            const hi = @bitCast(UnsignedT, less_than);
+            const result = lo +% r.uintLessThan(UnsignedT, hi -% lo);
+            return @bitCast(T, result);
+        } else {
+            // The signed implementation would work fine, but we can use stricter arithmetic operators here.
+            return at_least + r.uintLessThan(T, less_than - at_least);
+        }
+    }
+
+    /// Returns an evenly distributed random integer `at_least <= i <= at_most`.
+    /// See ::uintLessThan, which this function uses in most cases,
+    /// for commentary on the runtime of this function.
+    pub fn intRangeAtMost(r: *Random, comptime T: type, at_least: T, at_most: T) T {
+        assert(at_least <= at_most);
+        if (T.is_signed) {
+            // Two's complement makes this math pretty easy.
+            const UnsignedT = @IntType(false, T.bit_count);
+            const lo = @bitCast(UnsignedT, at_least);
+            const hi = @bitCast(UnsignedT, at_most);
+            const result = lo +% r.uintAtMost(UnsignedT, hi -% lo);
+            return @bitCast(T, result);
         } else {
-            // NOTE: Cannot @bitCast array to integer type.
-            return mem.readInt(rand_bytes, T, builtin.Endian.Little);
+            // The signed implementation would work fine, but we can use stricter arithmetic operators here.
+            return at_least + r.uintAtMost(T, at_most - at_least);
         }
     }
 
+    /// Return a random integer/boolean type.
+    /// TODO: deprecated. use ::boolean or ::int instead.
+    pub fn scalar(r: *Random, comptime T: type) T {
+        if (T == bool) return r.boolean();
+        return r.int(T);
+    }
+
     /// Return a random integer with even distribution between `start`
     /// inclusive and `end` exclusive.  `start` must be less than `end`.
+    /// TODO: deprecated. renamed to ::intRangeLessThan
     pub fn range(r: *Random, comptime T: type, start: T, end: T) T {
-        assert(start < end);
-        if (T.is_signed) {
-            const uint = @IntType(false, T.bit_count);
-            if (start >= 0 and end >= 0) {
-                return @intCast(T, r.range(uint, @intCast(uint, start), @intCast(uint, end)));
-            } else if (start < 0 and end < 0) {
-                // Can't overflow because the range is over signed ints
-                return math.negateCast(r.range(uint, math.absCast(end), math.absCast(start)) + 1) catch unreachable;
-            } else if (start < 0 and end >= 0) {
-                const end_uint = @intCast(uint, end);
-                const total_range = math.absCast(start) + end_uint;
-                const value = r.range(uint, 0, total_range);
-                const result = if (value < end_uint) x: {
-                    break :x @intCast(T, value);
-                } else if (value == end_uint) x: {
-                    break :x start;
-                } else x: {
-                    // Can't overflow because the range is over signed ints
-                    break :x math.negateCast(value - end_uint) catch unreachable;
-                };
-                return result;
-            } else {
-                unreachable;
-            }
-        } else {
-            const total_range = end - start;
-            const leftover = @maxValue(T) % total_range;
-            const upper_bound = @maxValue(T) - leftover;
-            var rand_val_array: [@sizeOf(T)]u8 = undefined;
-
-            while (true) {
-                r.bytes(rand_val_array[0..]);
-                const rand_val = mem.readInt(rand_val_array, T, builtin.Endian.Little);
-                if (rand_val < upper_bound) {
-                    return start + (rand_val % total_range);
-                }
-            }
-        }
+        return r.intRangeLessThan(T, start, end);
     }
 
     /// Return a floating point value evenly distributed in the range [0, 1).
@@ -97,12 +154,12 @@ pub const Random = struct {
         // Note: The lowest mantissa bit is always set to 0 so we only use half the available range.
         switch (T) {
             f32 => {
-                const s = r.scalar(u32);
+                const s = r.int(u32);
                 const repr = (0x7f << 23) | (s >> 9);
                 return @bitCast(f32, repr) - 1.0;
             },
             f64 => {
-                const s = r.scalar(u64);
+                const s = r.int(u64);
                 const repr = (0x3ff << 52) | (s >> 12);
                 return @bitCast(f64, repr) - 1.0;
             },
@@ -142,12 +199,167 @@ pub const Random = struct {
 
         var i: usize = 0;
         while (i < buf.len - 1) : (i += 1) {
-            const j = r.range(usize, i, buf.len);
+            const j = r.intRangeLessThan(usize, i, buf.len);
             mem.swap(T, &buf[i], &buf[j]);
         }
     }
 };
 
+const SequentialPrng = struct {
+    const Self = @This();
+    random: Random,
+    next_value: u8,
+
+    pub fn init() Self {
+        return Self{
+            .random = Random{ .fillFn = fill },
+            .next_value = 0,
+        };
+    }
+
+    fn fill(r: *Random, buf: []u8) void {
+        const self = @fieldParentPtr(Self, "random", r);
+        for (buf) |*b| {
+            b.* = self.next_value;
+        }
+        self.next_value +%= 1;
+    }
+};
+
+test "Random int" {
+    testRandomInt();
+    comptime testRandomInt();
+}
+fn testRandomInt() void {
+    var r = SequentialPrng.init();
+
+    assert(r.random.int(u0) == 0);
+
+    r.next_value = 0;
+    assert(r.random.int(u1) == 0);
+    assert(r.random.int(u1) == 1);
+    assert(r.random.int(u2) == 2);
+    assert(r.random.int(u2) == 3);
+    assert(r.random.int(u2) == 0);
+
+    r.next_value = 0xff;
+    assert(r.random.int(u8) == 0xff);
+    r.next_value = 0x11;
+    assert(r.random.int(u8) == 0x11);
+
+    r.next_value = 0xff;
+    assert(r.random.int(u32) == 0xffffffff);
+    r.next_value = 0x11;
+    assert(r.random.int(u32) == 0x11111111);
+
+    r.next_value = 0xff;
+    assert(r.random.int(i32) == -1);
+    r.next_value = 0x11;
+    assert(r.random.int(i32) == 0x11111111);
+
+    r.next_value = 0xff;
+    assert(r.random.int(i8) == -1);
+    r.next_value = 0x11;
+    assert(r.random.int(i8) == 0x11);
+
+    r.next_value = 0xff;
+    assert(r.random.int(u33) == 0x1ffffffff);
+    r.next_value = 0xff;
+    assert(r.random.int(i1) == -1);
+    r.next_value = 0xff;
+    assert(r.random.int(i2) == -1);
+    r.next_value = 0xff;
+    assert(r.random.int(i33) == -1);
+}
+
+test "Random boolean" {
+    testRandomBoolean();
+    comptime testRandomBoolean();
+}
+fn testRandomBoolean() void {
+    var r = SequentialPrng.init();
+    assert(r.random.boolean() == false);
+    assert(r.random.boolean() == true);
+    assert(r.random.boolean() == false);
+    assert(r.random.boolean() == true);
+}
+
+test "Random intLessThan" {
+    @setEvalBranchQuota(10000);
+    testRandomIntLessThan();
+    comptime testRandomIntLessThan();
+}
+fn testRandomIntLessThan() void {
+    var r = SequentialPrng.init();
+    r.next_value = 0xff;
+    assert(r.random.uintLessThan(u8, 4) == 3);
+    r.next_value = 0xff;
+    assert(r.random.uintLessThan(u8, 3) == 0);
+    assert(r.next_value == 1);
+
+    r.next_value = 0xff;
+    assert(r.random.intRangeLessThan(u8, 0, 0x80) == 0x7f);
+    r.next_value = 0xff;
+    assert(r.random.intRangeLessThan(u8, 0x7f, 0xff) == 0xfe);
+
+    r.next_value = 0xff;
+    assert(r.random.intRangeLessThan(i8, 0, 0x40) == 0x3f);
+    r.next_value = 0xff;
+    assert(r.random.intRangeLessThan(i8, -0x40, 0x40) == 0x3f);
+    r.next_value = 0xff;
+    assert(r.random.intRangeLessThan(i8, -0x80, 0) == -1);
+
+    r.next_value = 0xff;
+    assert(r.random.intRangeLessThan(i64, -0x8000000000000000, 0) == -1);
+    r.next_value = 0xff;
+    assert(r.random.intRangeLessThan(i3, -4, 0) == -1);
+    r.next_value = 0xff;
+    assert(r.random.intRangeLessThan(i3, -2, 2) == 1);
+
+    // test retrying and eventually getting a good value
+    // start just out of bounds
+    r.next_value = 0x81;
+    assert(r.random.uintLessThan(u8, 0x81) == 0);
+}
+
+test "Random intAtMost" {
+    @setEvalBranchQuota(10000);
+    testRandomIntAtMost();
+    comptime testRandomIntAtMost();
+}
+fn testRandomIntAtMost() void {
+    var r = SequentialPrng.init();
+    r.next_value = 0xff;
+    assert(r.random.uintAtMost(u8, 3) == 3);
+    r.next_value = 0xff;
+    assert(r.random.uintAtMost(u8, 2) == 0);
+    assert(r.next_value == 1);
+
+    r.next_value = 0xff;
+    assert(r.random.intRangeAtMost(u8, 0, 0x7f) == 0x7f);
+    r.next_value = 0xff;
+    assert(r.random.intRangeAtMost(u8, 0x7f, 0xfe) == 0xfe);
+
+    r.next_value = 0xff;
+    assert(r.random.intRangeAtMost(i8, 0, 0x3f) == 0x3f);
+    r.next_value = 0xff;
+    assert(r.random.intRangeAtMost(i8, -0x40, 0x3f) == 0x3f);
+    r.next_value = 0xff;
+    assert(r.random.intRangeAtMost(i8, -0x80, -1) == -1);
+
+    r.next_value = 0xff;
+    assert(r.random.intRangeAtMost(i64, -0x8000000000000000, -1) == -1);
+    r.next_value = 0xff;
+    assert(r.random.intRangeAtMost(i3, -4, -1) == -1);
+    r.next_value = 0xff;
+    assert(r.random.intRangeAtMost(i3, -2, 1) == 1);
+
+    // test retrying and eventually getting a good value
+    // start just out of bounds
+    r.next_value = 0x81;
+    assert(r.random.uintAtMost(u8, 0x80) == 0);
+}
+
 // Generator to extend 64-bit seed values into longer sequences.
 //
 // The number of cycles is thus limited to 64-bits regardless of the engine, but this
@@ -622,17 +834,6 @@ test "Random float" {
     }
 }
 
-test "Random scalar" {
-    var prng = DefaultPrng.init(0);
-    const s = prng.random.scalar(u64);
-}
-
-test "Random bytes" {
-    var prng = DefaultPrng.init(0);
-    var buf: [2048]u8 = undefined;
-    prng.random.bytes(buf[0..]);
-}
-
 test "Random shuffle" {
     var prng = DefaultPrng.init(0);
 
@@ -664,16 +865,16 @@ test "Random range" {
     testRange(&prng.random, -4, 3);
     testRange(&prng.random, -4, -1);
     testRange(&prng.random, 10, 14);
-    // TODO: test that prng.random.range(1, 1) causes an assertion error
+    testRange(&prng.random, -0x80, 0x7f);
 }
 
-fn testRange(r: *Random, start: i32, end: i32) void {
-    const count = @intCast(usize, end - start);
-    var values_buffer = []bool{false} ** 20;
+fn testRange(r: *Random, start: i8, end: i8) void {
+    const count = @intCast(usize, i32(end) - i32(start));
+    var values_buffer = []bool{false} ** 0x100;
     const values = values_buffer[0..count];
     var i: usize = 0;
     while (i < count) {
-        const value = r.range(i32, start, end);
+        const value: i32 = r.intRangeLessThan(i8, start, end);
         const index = @intCast(usize, value - start);
         if (!values[index]) {
             i += 1;