Commit 550888e2ac

Mahdi Rakhshandehroo <24380301+mrakh@users.noreply.github.com>
2021-12-28 01:13:15
std: improve random float generation
1 parent e0a514d
Changed files (2)
lib
lib/std/rand/Dilbert.zig
@@ -0,0 +1,52 @@
+//! Dilbert PRNG
+//! Do not use this PRNG! It is meant to be predictable, for the purposes of test reproducibility and coverage. 
+//! Its output is just a repeat of a user-specified byte pattern.
+//! Name is a reference to this comic: https://dilbert.com/strip/2001-10-25
+
+const std = @import("std");
+const Random = std.rand.Random;
+const math = std.math;
+const Dilbert = @This();
+
+pattern: []const u8 = undefined,
+curr_idx: usize = 0,
+
+pub fn init(pattern: []const u8) !Dilbert {
+    if (pattern.len == 0)
+        return error.EmptyPattern;
+    var self = Dilbert{};
+    self.pattern = pattern;
+    self.curr_idx = 0;
+    return self;
+}
+
+pub fn random(self: *Dilbert) Random {
+    return Random.init(self, fill);
+}
+
+pub fn fill(self: *Dilbert, buf: []u8) void {
+    for (buf) |*byte| {
+        byte.* = self.pattern[self.curr_idx];
+        self.curr_idx = (self.curr_idx + 1) % self.pattern.len;
+    }
+}
+
+test "Dilbert fill" {
+    var r = try Dilbert.init("9nine");
+
+    const seq = [_]u64{
+        0x396E696E65396E69,
+        0x6E65396E696E6539,
+        0x6E696E65396E696E,
+        0x65396E696E65396E,
+        0x696E65396E696E65,
+    };
+
+    for (seq) |s| {
+        var buf0: [8]u8 = undefined;
+        var buf1: [8]u8 = undefined;
+        std.mem.writeIntBig(u64, &buf0, s);
+        r.fill(&buf1);
+        try std.testing.expect(std.mem.eql(u8, buf0[0..], buf1[0..]));
+    }
+}
lib/std/rand.zig
@@ -16,6 +16,8 @@ const math = std.math;
 const ziggurat = @import("rand/ziggurat.zig");
 const maxInt = std.math.maxInt;
 
+const Dilbert = @import("rand/Dilbert.zig");
+
 /// Fast unbiased random numbers.
 pub const DefaultPrng = Xoshiro256;
 
@@ -249,18 +251,51 @@ pub const Random = struct {
 
     /// Return a floating point value evenly distributed in the range [0, 1).
     pub fn float(r: Random, comptime T: type) T {
-        // Generate a uniform value between [1, 2) and scale down to [0, 1).
-        // Note: The lowest mantissa bit is always set to 0 so we only use half the available range.
+        // Generate a uniformly random value between for the mantissa.
+        // Then generate an exponentially biased random value for the exponent.
+        // Over the previous method, this has the advantage of being able to
+        // represent every possible value in the available range.
         switch (T) {
             f32 => {
-                const s = r.int(u32);
-                const repr = (0x7f << 23) | (s >> 9);
-                return @bitCast(f32, repr) - 1.0;
+                // Use 23 random bits for the mantissa, and the rest for the exponent.
+                // If all 41 bits are zero, generate additional random bits, until a
+                // set bit is found, or 126 bits have been generated.
+                const rand = r.int(u64);
+                var rand_lz = @clz(u64, rand | 0x7FFFFF);
+                if (rand_lz == 41) {
+                    rand_lz += @clz(u64, r.int(u64));
+                    if (rand_lz == 41 + 64) {
+                        // It is astronomically unlikely to reach this point.
+                        rand_lz += @clz(u32, r.int(u32) | 0x7FF);
+                    }
+                }
+                const mantissa = @truncate(u23, rand);
+                const exponent = @as(u32, 126 - rand_lz) << 23;
+                return @bitCast(f32, exponent | mantissa);
             },
             f64 => {
-                const s = r.int(u64);
-                const repr = (0x3ff << 52) | (s >> 12);
-                return @bitCast(f64, repr) - 1.0;
+                // Use 52 random bits for the mantissa, and the rest for the exponent.
+                // If all 12 bits are zero, generate additional random bits, until a
+                // set bit is found, or 1022 bits have been generated.
+                const rand = r.int(u64);
+                var rand_lz: u64 = @clz(u64, rand | 0xFFFFFFFFFFFFF);
+                if (rand_lz == 12) {
+                    while (true) {
+                        // It is astronomically unlikely for this loop to execute more than once.
+                        const addl_rand_lz = @clz(u64, r.int(u64));
+                        rand_lz += addl_rand_lz;
+                        if (addl_rand_lz != 64) {
+                            break;
+                        }
+                        if (rand_lz >= 1022) {
+                            rand_lz = 1022;
+                            break;
+                        }
+                    }
+                }
+                const mantissa = rand & 0xFFFFFFFFFFFFF;
+                const exponent = (1022 - rand_lz) << 52;
+                return @bitCast(f64, exponent | mantissa);
             },
             else => @compileError("unknown floating point type"),
         }
@@ -573,7 +608,7 @@ test "splitmix64 sequence" {
 }
 
 // Actual Random helper function tests, pcg engine is assumed correct.
-test "Random float" {
+test "Random float correctness" {
     var prng = DefaultPrng.init(0);
     const random = prng.random();
 
@@ -589,6 +624,81 @@ test "Random float" {
     }
 }
 
+// Check the "astronomically unlikely" code paths.
+test "Random float coverage" {
+    var prng = try Dilbert.init(&[_]u8{0});
+    const random = prng.random();
+
+    const rand_f64 = random.float(f64);
+    const rand_f32 = random.float(f32);
+
+    try expect(rand_f32 == 0.0);
+    try expect(rand_f64 == 0.0);
+}
+
+test "Random float chi-square goodness of fit" {
+    const num_numbers = 100000;
+    const num_buckets = 1000;
+
+    var f32_hist = std.AutoHashMap(u32, u32).init(std.testing.allocator);
+    defer f32_hist.deinit();
+    var f64_hist = std.AutoHashMap(u64, u32).init(std.testing.allocator);
+    defer f64_hist.deinit();
+
+    var prng = DefaultPrng.init(0);
+    const random = prng.random();
+
+    var i: usize = 0;
+    while (i < num_numbers) : (i += 1) {
+        const rand_f32 = random.float(f32);
+        const rand_f64 = random.float(f64);
+        var f32_put = try f32_hist.getOrPut(@floatToInt(u32, rand_f32 * @intToFloat(f32, num_buckets)));
+        if (f32_put.found_existing) {
+            f32_put.value_ptr.* += 1;
+        } else {
+            f32_put.value_ptr.* = 0;
+        }
+        var f64_put = try f64_hist.getOrPut(@floatToInt(u32, rand_f64 * @intToFloat(f64, num_buckets)));
+        if (f64_put.found_existing) {
+            f64_put.value_ptr.* += 1;
+        } else {
+            f64_put.value_ptr.* = 0;
+        }
+    }
+
+    var f32_total_variance: f64 = 0;
+    var f64_total_variance: f64 = 0;
+
+    {
+        var j: u32 = 0;
+        while (j < num_buckets) : (j += 1) {
+            const count = @intToFloat(f64, (if (f32_hist.get(j)) |v| v else 0));
+            const expected = @intToFloat(f64, num_numbers) / @intToFloat(f64, num_buckets);
+            const delta = count - expected;
+            const variance = (delta * delta) / expected;
+            f32_total_variance += variance;
+        }
+    }
+
+    {
+        var j: u64 = 0;
+        while (j < num_buckets) : (j += 1) {
+            const count = @intToFloat(f64, (if (f64_hist.get(j)) |v| v else 0));
+            const expected = @intToFloat(f64, num_numbers) / @intToFloat(f64, num_buckets);
+            const delta = count - expected;
+            const variance = (delta * delta) / expected;
+            f64_total_variance += variance;
+        }
+    }
+
+    // Corresponds to a p-value > 0.05.
+    // Critical value is calculated by opening a Python interpreter and running:
+    // scipy.stats.chi2.isf(0.05, num_buckets - 1)
+    const critical_value = 1073.6426506574246;
+    try expect(f32_total_variance < critical_value);
+    try expect(f64_total_variance < critical_value);
+}
+
 test "Random shuffle" {
     var prng = DefaultPrng.init(0);
     const random = prng.random();