Commit 5bb8c03697

Justin Whear <justin.whear@gmail.com>
2022-08-28 13:19:51
std.random: add weightedIndex function
`weightedIndex` picks from a selection of weighted indices.
1 parent 0f27836
Changed files (2)
lib
lib/std/rand/test.zig
@@ -445,3 +445,29 @@ test "CSPRNG" {
     const c = random.int(u64);
     try expect(a ^ b ^ c != 0);
 }
+
+test "Random weightedIndex" {
+    // Make sure weightedIndex works for various integers and floats
+    inline for (.{ u64, i4, f32, f64 }) |T| {
+        var prng = DefaultPrng.init(0);
+        const random = prng.random();
+
+        var proportions = [_]T{ 2, 1, 1, 2 };
+        var counts = [_]f64{ 0, 0, 0, 0 };
+
+        const n_trials: u64 = 10_000;
+        var i: usize = 0;
+        while (i < n_trials) : (i += 1) {
+            const pick = random.weightedIndex(T, &proportions);
+            counts[pick] += 1;
+        }
+
+        // We expect the first and last counts to be roughly 2x the second and third
+        const approxEqRel = std.math.approxEqRel;
+        // Define "roughly" to be within 10%
+        const tolerance = 0.1;
+        try std.testing.expect(approxEqRel(f64, counts[0], counts[1] * 2, tolerance));
+        try std.testing.expect(approxEqRel(f64, counts[1], counts[2], tolerance));
+        try std.testing.expect(approxEqRel(f64, counts[2] * 2, counts[3], tolerance));
+    }
+}
lib/std/rand.zig
@@ -337,6 +337,42 @@ pub const Random = struct {
             mem.swap(T, &buf[i], &buf[j]);
         }
     }
+
+    /// Randomly selects an index into `proportions`, where the likelihood of each
+    /// index is weighted by that proportion.
+    ///
+    /// This is useful for selecting an item from a slice where weights are not equal.
+    /// `T` must be a numeric type capable of holding the sum of `proportions`.
+    pub fn weightedIndex(r: std.rand.Random, comptime T: type, proportions: []T) usize {
+        // This implementation works by summing the proportions and picking a random
+        //  point in [0, sum).  We then loop over the proportions, accumulating
+        //  until our accumulator is greater than the random point.
+
+        var sum: T = 0;
+        for (proportions) |v| {
+            sum += v;
+        }
+
+        const point = if (comptime std.meta.trait.isSignedInt(T))
+            r.intRangeLessThan(T, 0, sum)
+        else if (comptime std.meta.trait.isUnsignedInt(T))
+            r.uintLessThan(T, sum)
+        else if (comptime std.meta.trait.isFloat(T))
+            // take care that imprecision doesn't lead to a value slightly greater than sum
+            std.math.min(r.float(T) * sum, sum - std.math.epsilon(T))
+        else
+            @compileError("weightedIndex does not support proportions of type " ++ @typeName(T));
+
+        std.debug.assert(point < sum);
+
+        var accumulator: T = 0;
+        for (proportions) |p, index| {
+            accumulator += p;
+            if (point < accumulator) return index;
+        }
+
+        unreachable;
+    }
 };
 
 /// Convert a random integer 0 <= random_int <= maxValue(T),