Commit 1e301b03a9

Andrew Kelley <superjoe30@gmail.com>
2017-06-01 00:23:56
change std.rand.Rand.rangeUnsigned to std.rand.Rand.range
and make it support signed integers
1 parent 1ae2002
Changed files (3)
example
guess_number
std
example/guess_number/main.zig
@@ -12,7 +12,7 @@ pub fn main() -> %void {
     const seed = std.mem.readInt(seed_bytes, usize, true);
     var rand = Rand.init(seed);
 
-    const answer = rand.rangeUnsigned(u8, 0, 100) + 1;
+    const answer = rand.range(u8, 0, 100) + 1;
 
     while (true) {
         %%io.stdout.printf("\nGuess a number between 1 and 100: ");
std/math.zig
@@ -41,6 +41,10 @@ pub fn sub(comptime T: type, a: T, b: T) -> %T {
     if (@subWithOverflow(T, a, b, &answer)) error.Overflow else answer
 }
 
+pub fn negate(x: var) -> %@typeOf(x) {
+    return sub(@typeOf(x), 0, x);
+}
+
 error Overflow;
 pub fn shl(comptime T: type, a: T, b: T) -> %T {
     var answer: T = undefined;
@@ -341,3 +345,51 @@ test "math.floor" {
     assert(floor(f64(999.0)) == 999.0);
     assert(floor(f64(-999.0)) == -999.0);
 }
+
+/// Returns the absolute value of the integer parameter.
+/// Result is an unsigned integer.
+pub fn absCast(x: var) -> @IntType(false, @typeOf(x).bit_count) {
+    const uint = @IntType(false, @typeOf(x).bit_count);
+    if (x >= 0)
+        return uint(x);
+
+    return uint(-(x + 1)) + 1;
+}
+
+test "math.absCast" {
+    assert(absCast(i32(-999)) == 999);
+    assert(@typeOf(absCast(i32(-999))) == u32);
+
+    assert(absCast(i32(999)) == 999);
+    assert(@typeOf(absCast(i32(999))) == u32);
+
+    assert(absCast(i32(@minValue(i32))) == -@minValue(i32));
+    assert(@typeOf(absCast(i32(@minValue(i32)))) == u32);
+}
+
+/// Returns the negation of the integer parameter.
+/// Result is a signed integer.
+error Overflow;
+pub fn negateCast(x: var) -> %@IntType(true, @typeOf(x).bit_count) {
+    if (@typeOf(x).is_signed)
+        return negate(x);
+
+    const int = @IntType(true, @typeOf(x).bit_count);
+    if (x > -@minValue(int))
+        return error.Overflow;
+
+    if (x == -@minValue(int))
+        return @minValue(int);
+
+    return -int(x);
+}
+
+test "math.negateCast" {
+    assert(%%negateCast(u32(999)) == -999);
+    assert(@typeOf(%%negateCast(u32(999))) == i32);
+
+    assert(%%negateCast(u32(-@minValue(i32))) == @minValue(i32));
+    assert(@typeOf(%%negateCast(u32(-@minValue(i32)))) == i32);
+
+    if (negateCast(u32(@maxValue(i32) + 10))) |_| unreachable else |err| assert(err == error.Overflow);
+}
std/rand.zig
@@ -1,6 +1,7 @@
 const assert = @import("debug.zig").assert;
 const rand_test = @import("rand_test.zig");
 const mem = @import("mem.zig");
+const math = @import("math.zig");
 
 pub const MT19937_32 = MersenneTwister(
     u32, 624, 397, 31,
@@ -63,18 +64,43 @@ pub const Rand = struct {
 
     /// Get a random unsigned integer with even distribution between `start`
     /// inclusive and `end` exclusive.
-    // TODO support signed integers and then rename to "range"
-    pub fn rangeUnsigned(r: &Rand, comptime T: type, start: T, end: T) -> T {
-        const range = end - start;
-        const leftover = @maxValue(T) % range;
-        const upper_bound = @maxValue(T) - leftover;
-        var rand_val_array: [@sizeOf(T)]u8 = undefined;
-
-        while (true) {
-            r.fillBytes(rand_val_array[0..]);
-            const rand_val = mem.readInt(rand_val_array, T, false);
-            if (rand_val < upper_bound) {
-                return start + (rand_val % range);
+    pub fn range(r: &Rand, comptime T: type, start: T, end: T) -> T {
+        assert(start <= end);
+        if (T.is_signed) {
+            const uint = @IntType(false, T.bit_count);
+            if (start >= 0 and end >= 0) {
+                return T(r.range(uint, uint(start), uint(end)));
+            } else if (start < 0 and end < 0) {
+                // Can't overflow because the range is over signed ints
+                return %%math.negateCast(r.range(uint, math.absCast(end), math.absCast(start)) + 1);
+            } else if (start < 0 and end >= 0) {
+                const end_uint = uint(end);
+                const total_range = math.absCast(start) + end_uint;
+                const value = r.range(uint, 0, total_range);
+                const result = if (value < end_uint) {
+                    T(value)
+                } else if (value == end_uint) {
+                    start
+                } else {
+                    // Can't overflow because the range is over signed ints
+                    %%math.negateCast(value - end_uint)
+                };
+                return result;
+            } else {
+                unreachable;
+            }
+        } else {
+            const total_range = end - start;
+            const leftover = @maxValue(T) % total_range;
+            const upper_bound = @maxValue(T) - leftover;
+            var rand_val_array: [@sizeOf(T)]u8 = undefined;
+
+            while (true) {
+                r.fillBytes(rand_val_array[0..]);
+                const rand_val = mem.readInt(rand_val_array, T, false);
+                if (rand_val < upper_bound) {
+                    return start + (rand_val % total_range);
+                }
             }
         }
     }
@@ -94,7 +120,7 @@ pub const Rand = struct {
         } else {
             @compileError("unknown floating point type")
         };
-        return T(r.rangeUnsigned(int_type, 0, precision)) / T(precision);
+        return T(r.range(int_type, 0, precision)) / T(precision);
     }
 };
 
@@ -175,16 +201,38 @@ test "rand float 32" {
     }
 }
 
-test "testMT19937_64" {
+test "rand.MT19937_64" {
     var rng = MT19937_64.init(rand_test.mt64_seed);
     for (rand_test.mt64_data) |value| {
         assert(value == rng.get());
     }
 }
 
-test "testMT19937_32" {
+test "rand.MT19937_32" {
     var rng = MT19937_32.init(rand_test.mt32_seed);
     for (rand_test.mt32_data) |value| {
         assert(value == rng.get());
     }
 }
+
+test "rand.Rand.range" {
+    var r = Rand.init(42);
+    testRange(&r, -4, 3);
+    testRange(&r, -4, -1);
+    testRange(&r, 10, 14);
+}
+
+fn testRange(r: &Rand, start: i32, end: i32) {
+    const count = usize(end - start);
+    var values_buffer = []bool{false} ** 20;
+    const values = values_buffer[0..count];
+    var i: usize = 0;
+    while (i < count) {
+        const value = r.range(i32, start, end);
+        const index = usize(value - start);
+        if (!values[index]) {
+            i += 1;
+            values[index] = true;
+        }
+    }
+}