Commit 5bb8c03697
lib/std/rand/test.zig
@@ -445,3 +445,29 @@ test "CSPRNG" {
const c = random.int(u64);
try expect(a ^ b ^ c != 0);
}
+
+test "Random weightedIndex" {
+ // Make sure weightedIndex works for various integers and floats
+ inline for (.{ u64, i4, f32, f64 }) |T| {
+ var prng = DefaultPrng.init(0);
+ const random = prng.random();
+
+ var proportions = [_]T{ 2, 1, 1, 2 };
+ var counts = [_]f64{ 0, 0, 0, 0 };
+
+ const n_trials: u64 = 10_000;
+ var i: usize = 0;
+ while (i < n_trials) : (i += 1) {
+ const pick = random.weightedIndex(T, &proportions);
+ counts[pick] += 1;
+ }
+
+ // We expect the first and last counts to be roughly 2x the second and third
+ const approxEqRel = std.math.approxEqRel;
+ // Define "roughly" to be within 10%
+ const tolerance = 0.1;
+ try std.testing.expect(approxEqRel(f64, counts[0], counts[1] * 2, tolerance));
+ try std.testing.expect(approxEqRel(f64, counts[1], counts[2], tolerance));
+ try std.testing.expect(approxEqRel(f64, counts[2] * 2, counts[3], tolerance));
+ }
+}
lib/std/rand.zig
@@ -337,6 +337,42 @@ pub const Random = struct {
mem.swap(T, &buf[i], &buf[j]);
}
}
+
+ /// Randomly selects an index into `proportions`, where the likelihood of each
+ /// index is weighted by that proportion.
+ ///
+ /// This is useful for selecting an item from a slice where weights are not equal.
+ /// `T` must be a numeric type capable of holding the sum of `proportions`.
+ pub fn weightedIndex(r: std.rand.Random, comptime T: type, proportions: []T) usize {
+ // This implementation works by summing the proportions and picking a random
+ // point in [0, sum). We then loop over the proportions, accumulating
+ // until our accumulator is greater than the random point.
+
+ var sum: T = 0;
+ for (proportions) |v| {
+ sum += v;
+ }
+
+ const point = if (comptime std.meta.trait.isSignedInt(T))
+ r.intRangeLessThan(T, 0, sum)
+ else if (comptime std.meta.trait.isUnsignedInt(T))
+ r.uintLessThan(T, sum)
+ else if (comptime std.meta.trait.isFloat(T))
+ // take care that imprecision doesn't lead to a value slightly greater than sum
+ std.math.min(r.float(T) * sum, sum - std.math.epsilon(T))
+ else
+ @compileError("weightedIndex does not support proportions of type " ++ @typeName(T));
+
+ std.debug.assert(point < sum);
+
+ var accumulator: T = 0;
+ for (proportions) |p, index| {
+ accumulator += p;
+ if (point < accumulator) return index;
+ }
+
+ unreachable;
+ }
};
/// Convert a random integer 0 <= random_int <= maxValue(T),