Commit 1e301b03a9
example/guess_number/main.zig
@@ -12,7 +12,7 @@ pub fn main() -> %void {
const seed = std.mem.readInt(seed_bytes, usize, true);
var rand = Rand.init(seed);
- const answer = rand.rangeUnsigned(u8, 0, 100) + 1;
+ const answer = rand.range(u8, 0, 100) + 1;
while (true) {
%%io.stdout.printf("\nGuess a number between 1 and 100: ");
std/math.zig
@@ -41,6 +41,10 @@ pub fn sub(comptime T: type, a: T, b: T) -> %T {
if (@subWithOverflow(T, a, b, &answer)) error.Overflow else answer
}
+pub fn negate(x: var) -> %@typeOf(x) {
+ return sub(@typeOf(x), 0, x);
+}
+
error Overflow;
pub fn shl(comptime T: type, a: T, b: T) -> %T {
var answer: T = undefined;
@@ -341,3 +345,51 @@ test "math.floor" {
assert(floor(f64(999.0)) == 999.0);
assert(floor(f64(-999.0)) == -999.0);
}
+
+/// Returns the absolute value of the integer parameter.
+/// Result is an unsigned integer.
+pub fn absCast(x: var) -> @IntType(false, @typeOf(x).bit_count) {
+ const uint = @IntType(false, @typeOf(x).bit_count);
+ if (x >= 0)
+ return uint(x);
+
+ return uint(-(x + 1)) + 1;
+}
+
+test "math.absCast" {
+ assert(absCast(i32(-999)) == 999);
+ assert(@typeOf(absCast(i32(-999))) == u32);
+
+ assert(absCast(i32(999)) == 999);
+ assert(@typeOf(absCast(i32(999))) == u32);
+
+ assert(absCast(i32(@minValue(i32))) == -@minValue(i32));
+ assert(@typeOf(absCast(i32(@minValue(i32)))) == u32);
+}
+
+/// Returns the negation of the integer parameter.
+/// Result is a signed integer.
+error Overflow;
+pub fn negateCast(x: var) -> %@IntType(true, @typeOf(x).bit_count) {
+ if (@typeOf(x).is_signed)
+ return negate(x);
+
+ const int = @IntType(true, @typeOf(x).bit_count);
+ if (x > -@minValue(int))
+ return error.Overflow;
+
+ if (x == -@minValue(int))
+ return @minValue(int);
+
+ return -int(x);
+}
+
+test "math.negateCast" {
+ assert(%%negateCast(u32(999)) == -999);
+ assert(@typeOf(%%negateCast(u32(999))) == i32);
+
+ assert(%%negateCast(u32(-@minValue(i32))) == @minValue(i32));
+ assert(@typeOf(%%negateCast(u32(-@minValue(i32)))) == i32);
+
+ if (negateCast(u32(@maxValue(i32) + 10))) |_| unreachable else |err| assert(err == error.Overflow);
+}
std/rand.zig
@@ -1,6 +1,7 @@
const assert = @import("debug.zig").assert;
const rand_test = @import("rand_test.zig");
const mem = @import("mem.zig");
+const math = @import("math.zig");
pub const MT19937_32 = MersenneTwister(
u32, 624, 397, 31,
@@ -63,18 +64,43 @@ pub const Rand = struct {
/// Get a random unsigned integer with even distribution between `start`
/// inclusive and `end` exclusive.
- // TODO support signed integers and then rename to "range"
- pub fn rangeUnsigned(r: &Rand, comptime T: type, start: T, end: T) -> T {
- const range = end - start;
- const leftover = @maxValue(T) % range;
- const upper_bound = @maxValue(T) - leftover;
- var rand_val_array: [@sizeOf(T)]u8 = undefined;
-
- while (true) {
- r.fillBytes(rand_val_array[0..]);
- const rand_val = mem.readInt(rand_val_array, T, false);
- if (rand_val < upper_bound) {
- return start + (rand_val % range);
+ pub fn range(r: &Rand, comptime T: type, start: T, end: T) -> T {
+ assert(start <= end);
+ if (T.is_signed) {
+ const uint = @IntType(false, T.bit_count);
+ if (start >= 0 and end >= 0) {
+ return T(r.range(uint, uint(start), uint(end)));
+ } else if (start < 0 and end < 0) {
+ // Can't overflow because the range is over signed ints
+ return %%math.negateCast(r.range(uint, math.absCast(end), math.absCast(start)) + 1);
+ } else if (start < 0 and end >= 0) {
+ const end_uint = uint(end);
+ const total_range = math.absCast(start) + end_uint;
+ const value = r.range(uint, 0, total_range);
+ const result = if (value < end_uint) {
+ T(value)
+ } else if (value == end_uint) {
+ start
+ } else {
+ // Can't overflow because the range is over signed ints
+ %%math.negateCast(value - end_uint)
+ };
+ return result;
+ } else {
+ unreachable;
+ }
+ } else {
+ const total_range = end - start;
+ const leftover = @maxValue(T) % total_range;
+ const upper_bound = @maxValue(T) - leftover;
+ var rand_val_array: [@sizeOf(T)]u8 = undefined;
+
+ while (true) {
+ r.fillBytes(rand_val_array[0..]);
+ const rand_val = mem.readInt(rand_val_array, T, false);
+ if (rand_val < upper_bound) {
+ return start + (rand_val % total_range);
+ }
}
}
}
@@ -94,7 +120,7 @@ pub const Rand = struct {
} else {
@compileError("unknown floating point type")
};
- return T(r.rangeUnsigned(int_type, 0, precision)) / T(precision);
+ return T(r.range(int_type, 0, precision)) / T(precision);
}
};
@@ -175,16 +201,38 @@ test "rand float 32" {
}
}
-test "testMT19937_64" {
+test "rand.MT19937_64" {
var rng = MT19937_64.init(rand_test.mt64_seed);
for (rand_test.mt64_data) |value| {
assert(value == rng.get());
}
}
-test "testMT19937_32" {
+test "rand.MT19937_32" {
var rng = MT19937_32.init(rand_test.mt32_seed);
for (rand_test.mt32_data) |value| {
assert(value == rng.get());
}
}
+
+test "rand.Rand.range" {
+ var r = Rand.init(42);
+ testRange(&r, -4, 3);
+ testRange(&r, -4, -1);
+ testRange(&r, 10, 14);
+}
+
+fn testRange(r: &Rand, start: i32, end: i32) {
+ const count = usize(end - start);
+ var values_buffer = []bool{false} ** 20;
+ const values = values_buffer[0..count];
+ var i: usize = 0;
+ while (i < count) {
+ const value = r.range(i32, start, end);
+ const index = usize(value - start);
+ if (!values[index]) {
+ i += 1;
+ values[index] = true;
+ }
+ }
+}