Commit a241cf90d6

joadnacer <joad.nacer@gmail.com>
2023-10-13 01:38:55
std.base64: Improve Encoder/Decoder performance
1 parent b0f031f
Changed files (1)
lib
lib/std/base64.zig
@@ -102,14 +102,28 @@ pub const Base64Encoder = struct {
 
         var idx: usize = 0;
         var out_idx: usize = 0;
-        while (idx + 2 < source.len) : (idx += 3) {
+        while (idx + 15 < source.len) : (idx += 12) {
+            const bits = std.mem.readIntBig(u128, source[idx..][0..16]);
+            inline for (0..16) |i| {
+                dest[out_idx + i] = encoder.alphabet_chars[@truncate((bits >> (122 - i * 6)) & 0x3f)];
+            }
+            out_idx += 16;
+        }
+        while (idx + 3 < source.len) : (idx += 3) {
+            const bits = std.mem.readIntBig(u32, source[idx..][0..4]);
+            dest[out_idx] = encoder.alphabet_chars[(bits >> 26) & 0x3f];
+            dest[out_idx + 1] = encoder.alphabet_chars[(bits >> 20) & 0x3f];
+            dest[out_idx + 2] = encoder.alphabet_chars[(bits >> 14) & 0x3f];
+            dest[out_idx + 3] = encoder.alphabet_chars[(bits >> 8) & 0x3f];
+            out_idx += 4;
+        }
+        if (idx + 2 < source.len) {
             dest[out_idx] = encoder.alphabet_chars[source[idx] >> 2];
             dest[out_idx + 1] = encoder.alphabet_chars[((source[idx] & 0x3) << 4) | (source[idx + 1] >> 4)];
             dest[out_idx + 2] = encoder.alphabet_chars[(source[idx + 1] & 0xf) << 2 | (source[idx + 2] >> 6)];
             dest[out_idx + 3] = encoder.alphabet_chars[source[idx + 2] & 0x3f];
             out_idx += 4;
-        }
-        if (idx + 1 < source.len) {
+        } else if (idx + 1 < source.len) {
             dest[out_idx] = encoder.alphabet_chars[source[idx] >> 2];
             dest[out_idx + 1] = encoder.alphabet_chars[((source[idx] & 0x3) << 4) | (source[idx + 1] >> 4)];
             dest[out_idx + 2] = encoder.alphabet_chars[(source[idx + 1] & 0xf) << 2];
@@ -130,15 +144,18 @@ pub const Base64Encoder = struct {
 
 pub const Base64Decoder = struct {
     const invalid_char: u8 = 0xff;
+    const invalid_char_tst: u32 = 0xff000000;
 
     /// e.g. 'A' => 0.
     /// `invalid_char` for any value not in the 64 alphabet chars.
     char_to_index: [256]u8,
+    fast_char_to_index: [4][256]u32,
     pad_char: ?u8,
 
     pub fn init(alphabet_chars: [64]u8, pad_char: ?u8) Base64Decoder {
         var result = Base64Decoder{
             .char_to_index = [_]u8{invalid_char} ** 256,
+            .fast_char_to_index = .{[_]u32{invalid_char_tst} ** 256} ** 4,
             .pad_char = pad_char,
         };
 
@@ -147,6 +164,12 @@ pub const Base64Decoder = struct {
             assert(!char_in_alphabet[c]);
             assert(pad_char == null or c != pad_char.?);
 
+            const ci = @as(u32, @intCast(i));
+            result.fast_char_to_index[0][c] = ci << 2;
+            result.fast_char_to_index[1][c] = (ci >> 4) | ((ci & 0x0f) << 12);
+            result.fast_char_to_index[2][c] = ((ci & 0x3) << 22) | ((ci & 0x3c) << 6);
+            result.fast_char_to_index[3][c] = ci << 16;
+
             result.char_to_index[c] = @as(u8, @intCast(i));
             char_in_alphabet[c] = true;
         }
@@ -184,11 +207,39 @@ pub const Base64Decoder = struct {
     /// invalid padding results in error.InvalidPadding.
     pub fn decode(decoder: *const Base64Decoder, dest: []u8, source: []const u8) Error!void {
         if (decoder.pad_char != null and source.len % 4 != 0) return error.InvalidPadding;
+        var dest_idx: usize = 0;
+        var fast_src_idx: usize = 0;
         var acc: u12 = 0;
         var acc_len: u4 = 0;
-        var dest_idx: usize = 0;
         var leftover_idx: ?usize = null;
-        for (source, 0..) |c, src_idx| {
+        while (fast_src_idx + 16 < source.len and dest_idx + 15 < dest.len) : ({
+            fast_src_idx += 16;
+            dest_idx += 12;
+        }) {
+            var bits: u128 = 0;
+            inline for (0..4) |i| {
+                var new_bits: u128 = decoder.fast_char_to_index[0][source[fast_src_idx + i * 4]];
+                new_bits |= decoder.fast_char_to_index[1][source[fast_src_idx + 1 + i * 4]];
+                new_bits |= decoder.fast_char_to_index[2][source[fast_src_idx + 2 + i * 4]];
+                new_bits |= decoder.fast_char_to_index[3][source[fast_src_idx + 3 + i * 4]];
+                if ((new_bits & invalid_char_tst) != 0) return error.InvalidCharacter;
+                bits |= (new_bits << (24 * i));
+            }
+            std.mem.writeIntLittle(u128, dest[dest_idx..][0..16], bits);
+        }
+        while (fast_src_idx + 4 < source.len and dest_idx + 3 < dest.len) : ({
+            fast_src_idx += 4;
+            dest_idx += 3;
+        }) {
+            var bits = decoder.fast_char_to_index[0][source[fast_src_idx]];
+            bits |= decoder.fast_char_to_index[1][source[fast_src_idx + 1]];
+            bits |= decoder.fast_char_to_index[2][source[fast_src_idx + 2]];
+            bits |= decoder.fast_char_to_index[3][source[fast_src_idx + 3]];
+            if ((bits & invalid_char_tst) != 0) return error.InvalidCharacter;
+            std.mem.writeIntLittle(u32, dest[dest_idx..][0..4], bits);
+        }
+        var remaining = source[fast_src_idx..];
+        for (remaining, fast_src_idx..) |c, src_idx| {
             const d = decoder.char_to_index[c];
             if (d == invalid_char) {
                 if (decoder.pad_char == null or c != decoder.pad_char.?) return error.InvalidCharacter;
@@ -338,6 +389,10 @@ fn testBase64() !void {
     try testAllApis(codecs, "foob", "Zm9vYg==");
     try testAllApis(codecs, "fooba", "Zm9vYmE=");
     try testAllApis(codecs, "foobar", "Zm9vYmFy");
+    try testAllApis(codecs, "foobarfoobarfoo", "Zm9vYmFyZm9vYmFyZm9v");
+    try testAllApis(codecs, "foobarfoobarfoob", "Zm9vYmFyZm9vYmFyZm9vYg==");
+    try testAllApis(codecs, "foobarfoobarfooba", "Zm9vYmFyZm9vYmFyZm9vYmE=");
+    try testAllApis(codecs, "foobarfoobarfoobar", "Zm9vYmFyZm9vYmFyZm9vYmFy");
 
     try testDecodeIgnoreSpace(codecs, "", " ");
     try testDecodeIgnoreSpace(codecs, "f", "Z g= =");
@@ -357,11 +412,23 @@ fn testBase64() !void {
     try testError(codecs, "A/==", error.InvalidPadding);
     try testError(codecs, "A===", error.InvalidPadding);
     try testError(codecs, "====", error.InvalidPadding);
+    try testError(codecs, "Zm9vYmFyZm9vYmFyA..A", error.InvalidCharacter);
+    try testError(codecs, "Zm9vYmFyZm9vYmFyAA=A", error.InvalidPadding);
+    try testError(codecs, "Zm9vYmFyZm9vYmFyAA/=", error.InvalidPadding);
+    try testError(codecs, "Zm9vYmFyZm9vYmFyA/==", error.InvalidPadding);
+    try testError(codecs, "Zm9vYmFyZm9vYmFyA===", error.InvalidPadding);
+    try testError(codecs, "A..AZm9vYmFyZm9vYmFy", error.InvalidCharacter);
+    try testError(codecs, "Zm9vYmFyZm9vAA=A", error.InvalidPadding);
+    try testError(codecs, "Zm9vYmFyZm9vAA/=", error.InvalidPadding);
+    try testError(codecs, "Zm9vYmFyZm9vA/==", error.InvalidPadding);
+    try testError(codecs, "Zm9vYmFyZm9vA===", error.InvalidPadding);
 
     try testNoSpaceLeftError(codecs, "AA==");
     try testNoSpaceLeftError(codecs, "AAA=");
     try testNoSpaceLeftError(codecs, "AAAA");
     try testNoSpaceLeftError(codecs, "AAAAAA==");
+
+    try testFourBytesDestNoSpaceLeftError(codecs, "AAAAAAAAAAAAAAAA");
 }
 
 fn testBase64UrlSafeNoPad() !void {
@@ -374,6 +441,7 @@ fn testBase64UrlSafeNoPad() !void {
     try testAllApis(codecs, "foob", "Zm9vYg");
     try testAllApis(codecs, "fooba", "Zm9vYmE");
     try testAllApis(codecs, "foobar", "Zm9vYmFy");
+    try testAllApis(codecs, "foobarfoobarfoobar", "Zm9vYmFyZm9vYmFyZm9vYmFy");
 
     try testDecodeIgnoreSpace(codecs, "", " ");
     try testDecodeIgnoreSpace(codecs, "f", "Z g ");
@@ -392,11 +460,15 @@ fn testBase64UrlSafeNoPad() !void {
     try testError(codecs, "A/==", error.InvalidCharacter);
     try testError(codecs, "A===", error.InvalidCharacter);
     try testError(codecs, "====", error.InvalidCharacter);
+    try testError(codecs, "Zm9vYmFyZm9vYmFyA..A", error.InvalidCharacter);
+    try testError(codecs, "A..AZm9vYmFyZm9vYmFy", error.InvalidCharacter);
 
     try testNoSpaceLeftError(codecs, "AA");
     try testNoSpaceLeftError(codecs, "AAA");
     try testNoSpaceLeftError(codecs, "AAAA");
     try testNoSpaceLeftError(codecs, "AAAAAA");
+
+    try testFourBytesDestNoSpaceLeftError(codecs, "AAAAAAAAAAAAAAAA");
 }
 
 fn testAllApis(codecs: Codecs, expected_decoded: []const u8, expected_encoded: []const u8) !void {
@@ -457,3 +529,12 @@ fn testNoSpaceLeftError(codecs: Codecs, encoded: []const u8) !void {
         return error.ExpectedError;
     } else |err| if (err != error.NoSpaceLeft) return err;
 }
+
+fn testFourBytesDestNoSpaceLeftError(codecs: Codecs, encoded: []const u8) !void {
+    const decoder_ignore_space = codecs.decoderWithIgnore(" ");
+    var buffer: [0x100]u8 = undefined;
+    var decoded = buffer[0..4];
+    if (decoder_ignore_space.decode(decoded, encoded)) |_| {
+        return error.ExpectedError;
+    } else |err| if (err != error.NoSpaceLeft) return err;
+}