Commit c7cb5c31e5

Marc Tiehuis <marctiehuis@gmail.com>
2018-04-14 11:08:49
Add exp/norm distributed random float generation
1 parent caefaf7
Changed files (3)
std/rand/index.zig
@@ -19,6 +19,7 @@ const builtin = @import("builtin");
 const assert = std.debug.assert;
 const mem = std.mem;
 const math = std.math;
+const ziggurat = @import("ziggurat.zig");
 
 // When you need fast unbiased random numbers
 pub const DefaultPrng = Xoroshiro128;
@@ -109,15 +110,28 @@ pub const Random = struct {
         }
     }
 
-    /// Return a floating point value normally distributed in the range [0, 1].
+    /// Return a floating point value normally distributed with mean = 0, stddev = 1.
+    ///
+    /// To use different parameters, use: floatNorm(...) * desiredStddev + desiredMean.
     pub fn floatNorm(r: &Random, comptime T: type) T {
-        // TODO(tiehuis): See https://www.doornik.com/research/ziggurat.pdf
-        @compileError("floatNorm is unimplemented");
+        const value = ziggurat.next_f64(r, ziggurat.NormDist);
+        switch (T) {
+            f32 => return f32(value),
+            f64 => return value,
+            else => @compileError("unknown floating point type"),
+        }
     }
 
-    /// Return a exponentially distributed float between (0, @maxValue(f64))
+    /// Return an exponentially distributed float with a rate parameter of 1.
+    ///
+    /// To use a different rate parameter, use: floatExp(...) / desiredRate.
     pub fn floatExp(r: &Random, comptime T: type) T {
-        @compileError("floatExp is unimplemented");
+        const value = ziggurat.next_f64(r, ziggurat.ExpDist);
+        switch (T) {
+            f32 => return f32(value),
+            f64 => return value,
+            else => @compileError("unknown floating point type"),
+        }
     }
 
     /// Shuffle a slice into a random order.
std/rand/ziggurat.zig
@@ -0,0 +1,146 @@
+// Implements ZIGNOR [1].
+//
+// [1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to Generate Normal Random Samples*]
+// (https://www.doornik.com/research/ziggurat.pdf). Nuffield College, Oxford.
+//
+// rust/rand used as a reference;
+//
+// NOTE: This seems interesting but reference code is a bit hard to grok:
+// https://sbarral.github.io/etf.
+
+const std = @import("../index.zig");
+const math = std.math;
+const Random = std.rand.Random;
+
+pub fn next_f64(random: &Random, comptime tables: &const ZigTable) f64 {
+    while (true) {
+        // We manually construct a float from parts as we can avoid an extra random lookup here by
+        // using the unused exponent for the lookup table entry.
+        const bits = random.scalar(u64);
+        const i = usize(bits & 0xff);
+
+        const u = blk: {
+            if (tables.is_symmetric) {
+                // Generate a value in the range [2, 4) and scale into [-1, 1)
+                const repr = ((0x3ff + 1) << 52) | (bits >> 12);
+                break :blk @bitCast(f64, repr) - 3.0;
+            } else {
+                // Generate a value in the range [1, 2) and scale into (0, 1)
+                const repr = (0x3ff << 52) | (bits >> 12);
+                break :blk @bitCast(f64, repr) - (1.0 - math.f64_epsilon / 2.0);
+            }
+        };
+
+        const x = u * tables.x[i];
+        const test_x = if (tables.is_symmetric) math.fabs(x) else x;
+
+        // equivalent to |u| < tables.x[i+1] / tables.x[i] (or u < tables.x[i+1] / tables.x[i])
+        if (test_x < tables.x[i + 1]) {
+            return x;
+        }
+
+        if (i == 0) {
+            return tables.zero_case(random, u);
+        }
+
+        // equivalent to f1 + DRanU() * (f0 - f1) < 1
+        if (tables.f[i + 1] + (tables.f[i] - tables.f[i + 1]) * random.float(f64) < tables.pdf(x)) {
+            return x;
+        }
+    }
+}
+
+pub const ZigTable = struct {
+    r: f64,
+    x: [257]f64,
+    f: [257]f64,
+
+    // probability density function used as a fallback
+    pdf: fn(f64) f64,
+    // whether the distribution is symmetric
+    is_symmetric: bool,
+    // fallback calculation in the case we are in the 0 block
+    zero_case: fn(&Random, f64) f64,
+};
+
+// zigNorInit
+fn ZigTableGen(comptime is_symmetric: bool, comptime r: f64, comptime v: f64, comptime f: fn(f64) f64,
+       comptime f_inv: fn(f64) f64, comptime zero_case: fn(&Random, f64) f64) ZigTable {
+    var tables: ZigTable = undefined;
+
+    tables.is_symmetric = is_symmetric;
+    tables.r = r;
+    tables.pdf = f;
+    tables.zero_case = zero_case;
+
+    tables.x[0] = v / f(r);
+    tables.x[1] = r;
+
+    for (tables.x[2..256]) |*entry, i| {
+        const last = tables.x[2 + i - 1];
+        *entry = f_inv(v / last + f(last));
+    }
+    tables.x[256] = 0;
+
+    for (tables.f[0..]) |*entry, i| {
+        *entry = f(tables.x[i]);
+    }
+
+    return tables;
+}
+
+// N(0, 1)
+pub const NormDist = blk: {
+    @setEvalBranchQuota(30000);
+    break :blk ZigTableGen(true, norm_r, norm_v, norm_f, norm_f_inv, norm_zero_case);
+};
+
+const norm_r = 3.6541528853610088;
+const norm_v = 0.00492867323399;
+
+fn norm_f(x: f64) f64 { return math.exp(-x * x / 2.0); }
+fn norm_f_inv(y: f64) f64 { return math.sqrt(-2.0 * math.ln(y)); }
+fn norm_zero_case(random: &Random, u: f64) f64 {
+    var x: f64 = 1;
+    var y: f64 = 0;
+
+    while (-2.0 * y < x * x) {
+        x = math.ln(random.float(f64)) / norm_r;
+        y = math.ln(random.float(f64));
+    }
+
+    if (u < 0) {
+        return x - norm_r;
+    } else {
+        return norm_r - x;
+    }
+}
+
+test "ziggurant normal dist sanity" {
+    var prng = std.rand.DefaultPrng.init(0);
+    var i: usize = 0;
+    while (i < 1000) : (i += 1) {
+        _ = prng.random.floatNorm(f64);
+    }
+}
+
+// Exp(1)
+pub const ExpDist = blk: {
+    @setEvalBranchQuota(30000);
+    break :blk ZigTableGen(false, exp_r, exp_v, exp_f, exp_f_inv, exp_zero_case);
+};
+
+const exp_r = 7.69711747013104972;
+const exp_v = 0.0039496598225815571993;
+
+fn exp_f(x: f64) f64 { return math.exp(-x); }
+fn exp_f_inv(y: f64) f64 { return -math.ln(y); }
+fn exp_zero_case(random: &Random, _: f64) f64 { return exp_r - math.ln(random.float(f64)); }
+
+test "ziggurant exp dist sanity" {
+    var prng = std.rand.DefaultPrng.init(0);
+    var i: usize = 0;
+    while (i < 1000) : (i += 1) {
+        _ = prng.random.floatExp(f64);
+    }
+}
CMakeLists.txt
@@ -515,6 +515,7 @@ set(ZIG_STD_FILES
     "os/windows/util.zig"
     "os/zen.zig"
     "rand/index.zig"
+    "rand/ziggurat.zig"
     "sort.zig"
     "special/bootstrap.zig"
     "special/bootstrap_lib.zig"