Commit 163ebe044b

Henry John Kupty <hkupty@users.noreply.github.com>
2025-10-07 18:32:13
std.mem.countScalar: rework to benefit from simd (#25477)
`findScalarPos` might do repetitive work, even if using simd. For example, when searching the string `/abcde/fghijk/lm` for the character `/`, a 16-byte wide search would yield `1000001000000100` but would only count the first `1` and re-search the remaining of the string. When testing locally, the difference was quite significative: ``` count scalar 5737 iterations 522.83us per iterations 0 bytes per iteration worst: 2370us median: 512us stddev: 107.64us count v2 38333 iterations 78.03us per iterations 0 bytes per iteration worst: 713us median: 76us stddev: 10.62us count scalar v2 99565 iterations 29.80us per iterations 0 bytes per iteration worst: 41us median: 29us stddev: 1.04us ``` Note that `count v2` is a simpler string search, similar to the remaining version of the simd approach: ``` pub fn countV2(comptime T: type, haystack: []const T, needle: T) usize { const n = haystack.len; if (n < 1) return 0; var count: usize = 0; for (haystack[0..n]) |item| { count += @intFromBool(item == needle); } return count; } ``` Which implies the compiler yields some optimized code for a simpler loop that is more performant than the `findScalarPos`-based approach, hence the usage of iterative approach for the remaining of the haystack. Co-authored-by: StAlKeR7779 <stalkek7779@yandex.ru>
1 parent 9760068
Changed files (1)
lib
lib/std/mem.zig
@@ -1706,12 +1706,26 @@ test count {
 
 /// Returns the number of needles inside the haystack
 pub fn countScalar(comptime T: type, haystack: []const T, needle: T) usize {
+    const n = haystack.len;
     var i: usize = 0;
     var found: usize = 0;
 
-    while (findScalarPos(T, haystack, i, needle)) |idx| {
-        i = idx + 1;
-        found += 1;
+    if (use_vectors_for_comparison and
+        (@typeInfo(T) == .int or @typeInfo(T) == .float) and std.math.isPowerOfTwo(@bitSizeOf(T)))
+    {
+        if (std.simd.suggestVectorLength(T)) |block_size| {
+            const Block = @Vector(block_size, T);
+
+            const letter_mask: Block = @splat(needle);
+            while (n - i >= block_size) : (i += block_size) {
+                const haystack_block: Block = haystack[i..][0..block_size].*;
+                found += std.simd.countTrues(letter_mask == haystack_block);
+            }
+        }
+    }
+
+    for (haystack[i..n]) |item| {
+        found += @intFromBool(item == needle);
     }
 
     return found;