Commit a155e35850

Ryan Liptak <squeek502@hotmail.com>
2023-08-15 15:11:59
std.json: Fix decoding of UTF-16 surrogate pairs (#16830)
* std.unicode: Add more UTF-16 decoding functions This mostly makes parts of Utf16LeIterator reusable * std.json: Fix decoding of UTF-16 surrogate pairs Before this commit, there were 524,288 codepoints that would get decoded improperly. After this commit, there are 0. Fixes #16828
1 parent f7b82ed
Changed files (3)
lib/std/json/scanner.zig
@@ -414,7 +414,7 @@ pub const Scanner = struct {
     string_is_object_key: bool = false,
     stack: BitStack,
     value_start: usize = undefined,
-    unicode_code_point: u21 = undefined,
+    utf16_code_units: [2]u16 = undefined,
 
     input: []const u8 = "",
     cursor: usize = 0,
@@ -1083,13 +1083,13 @@ pub const Scanner = struct {
                     const c = try self.expectByte();
                     switch (c) {
                         '0'...'9' => {
-                            self.unicode_code_point = @as(u21, c - '0') << 12;
+                            self.utf16_code_units[0] = @as(u16, c - '0') << 12;
                         },
                         'A'...'F' => {
-                            self.unicode_code_point = @as(u21, c - 'A' + 10) << 12;
+                            self.utf16_code_units[0] = @as(u16, c - 'A' + 10) << 12;
                         },
                         'a'...'f' => {
-                            self.unicode_code_point = @as(u21, c - 'a' + 10) << 12;
+                            self.utf16_code_units[0] = @as(u16, c - 'a' + 10) << 12;
                         },
                         else => return error.SyntaxError,
                     }
@@ -1101,13 +1101,13 @@ pub const Scanner = struct {
                     const c = try self.expectByte();
                     switch (c) {
                         '0'...'9' => {
-                            self.unicode_code_point |= @as(u21, c - '0') << 8;
+                            self.utf16_code_units[0] |= @as(u16, c - '0') << 8;
                         },
                         'A'...'F' => {
-                            self.unicode_code_point |= @as(u21, c - 'A' + 10) << 8;
+                            self.utf16_code_units[0] |= @as(u16, c - 'A' + 10) << 8;
                         },
                         'a'...'f' => {
-                            self.unicode_code_point |= @as(u21, c - 'a' + 10) << 8;
+                            self.utf16_code_units[0] |= @as(u16, c - 'a' + 10) << 8;
                         },
                         else => return error.SyntaxError,
                     }
@@ -1119,13 +1119,13 @@ pub const Scanner = struct {
                     const c = try self.expectByte();
                     switch (c) {
                         '0'...'9' => {
-                            self.unicode_code_point |= @as(u21, c - '0') << 4;
+                            self.utf16_code_units[0] |= @as(u16, c - '0') << 4;
                         },
                         'A'...'F' => {
-                            self.unicode_code_point |= @as(u21, c - 'A' + 10) << 4;
+                            self.utf16_code_units[0] |= @as(u16, c - 'A' + 10) << 4;
                         },
                         'a'...'f' => {
-                            self.unicode_code_point |= @as(u21, c - 'a' + 10) << 4;
+                            self.utf16_code_units[0] |= @as(u16, c - 'a' + 10) << 4;
                         },
                         else => return error.SyntaxError,
                     }
@@ -1137,31 +1137,26 @@ pub const Scanner = struct {
                     const c = try self.expectByte();
                     switch (c) {
                         '0'...'9' => {
-                            self.unicode_code_point |= c - '0';
+                            self.utf16_code_units[0] |= c - '0';
                         },
                         'A'...'F' => {
-                            self.unicode_code_point |= c - 'A' + 10;
+                            self.utf16_code_units[0] |= c - 'A' + 10;
                         },
                         'a'...'f' => {
-                            self.unicode_code_point |= c - 'a' + 10;
+                            self.utf16_code_units[0] |= c - 'a' + 10;
                         },
                         else => return error.SyntaxError,
                     }
                     self.cursor += 1;
-                    switch (self.unicode_code_point) {
-                        0xD800...0xDBFF => {
-                            // High surrogate half.
-                            self.unicode_code_point = 0x10000 | (self.unicode_code_point << 10);
-                            self.state = .string_surrogate_half;
-                            continue :state_loop;
-                        },
-                        0xDC00...0xDFFF => return error.SyntaxError, // Unexpected low surrogate half.
-                        else => {
-                            // Code point from a single UTF-16 code unit.
-                            self.value_start = self.cursor;
-                            self.state = .string;
-                            return self.partialStringCodepoint();
-                        },
+                    if (std.unicode.utf16IsHighSurrogate(self.utf16_code_units[0])) {
+                        self.state = .string_surrogate_half;
+                        continue :state_loop;
+                    } else if (std.unicode.utf16IsLowSurrogate(self.utf16_code_units[0])) {
+                        return error.SyntaxError; // Unexpected low surrogate half.
+                    } else {
+                        self.value_start = self.cursor;
+                        self.state = .string;
+                        return partialStringCodepoint(self.utf16_code_units[0]);
                     }
                 },
                 .string_surrogate_half => {
@@ -1188,6 +1183,7 @@ pub const Scanner = struct {
                     switch (try self.expectByte()) {
                         'D', 'd' => {
                             self.cursor += 1;
+                            self.utf16_code_units[1] = 0xD << 12;
                             self.state = .string_surrogate_half_backslash_u_1;
                             continue :state_loop;
                         },
@@ -1199,13 +1195,13 @@ pub const Scanner = struct {
                     switch (c) {
                         'C'...'F' => {
                             self.cursor += 1;
-                            self.unicode_code_point |= @as(u21, c - 'C') << 8;
+                            self.utf16_code_units[1] |= @as(u16, c - 'A' + 10) << 8;
                             self.state = .string_surrogate_half_backslash_u_2;
                             continue :state_loop;
                         },
                         'c'...'f' => {
                             self.cursor += 1;
-                            self.unicode_code_point |= @as(u21, c - 'c') << 8;
+                            self.utf16_code_units[1] |= @as(u16, c - 'a' + 10) << 8;
                             self.state = .string_surrogate_half_backslash_u_2;
                             continue :state_loop;
                         },
@@ -1217,19 +1213,19 @@ pub const Scanner = struct {
                     switch (c) {
                         '0'...'9' => {
                             self.cursor += 1;
-                            self.unicode_code_point |= @as(u21, c - '0') << 4;
+                            self.utf16_code_units[1] |= @as(u16, c - '0') << 4;
                             self.state = .string_surrogate_half_backslash_u_3;
                             continue :state_loop;
                         },
                         'A'...'F' => {
                             self.cursor += 1;
-                            self.unicode_code_point |= @as(u21, c - 'A' + 10) << 4;
+                            self.utf16_code_units[1] |= @as(u16, c - 'A' + 10) << 4;
                             self.state = .string_surrogate_half_backslash_u_3;
                             continue :state_loop;
                         },
                         'a'...'f' => {
                             self.cursor += 1;
-                            self.unicode_code_point |= @as(u21, c - 'a' + 10) << 4;
+                            self.utf16_code_units[1] |= @as(u16, c - 'a' + 10) << 4;
                             self.state = .string_surrogate_half_backslash_u_3;
                             continue :state_loop;
                         },
@@ -1240,20 +1236,21 @@ pub const Scanner = struct {
                     const c = try self.expectByte();
                     switch (c) {
                         '0'...'9' => {
-                            self.unicode_code_point |= c - '0';
+                            self.utf16_code_units[1] |= c - '0';
                         },
                         'A'...'F' => {
-                            self.unicode_code_point |= c - 'A' + 10;
+                            self.utf16_code_units[1] |= c - 'A' + 10;
                         },
                         'a'...'f' => {
-                            self.unicode_code_point |= c - 'a' + 10;
+                            self.utf16_code_units[1] |= c - 'a' + 10;
                         },
                         else => return error.SyntaxError,
                     }
                     self.cursor += 1;
                     self.value_start = self.cursor;
                     self.state = .string;
-                    return self.partialStringCodepoint();
+                    const code_point = std.unicode.utf16DecodeSurrogatePair(&self.utf16_code_units) catch unreachable;
+                    return partialStringCodepoint(code_point);
                 },
 
                 .string_utf8_last_byte => {
@@ -1681,9 +1678,7 @@ pub const Scanner = struct {
         return Token{ .partial_number = slice };
     }
 
-    fn partialStringCodepoint(self: *@This()) Token {
-        const code_point = self.unicode_code_point;
-        self.unicode_code_point = undefined;
+    fn partialStringCodepoint(code_point: u21) Token {
         var buf: [4]u8 = undefined;
         switch (std.unicode.utf8Encode(code_point, &buf) catch unreachable) {
             1 => return Token{ .partial_string_escaped_1 = buf[0..1].* },
lib/std/json/scanner_test.zig
@@ -236,6 +236,7 @@ const string_test_cases = .{
     .{ "\\u000a", "\n" },
     .{ "𝄞", "\u{1D11E}" },
     .{ "\\uD834\\uDD1E", "\u{1D11E}" },
+    .{ "\\uD87F\\uDFFE", "\u{2FFFE}" },
     .{ "\\uff20", "@" },
 };
 
lib/std/unicode.zig
@@ -293,6 +293,58 @@ pub const Utf8Iterator = struct {
     }
 };
 
+pub fn utf16IsHighSurrogate(c: u16) bool {
+    return c & ~@as(u16, 0x03ff) == 0xd800;
+}
+
+pub fn utf16IsLowSurrogate(c: u16) bool {
+    return c & ~@as(u16, 0x03ff) == 0xdc00;
+}
+
+/// Returns how many code units the UTF-16 representation would require
+/// for the given codepoint.
+pub fn utf16CodepointSequenceLength(c: u21) !u2 {
+    if (c <= 0xFFFF) return 1;
+    if (c <= 0x10FFFF) return 2;
+    return error.CodepointTooLarge;
+}
+
+test utf16CodepointSequenceLength {
+    try testing.expectEqual(@as(u2, 1), try utf16CodepointSequenceLength('a'));
+    try testing.expectEqual(@as(u2, 1), try utf16CodepointSequenceLength(0xFFFF));
+    try testing.expectEqual(@as(u2, 2), try utf16CodepointSequenceLength(0x10000));
+    try testing.expectEqual(@as(u2, 2), try utf16CodepointSequenceLength(0x10FFFF));
+    try testing.expectError(error.CodepointTooLarge, utf16CodepointSequenceLength(0x110000));
+}
+
+/// Given the first code unit of a UTF-16 codepoint, returns a number 1-2
+/// indicating the total length of the codepoint in UTF-16 code units.
+/// If this code unit does not match the form of a UTF-16 start code unit, returns Utf16InvalidStartCodeUnit.
+pub fn utf16CodeUnitSequenceLength(first_code_unit: u16) !u2 {
+    if (utf16IsHighSurrogate(first_code_unit)) return 2;
+    if (utf16IsLowSurrogate(first_code_unit)) return error.Utf16InvalidStartCodeUnit;
+    return 1;
+}
+
+test utf16CodeUnitSequenceLength {
+    try testing.expectEqual(@as(u2, 1), try utf16CodeUnitSequenceLength('a'));
+    try testing.expectEqual(@as(u2, 1), try utf16CodeUnitSequenceLength(0xFFFF));
+    try testing.expectEqual(@as(u2, 2), try utf16CodeUnitSequenceLength(0xDBFF));
+    try testing.expectError(error.Utf16InvalidStartCodeUnit, utf16CodeUnitSequenceLength(0xDFFF));
+}
+
+/// Decodes the codepoint encoded in the given pair of UTF-16 code units.
+/// Asserts that `surrogate_pair.len >= 2` and that the first code unit is a high surrogate.
+/// If the second code unit is not a low surrogate, error.ExpectedSecondSurrogateHalf is returned.
+pub fn utf16DecodeSurrogatePair(surrogate_pair: []const u16) !u21 {
+    assert(surrogate_pair.len >= 2);
+    assert(utf16IsHighSurrogate(surrogate_pair[0]));
+    const high_half: u21 = surrogate_pair[0];
+    const low_half = surrogate_pair[1];
+    if (!utf16IsLowSurrogate(low_half)) return error.ExpectedSecondSurrogateHalf;
+    return 0x10000 + ((high_half & 0x03ff) << 10) | (low_half & 0x03ff);
+}
+
 pub const Utf16LeIterator = struct {
     bytes: []const u8,
     i: usize,
@@ -307,19 +359,20 @@ pub const Utf16LeIterator = struct {
     pub fn nextCodepoint(it: *Utf16LeIterator) !?u21 {
         assert(it.i <= it.bytes.len);
         if (it.i == it.bytes.len) return null;
-        const c0: u21 = mem.readIntLittle(u16, it.bytes[it.i..][0..2]);
+        var code_units: [2]u16 = undefined;
+        code_units[0] = mem.readIntLittle(u16, it.bytes[it.i..][0..2]);
         it.i += 2;
-        if (c0 & ~@as(u21, 0x03ff) == 0xd800) {
+        if (utf16IsHighSurrogate(code_units[0])) {
             // surrogate pair
             if (it.i >= it.bytes.len) return error.DanglingSurrogateHalf;
-            const c1: u21 = mem.readIntLittle(u16, it.bytes[it.i..][0..2]);
-            if (c1 & ~@as(u21, 0x03ff) != 0xdc00) return error.ExpectedSecondSurrogateHalf;
+            code_units[1] = mem.readIntLittle(u16, it.bytes[it.i..][0..2]);
+            const codepoint = try utf16DecodeSurrogatePair(&code_units);
             it.i += 2;
-            return 0x10000 + (((c0 & 0x03ff) << 10) | (c1 & 0x03ff));
-        } else if (c0 & ~@as(u21, 0x03ff) == 0xdc00) {
+            return codepoint;
+        } else if (utf16IsLowSurrogate(code_units[0])) {
             return error.UnexpectedSecondSurrogateHalf;
         } else {
-            return c0;
+            return code_units[0];
         }
     }
 };