Commit 56929795a8

Carter Snook <cartersnook04@gmail.com>
2024-06-14 22:40:54
std.unicode: add encode overflow check function and friends
1 parent 82a934b
Changed files (1)
lib
lib/std/unicode.zig
@@ -1405,29 +1405,38 @@ test "ArrayList functions on a re-used list" {
     }
 }
 
-/// Converts a UTF-8 string literal into a UTF-16LE string literal.
-pub fn utf8ToUtf16LeStringLiteral(comptime utf8: []const u8) *const [calcUtf16LeLen(utf8) catch |err| @compileError(err):0]u16 {
+fn utf8ToUtf16LeStringLiteralImpl(comptime utf8: []const u8, comptime surrogates: Surrogates) *const [calcUtf16LeLenImpl(utf8, surrogates) catch |err| @compileError(err):0]u16 {
     return comptime blk: {
-        const len: usize = calcUtf16LeLen(utf8) catch unreachable;
+        const len: usize = calcUtf16LeLenImpl(utf8, surrogates) catch unreachable;
         var utf16le: [len:0]u16 = [_:0]u16{0} ** len;
-        const utf16le_len = utf8ToUtf16Le(&utf16le, utf8[0..]) catch |err| @compileError(err);
+        const utf16le_len = utf8ToUtf16LeImpl(&utf16le, utf8[0..], surrogates) catch |err| @compileError(err);
         assert(len == utf16le_len);
         const final = utf16le;
         break :blk &final;
     };
 }
 
-const CalcUtf16LeLenError = Utf8DecodeError || error{Utf8InvalidStartByte};
+/// Converts a UTF-8 string literal into a UTF-16LE string literal.
+pub fn utf8ToUtf16LeStringLiteral(comptime utf8: []const u8) *const [calcUtf16LeLen(utf8) catch |err| @compileError(err):0]u16 {
+    return utf8ToUtf16LeStringLiteralImpl(utf8, .cannot_encode_surrogate_half);
+}
 
-/// Returns length in UTF-16 of UTF-8 slice as length of []u16.
-/// Length in []u8 is 2*len16.
-pub fn calcUtf16LeLen(utf8: []const u8) CalcUtf16LeLenError!usize {
+/// Converts a WTF-8 string literal into a WTF-16LE string literal.
+pub fn wtf8ToWtf16LeStringLiteral(comptime wtf8: []const u8) *const [calcWtf16LeLen(wtf8) catch |err| @compileError(err):0]u16 {
+    return utf8ToUtf16LeStringLiteralImpl(wtf8, .can_encode_surrogate_half);
+}
+
+pub fn calcUtf16LeLenImpl(utf8: []const u8, comptime surrogates: Surrogates) !usize {
+    const utf8DecodeImpl = switch (surrogates) {
+        .cannot_encode_surrogate_half => utf8Decode,
+        .can_encode_surrogate_half => wtf8Decode,
+    };
     var src_i: usize = 0;
     var dest_len: usize = 0;
     while (src_i < utf8.len) {
         const n = try utf8ByteSequenceLength(utf8[src_i]);
         const next_src_i = src_i + n;
-        const codepoint = try utf8Decode(utf8[src_i..next_src_i]);
+        const codepoint = try utf8DecodeImpl(utf8[src_i..next_src_i]);
         if (codepoint < 0x10000) {
             dest_len += 1;
         } else {
@@ -1438,16 +1447,37 @@ pub fn calcUtf16LeLen(utf8: []const u8) CalcUtf16LeLenError!usize {
     return dest_len;
 }
 
-fn testCalcUtf16LeLen() !void {
-    try testing.expectEqual(@as(usize, 1), try calcUtf16LeLen("a"));
-    try testing.expectEqual(@as(usize, 10), try calcUtf16LeLen("abcdefghij"));
-    try testing.expectEqual(@as(usize, 10), try calcUtf16LeLen("äåéëþüúíóö"));
-    try testing.expectEqual(@as(usize, 5), try calcUtf16LeLen("こんにちは"));
+const CalcUtf16LeLenError = Utf8DecodeError || error{Utf8InvalidStartByte};
+
+/// Returns length in UTF-16LE of UTF-8 slice as length of []u16.
+/// Length in []u8 is 2*len16.
+pub fn calcUtf16LeLen(utf8: []const u8) CalcUtf16LeLenError!usize {
+    return calcUtf16LeLenImpl(utf8, .cannot_encode_surrogate_half);
+}
+
+const CalcWtf16LeLenError = Wtf8DecodeError || error{Utf8InvalidStartByte};
+
+/// Returns length in WTF-16LE of WTF-8 slice as length of []u16.
+/// Length in []u8 is 2*len16.
+pub fn calcWtf16LeLen(wtf8: []const u8) CalcWtf16LeLenError!usize {
+    return calcUtf16LeLenImpl(wtf8, .can_encode_surrogate_half);
 }
 
-test "calculate utf16 string length of given utf8 string in u16" {
-    try testCalcUtf16LeLen();
-    try comptime testCalcUtf16LeLen();
+fn testCalcUtf16LeLenImpl(calcUtf16LeLenImpl_: anytype) !void {
+    try testing.expectEqual(@as(usize, 1), try calcUtf16LeLenImpl_("a"));
+    try testing.expectEqual(@as(usize, 10), try calcUtf16LeLenImpl_("abcdefghij"));
+    try testing.expectEqual(@as(usize, 10), try calcUtf16LeLenImpl_("äåéëþüúíóö"));
+    try testing.expectEqual(@as(usize, 5), try calcUtf16LeLenImpl_("こんにちは"));
+}
+
+test calcUtf16LeLen {
+    try testCalcUtf16LeLenImpl(calcUtf16LeLen);
+    try comptime testCalcUtf16LeLenImpl(calcUtf16LeLen);
+}
+
+test calcWtf16LeLen {
+    try testCalcUtf16LeLenImpl(calcWtf16LeLen);
+    try comptime testCalcUtf16LeLenImpl(calcWtf16LeLen);
 }
 
 /// Print the given `utf16le` string, encoded as UTF-8 bytes.
@@ -1487,8 +1517,10 @@ pub fn fmtUtf16Le(utf16le: []const u16) std.fmt.Formatter(formatUtf16Le) {
 test fmtUtf16Le {
     const expectFmt = testing.expectFmt;
     try expectFmt("", "{}", .{fmtUtf16Le(utf8ToUtf16LeStringLiteral(""))});
+    try expectFmt("", "{}", .{fmtUtf16Le(wtf8ToWtf16LeStringLiteral(""))});
     try expectFmt("foo", "{}", .{fmtUtf16Le(utf8ToUtf16LeStringLiteral("foo"))});
-    try expectFmt("𐐷", "{}", .{fmtUtf16Le(utf8ToUtf16LeStringLiteral("𐐷"))});
+    try expectFmt("foo", "{}", .{fmtUtf16Le(wtf8ToWtf16LeStringLiteral("foo"))});
+    try expectFmt("𐐷", "{}", .{fmtUtf16Le(wtf8ToWtf16LeStringLiteral("𐐷"))});
     try expectFmt("퟿", "{}", .{fmtUtf16Le(&[_]u16{mem.readInt(u16, "\xff\xd7", native_endian)})});
     try expectFmt("�", "{}", .{fmtUtf16Le(&[_]u16{mem.readInt(u16, "\x00\xd8", native_endian)})});
     try expectFmt("�", "{}", .{fmtUtf16Le(&[_]u16{mem.readInt(u16, "\xff\xdb", native_endian)})});
@@ -1497,12 +1529,12 @@ test fmtUtf16Le {
     try expectFmt("", "{}", .{fmtUtf16Le(&[_]u16{mem.readInt(u16, "\x00\xe0", native_endian)})});
 }
 
-test utf8ToUtf16LeStringLiteral {
+fn testUtf8ToUtf16LeStringLiteral(utf8ToUtf16LeStringLiteral_: anytype) !void {
     {
         const bytes = [_:0]u16{
             mem.nativeToLittle(u16, 0x41),
         };
-        const utf16 = utf8ToUtf16LeStringLiteral("A");
+        const utf16 = utf8ToUtf16LeStringLiteral_("A");
         try testing.expectEqualSlices(u16, &bytes, utf16);
         try testing.expect(utf16[1] == 0);
     }
@@ -1511,7 +1543,7 @@ test utf8ToUtf16LeStringLiteral {
             mem.nativeToLittle(u16, 0xD801),
             mem.nativeToLittle(u16, 0xDC37),
         };
-        const utf16 = utf8ToUtf16LeStringLiteral("𐐷");
+        const utf16 = utf8ToUtf16LeStringLiteral_("𐐷");
         try testing.expectEqualSlices(u16, &bytes, utf16);
         try testing.expect(utf16[2] == 0);
     }
@@ -1519,7 +1551,7 @@ test utf8ToUtf16LeStringLiteral {
         const bytes = [_:0]u16{
             mem.nativeToLittle(u16, 0x02FF),
         };
-        const utf16 = utf8ToUtf16LeStringLiteral("\u{02FF}");
+        const utf16 = utf8ToUtf16LeStringLiteral_("\u{02FF}");
         try testing.expectEqualSlices(u16, &bytes, utf16);
         try testing.expect(utf16[1] == 0);
     }
@@ -1527,7 +1559,7 @@ test utf8ToUtf16LeStringLiteral {
         const bytes = [_:0]u16{
             mem.nativeToLittle(u16, 0x7FF),
         };
-        const utf16 = utf8ToUtf16LeStringLiteral("\u{7FF}");
+        const utf16 = utf8ToUtf16LeStringLiteral_("\u{7FF}");
         try testing.expectEqualSlices(u16, &bytes, utf16);
         try testing.expect(utf16[1] == 0);
     }
@@ -1535,7 +1567,7 @@ test utf8ToUtf16LeStringLiteral {
         const bytes = [_:0]u16{
             mem.nativeToLittle(u16, 0x801),
         };
-        const utf16 = utf8ToUtf16LeStringLiteral("\u{801}");
+        const utf16 = utf8ToUtf16LeStringLiteral_("\u{801}");
         try testing.expectEqualSlices(u16, &bytes, utf16);
         try testing.expect(utf16[1] == 0);
     }
@@ -1544,12 +1576,20 @@ test utf8ToUtf16LeStringLiteral {
             mem.nativeToLittle(u16, 0xDBFF),
             mem.nativeToLittle(u16, 0xDFFF),
         };
-        const utf16 = utf8ToUtf16LeStringLiteral("\u{10FFFF}");
+        const utf16 = utf8ToUtf16LeStringLiteral_("\u{10FFFF}");
         try testing.expectEqualSlices(u16, &bytes, utf16);
         try testing.expect(utf16[2] == 0);
     }
 }
 
+test utf8ToUtf16LeStringLiteral {
+    try testUtf8ToUtf16LeStringLiteral(utf8ToUtf16LeStringLiteral);
+}
+
+test wtf8ToWtf16LeStringLiteral {
+    try testUtf8ToUtf16LeStringLiteral(wtf8ToWtf16LeStringLiteral);
+}
+
 fn testUtf8CountCodepoints() !void {
     try testing.expectEqual(@as(usize, 10), try utf8CountCodepoints("abcdefghij"));
     try testing.expectEqual(@as(usize, 10), try utf8CountCodepoints("äåéëþüúíóö"));
@@ -1795,6 +1835,30 @@ pub fn wtf8ToWtf16Le(wtf16le: []u16, wtf8: []const u8) error{InvalidWtf8}!usize
     return utf8ToUtf16LeImpl(wtf16le, wtf8, .can_encode_surrogate_half);
 }
 
+fn checkUtf8ToUtf16LeOverflowImpl(utf8: []const u8, utf16le: []const u16, comptime surrogates: Surrogates) !bool {
+    // Each u8 in UTF-8/WTF-8 correlates to at most one u16 in UTF-16LE/WTF-16LE.
+    if (utf16le.len >= utf8.len) return false;
+    const utf16_len = calcUtf16LeLenImpl(utf8, surrogates) catch {
+        return switch (surrogates) {
+            .cannot_encode_surrogate_half => error.InvalidUtf8,
+            .can_encode_surrogate_half => error.InvalidWtf8,
+        };
+    };
+    return utf16_len > utf16le.len;
+}
+
+/// Checks if calling `utf8ToUtf16Le` would overflow. Might fail if utf8 is not
+/// valid UTF-8.
+pub fn checkUtf8ToUtf16LeOverflow(utf8: []const u8, utf16le: []const u16) error{InvalidUtf8}!bool {
+    return checkUtf8ToUtf16LeOverflowImpl(utf8, utf16le, .cannot_encode_surrogate_half);
+}
+
+/// Checks if calling `utf8ToUtf16Le` would overflow. Might fail if wtf8 is not
+/// valid WTF-8.
+pub fn checkWtf8ToWtf16LeOverflow(wtf8: []const u8, wtf16le: []const u16) error{InvalidWtf8}!bool {
+    return checkUtf8ToUtf16LeOverflowImpl(wtf8, wtf16le, .can_encode_surrogate_half);
+}
+
 /// Surrogate codepoints (U+D800 to U+DFFF) are replaced by the Unicode replacement
 /// character (U+FFFD).
 /// All surrogate codepoints and the replacement character are encoded as three
@@ -2000,6 +2064,8 @@ fn testRoundtripWtf8(wtf8: []const u8) !void {
     {
         var wtf16_buf: [32]u16 = undefined;
         const wtf16_len = try wtf8ToWtf16Le(&wtf16_buf, wtf8);
+        try testing.expectEqual(wtf16_len, calcWtf16LeLen(wtf8));
+        try testing.expectEqual(false, checkWtf8ToWtf16LeOverflow(wtf8, &wtf16_buf));
         const wtf16 = wtf16_buf[0..wtf16_len];
 
         var roundtripped_buf: [32]u8 = undefined;