master
1//! Implements [ZIGNOR][1] (Jurgen A. Doornik, 2005, Nuffield College, Oxford).
2//!
3//! [1]: https://www.doornik.com/research/ziggurat.pdf
4//!
5//! rust/rand used as a reference;
6//!
7//! NOTE: This seems interesting but reference code is a bit hard to grok:
8//! https://sbarral.github.io/etf.
9
10const std = @import("../std.zig");
11const builtin = @import("builtin");
12const math = std.math;
13const Random = std.Random;
14
15pub fn next_f64(random: Random, comptime tables: ZigTable) f64 {
16 while (true) {
17 // We manually construct a float from parts as we can avoid an extra random lookup here by
18 // using the unused exponent for the lookup table entry.
19 const bits = random.int(u64);
20 const i = @as(usize, @as(u8, @truncate(bits)));
21
22 const u = blk: {
23 if (tables.is_symmetric) {
24 // Generate a value in the range [2, 4) and scale into [-1, 1)
25 const repr = ((0x3ff + 1) << 52) | (bits >> 12);
26 break :blk @as(f64, @bitCast(repr)) - 3.0;
27 } else {
28 // Generate a value in the range [1, 2) and scale into (0, 1)
29 const repr = (0x3ff << 52) | (bits >> 12);
30 break :blk @as(f64, @bitCast(repr)) - (1.0 - math.floatEps(f64) / 2.0);
31 }
32 };
33
34 const x = u * tables.x[i];
35 const test_x = if (tables.is_symmetric) @abs(x) else x;
36
37 // equivalent to |u| < tables.x[i+1] / tables.x[i] (or u < tables.x[i+1] / tables.x[i])
38 if (test_x < tables.x[i + 1]) {
39 return x;
40 }
41
42 if (i == 0) {
43 return tables.zero_case(random, u);
44 }
45
46 // equivalent to f1 + DRanU() * (f0 - f1) < 1
47 if (tables.f[i + 1] + (tables.f[i] - tables.f[i + 1]) * random.float(f64) < tables.pdf(x)) {
48 return x;
49 }
50 }
51}
52
53pub const ZigTable = struct {
54 r: f64,
55 x: [257]f64,
56 f: [257]f64,
57
58 // probability density function used as a fallback
59 pdf: fn (f64) f64,
60 // whether the distribution is symmetric
61 is_symmetric: bool,
62 // fallback calculation in the case we are in the 0 block
63 zero_case: fn (Random, f64) f64,
64};
65
66// zigNorInit
67pub fn ZigTableGen(
68 comptime is_symmetric: bool,
69 comptime r: f64,
70 comptime v: f64,
71 comptime f: fn (f64) f64,
72 comptime f_inv: fn (f64) f64,
73 comptime zero_case: fn (Random, f64) f64,
74) ZigTable {
75 var tables: ZigTable = undefined;
76
77 tables.is_symmetric = is_symmetric;
78 tables.r = r;
79 tables.pdf = f;
80 tables.zero_case = zero_case;
81
82 tables.x[0] = v / f(r);
83 tables.x[1] = r;
84
85 for (tables.x[2..256], 0..) |*entry, i| {
86 const last = tables.x[2 + i - 1];
87 entry.* = f_inv(v / last + f(last));
88 }
89 tables.x[256] = 0;
90
91 for (tables.f[0..], 0..) |*entry, i| {
92 entry.* = f(tables.x[i]);
93 }
94
95 return tables;
96}
97
98// N(0, 1)
99pub const NormDist = blk: {
100 @setEvalBranchQuota(30000);
101 break :blk ZigTableGen(true, norm_r, norm_v, norm_f, norm_f_inv, norm_zero_case);
102};
103
104pub const norm_r = 3.6541528853610088;
105pub const norm_v = 0.00492867323399;
106
107pub fn norm_f(x: f64) f64 {
108 return @exp(-x * x / 2.0);
109}
110pub fn norm_f_inv(y: f64) f64 {
111 return @sqrt(-2.0 * @log(y));
112}
113pub fn norm_zero_case(random: Random, u: f64) f64 {
114 var x: f64 = 1;
115 var y: f64 = 0;
116
117 while (-2.0 * y < x * x) {
118 x = @log(random.float(f64)) / norm_r;
119 y = @log(random.float(f64));
120 }
121
122 if (u < 0) {
123 return x - norm_r;
124 } else {
125 return norm_r - x;
126 }
127}
128
129test "normal dist smoke test" {
130 // Hardcode 0 as the seed because it's possible a seed exists that fails
131 // this test.
132 var prng = Random.DefaultPrng.init(0);
133 const random = prng.random();
134
135 var i: usize = 0;
136 while (i < 1000) : (i += 1) {
137 _ = random.floatNorm(f64);
138 }
139}
140
141// Exp(1)
142pub const ExpDist = blk: {
143 @setEvalBranchQuota(30000);
144 break :blk ZigTableGen(false, exp_r, exp_v, exp_f, exp_f_inv, exp_zero_case);
145};
146
147pub const exp_r = 7.69711747013104972;
148pub const exp_v = 0.0039496598225815571993;
149
150pub fn exp_f(x: f64) f64 {
151 return @exp(-x);
152}
153pub fn exp_f_inv(y: f64) f64 {
154 return -@log(y);
155}
156pub fn exp_zero_case(random: Random, _: f64) f64 {
157 return exp_r - @log(random.float(f64));
158}
159
160test "exp dist smoke test" {
161 var prng = Random.DefaultPrng.init(0);
162 const random = prng.random();
163
164 var i: usize = 0;
165 while (i < 1000) : (i += 1) {
166 _ = random.floatExp(f64);
167 }
168}
169
170test {
171 _ = NormDist;
172}