Commit d68f39b541

Karl Seguin <karlseguin@users.noreply.github.com>
2023-10-07 05:49:21
std.unicode.utf8ValidateSlice: optimize implementation (#17329)
Originally inspired by Go's `utf8.Valid` function. Includes some test cases from Go's test suite. Further optimized to be faster in all tested cases (short/long ascii/UTF8), in all release modes. Takes advantage of SIMD for the ASCII fast path.
1 parent 5a4a587
Changed files (1)
lib
lib/std/unicode.zig
@@ -196,22 +196,115 @@ pub fn utf8CountCodepoints(s: []const u8) !usize {
     return len;
 }
 
-pub fn utf8ValidateSlice(s: []const u8) bool {
+/// Returns true if the input consists entirely of UTF-8 codepoints
+pub fn utf8ValidateSlice(input: []const u8) bool {
+    var remaining = input;
+
+    const V_len = comptime std.simd.suggestVectorSize(usize) orelse 1;
+    const V = @Vector(V_len, usize);
+    const u8s_in_vector = @sizeOf(usize) * V_len;
+
+    // Fast path. Check for and skip ASCII characters at the start of the input.
+    while (remaining.len >= u8s_in_vector) {
+        const chunk: V = @bitCast(remaining[0..u8s_in_vector].*);
+        const swapped = mem.littleToNative(V, chunk);
+        const reduced = @reduce(.Or, swapped);
+        const mask: usize = @bitCast([1]u8{0x80} ** @sizeOf(usize));
+        if (reduced & mask != 0) {
+            // Found a non ASCII byte
+            break;
+        }
+        remaining = remaining[u8s_in_vector..];
+    }
+
+    // default lowest and highest continuation byte
+    const lo_cb = 0b10000000;
+    const hi_cb = 0b10111111;
+
+    const min_non_ascii_codepoint = 0x80;
+
+    // The first nibble is used to identify the continuation byte range to
+    // accept. The second nibble is the size.
+    const xx = 0xF1; // invalid: size 1
+    const as = 0xF0; // ASCII: size 1
+    const s1 = 0x02; // accept 0, size 2
+    const s2 = 0x13; // accept 1, size 3
+    const s3 = 0x03; // accept 0, size 3
+    const s4 = 0x23; // accept 2, size 3
+    const s5 = 0x34; // accept 3, size 4
+    const s6 = 0x04; // accept 0, size 4
+    const s7 = 0x44; // accept 4, size 4
+
+    // Information about the first byte in a UTF-8 sequence.
+    const first = comptime ([_]u8{as} ** 128) ++ ([_]u8{xx} ** 64) ++ [_]u8{
+        xx, xx, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1,
+        s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1, s1,
+        s2, s3, s3, s3, s3, s3, s3, s3, s3, s3, s3, s3, s3, s4, s3, s3,
+        s5, s6, s6, s6, s7, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx, xx,
+    };
+
+    var n = remaining.len;
     var i: usize = 0;
-    while (i < s.len) {
-        if (utf8ByteSequenceLength(s[i])) |cp_len| {
-            if (i + cp_len > s.len) {
-                return false;
-            }
+    while (i < n) {
+        const first_byte = remaining[i];
+        if (first_byte < min_non_ascii_codepoint) {
+            i += 1;
+            continue;
+        }
 
-            if (std.meta.isError(utf8Decode(s[i .. i + cp_len]))) {
-                return false;
-            }
-            i += cp_len;
-        } else |_| {
+        const info = first[first_byte];
+        if (info == xx) {
+            return false; // Illegal starter byte.
+        }
+
+        const size = info & 7;
+        if (i + size > n) {
+            return false; // Short or invalid.
+        }
+
+        // Figure out the acceptable low and high continuation bytes, starting
+        // with our defaults.
+        var accept_lo: u8 = lo_cb;
+        var accept_hi: u8 = hi_cb;
+
+        switch (info >> 4) {
+            0 => {},
+            1 => accept_lo = 0xA0,
+            2 => accept_hi = 0x9F,
+            3 => accept_lo = 0x90,
+            4 => accept_hi = 0x8F,
+            else => unreachable,
+        }
+
+        const c1 = remaining[i + 1];
+        if (c1 < accept_lo or accept_hi < c1) {
             return false;
         }
+
+        switch (size) {
+            2 => i += 2,
+            3 => {
+                const c2 = remaining[i + 2];
+                if (c2 < lo_cb or hi_cb < c2) {
+                    return false;
+                }
+                i += 3;
+            },
+            4 => {
+                const c2 = remaining[i + 2];
+                if (c2 < lo_cb or hi_cb < c2) {
+                    return false;
+                }
+                const c3 = remaining[i + 3];
+                if (c3 < lo_cb or hi_cb < c3) {
+                    return false;
+                }
+                i += 4;
+            },
+            else => unreachable,
+        }
     }
+
     return true;
 }
 
@@ -502,15 +595,44 @@ fn testUtf8ViewOk() !void {
     try testing.expect(it2.nextCodepoint() == null);
 }
 
-test "bad utf8 slice" {
-    try comptime testBadUtf8Slice();
-    try testBadUtf8Slice();
+test "validate slice" {
+    try comptime testValidateSlice();
+    try testValidateSlice();
+
+    // We skip a variable (based on recommended vector size) chunks of
+    // ASCII characters. Let's make sure we're chunking correctly.
+    const str = [_]u8{'a'} ** 550 ++ "\xc0";
+    for (0..str.len - 3) |i| {
+        try testing.expect(!utf8ValidateSlice(str[i..]));
+    }
 }
-fn testBadUtf8Slice() !void {
+fn testValidateSlice() !void {
     try testing.expect(utf8ValidateSlice("abc"));
+    try testing.expect(utf8ValidateSlice("abc\xdf\xbf"));
+    try testing.expect(utf8ValidateSlice(""));
+    try testing.expect(utf8ValidateSlice("a"));
+    try testing.expect(utf8ValidateSlice("abc"));
+    try testing.expect(utf8ValidateSlice("Ж"));
+    try testing.expect(utf8ValidateSlice("ЖЖ"));
+    try testing.expect(utf8ValidateSlice("брэд-ЛГТМ"));
+    try testing.expect(utf8ValidateSlice("☺☻☹"));
+    try testing.expect(utf8ValidateSlice("a\u{fffdb}"));
+    try testing.expect(utf8ValidateSlice("\xf4\x8f\xbf\xbf"));
+    try testing.expect(utf8ValidateSlice("abc\xdf\xbf"));
+
     try testing.expect(!utf8ValidateSlice("abc\xc0"));
     try testing.expect(!utf8ValidateSlice("abc\xc0abc"));
-    try testing.expect(utf8ValidateSlice("abc\xdf\xbf"));
+    try testing.expect(!utf8ValidateSlice("aa\xe2"));
+    try testing.expect(!utf8ValidateSlice("\x42\xfa"));
+    try testing.expect(!utf8ValidateSlice("\x42\xfa\x43"));
+    try testing.expect(!utf8ValidateSlice("abc\xc0"));
+    try testing.expect(!utf8ValidateSlice("abc\xc0abc"));
+    try testing.expect(!utf8ValidateSlice("\xf4\x90\x80\x80"));
+    try testing.expect(!utf8ValidateSlice("\xf7\xbf\xbf\xbf"));
+    try testing.expect(!utf8ValidateSlice("\xfb\xbf\xbf\xbf\xbf"));
+    try testing.expect(!utf8ValidateSlice("\xc0\x80"));
+    try testing.expect(!utf8ValidateSlice("\xed\xa0\x80"));
+    try testing.expect(!utf8ValidateSlice("\xed\xbf\xbf"));
 }
 
 test "valid utf8" {