Commit 2387292f20

Josh Wolfe <thejoshwolfe@gmail.com>
2018-04-29 23:28:11
move some checks around in utf8Encode logic to be more zig idiomatic
1 parent 8c567d8
Changed files (1)
std/unicode.zig
@@ -1,12 +1,11 @@
 const std = @import("./index.zig");
 const debug = std.debug;
 
-// Given a Utf8-Codepoint returns how many (1-4)
-// bytes there are if represented as an array of bytes.
+/// Returns how many bytes the UTF-8 representation would require
+/// for the given codepoint.
 pub fn utf8CodepointSequenceLength(c: u32) !u3 {
     if (c < 0x80) return u3(1);
     if (c < 0x800) return u3(2);
-    if (c -% 0xd800 < 0x800) return error.InvalidCodepoint;
     if (c < 0x10000) return u3(3);
     if (c < 0x110000) return u3(4);
     return error.CodepointTooLarge;
@@ -23,45 +22,39 @@ pub fn utf8ByteSequenceLength(first_byte: u8) !u3 {
     return error.Utf8InvalidStartByte;
 }
 
-/// Encodes a code point back into utf8
-/// c: the code point
-/// out: the out buffer to write to
-/// Notes: out has to have a len big enough for the bytes
-///        however this limit is dependent on the code point
-///        but giving it a minimum of 4 will ensure it will work
-///        for all code points.
-/// Errors: Will return an error if the code point is invalid.
+/// Encodes the given codepoint into a UTF-8 byte sequence.
+/// c: the codepoint.
+/// out: the out buffer to write to. Must have a len >= utf8CodepointSequenceLength(c).
+/// Errors: if c cannot be encoded in UTF-8.
+/// Returns: the number of bytes written to out.
 pub fn utf8Encode(c: u32, out: []u8) !u3 {
-    if (utf8CodepointSequenceLength(c)) |length| {
-        debug.assert(out.len >= length);
-        switch (length) {
-            // The pattern for each is the same
-            // - Increasing the initial shift by 6 each time
-            // - Each time after the first shorten the shifted
-            //   value to a max of 0b111111 (63)
-            1 => out[0] = u8(c), // Can just do 0 + codepoint for initial range
-            2 => {
-                out[0] = u8(0b11000000 | (c >> 6));
-                out[1] = u8(0b10000000 | (c & 0b111111));
-            },
-            3 => {
-                out[0] = u8(0b11100000 | (c >> 12));
-                out[1] = u8(0b10000000 | ((c >> 6) & 0b111111));
-                out[2] = u8(0b10000000 | (c & 0b111111));
-            },
-            4 => {
-                out[0] = u8(0b11110000 | (c >> 18));
-                out[1] = u8(0b10000000 | ((c >> 12) & 0b111111));
-                out[2] = u8(0b10000000 | ((c >> 6) & 0b111111));
-                out[3] = u8(0b10000000 | (c & 0b111111));
-            },
-            else => unreachable,
-        }
-
-        return length;
-    } else |err| {
-        return err;
+    const length = try utf8CodepointSequenceLength(c);
+    debug.assert(out.len >= length);
+    switch (length) {
+        // The pattern for each is the same
+        // - Increasing the initial shift by 6 each time
+        // - Each time after the first shorten the shifted
+        //   value to a max of 0b111111 (63)
+        1 => out[0] = u8(c), // Can just do 0 + codepoint for initial range
+        2 => {
+            out[0] = u8(0b11000000 | (c >> 6));
+            out[1] = u8(0b10000000 | (c & 0b111111));
+        },
+        3 => {
+            if (0xd800 <= c and c <= 0xdfff) return error.Utf8CannotEncodeSurrogateHalf;
+            out[0] = u8(0b11100000 | (c >> 12));
+            out[1] = u8(0b10000000 | ((c >> 6) & 0b111111));
+            out[2] = u8(0b10000000 | (c & 0b111111));
+        },
+        4 => {
+            out[0] = u8(0b11110000 | (c >> 18));
+            out[1] = u8(0b10000000 | ((c >> 12) & 0b111111));
+            out[2] = u8(0b10000000 | ((c >> 6) & 0b111111));
+            out[3] = u8(0b10000000 | (c & 0b111111));
+        },
+        else => unreachable,
     }
+    return length;
 }
 
 /// Decodes the UTF-8 codepoint encoded in the given slice of bytes.
@@ -249,8 +242,10 @@ test "utf8 encode" {
 
 test "utf8 encode error" {
     var array: [4]u8 = undefined;
-    testErrorEncode(0xFFFFFF, array[0..], error.CodepointTooLarge);
-    testErrorEncode(0xd900, array[0..], error.InvalidCodepoint);
+    testErrorEncode(0xd800, array[0..], error.Utf8CannotEncodeSurrogateHalf);
+    testErrorEncode(0xdfff, array[0..], error.Utf8CannotEncodeSurrogateHalf);
+    testErrorEncode(0x110000, array[0..], error.CodepointTooLarge);
+    testErrorEncode(0xffffffff, array[0..], error.CodepointTooLarge);
 }
 
 fn testErrorEncode(codePoint: u32, array: []u8, expectedErr: error) void {