Commit 9b386bda33

Frank Denis <github@pureftpd.org>
2020-10-09 23:19:27
std/crypto: add a vectorized ChaCha20 implementation
Brings a 30% speed boost on x86_64 even though we still process only one block at a time for now. Only enabled on x86_64 since the non-vectorized implementation seems to currently perform better on some architectures (at least on aarch64). But the non-vectorized implementation still gets a little speed boost as well (~17%) with these changes.
1 parent 53c63bd
Changed files (1)
lib
std
lib/std/crypto/chacha20.zig
@@ -10,120 +10,315 @@ const mem = std.mem;
 const assert = std.debug.assert;
 const testing = std.testing;
 const maxInt = std.math.maxInt;
+const Vector = std.meta.Vector;
 const Poly1305 = std.crypto.onetimeauth.Poly1305;
 
-const QuarterRound = struct {
-    a: usize,
-    b: usize,
-    c: usize,
-    d: usize,
-};
+// Vectorized implementation of the core function
+const ChaCha20VecImpl = struct {
+    const Lane = Vector(4, u32);
+    const BlockVec = [4]Lane;
+
+    fn initContext(key: [8]u32, d: [4]u32) BlockVec {
+        const c = "expand 32-byte k";
+        const constant_le = comptime Lane{
+            mem.readIntLittle(u32, c[0..4]),
+            mem.readIntLittle(u32, c[4..8]),
+            mem.readIntLittle(u32, c[8..12]),
+            mem.readIntLittle(u32, c[12..16]),
+        };
+        return BlockVec{
+            constant_le,
+            Lane{ key[0], key[1], key[2], key[3] },
+            Lane{ key[4], key[5], key[6], key[7] },
+            Lane{ d[0], d[1], d[2], d[3] },
+        };
+    }
 
-fn Rp(a: usize, b: usize, c: usize, d: usize) QuarterRound {
-    return QuarterRound{
-        .a = a,
-        .b = b,
-        .c = c,
-        .d = d,
-    };
-}
+    inline fn chacha20Core(x: *BlockVec, input: BlockVec) void {
+        const rot8 = Vector(16, i32){ 3, 0, 1, 2, 7, 4, 5, 6, 11, 8, 9, 10, 15, 12, 13, 14 };
+        const rot16 = Vector(16, i32){ 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13 };
+
+        x.* = input;
+
+        var r: usize = 0;
+        while (r < 20) : (r += 2) {
+            x[0] +%= x[1];
+            x[3] ^= x[0];
+            x[3] = @bitCast(Vector(4, u32), @shuffle(u8, @bitCast(Vector(16, u8), x[3]), undefined, rot16));
+
+            x[2] +%= x[3];
+            x[1] ^= x[2];
+
+            var t1 = x[1];
+            x[1] <<= @splat(4, @as(u5, 12));
+            t1 >>= @splat(4, @as(u5, 20));
+            x[1] ^= t1;
+
+            x[0] +%= x[1];
+            x[3] ^= x[0];
+            x[0] = @shuffle(u32, x[0], undefined, Vector(4, i32){ 3, 0, 1, 2 });
+            x[3] = @bitCast(Vector(4, u32), @shuffle(u8, @bitCast(Vector(16, u8), x[3]), undefined, rot8));
+
+            x[2] +%= x[3];
+            x[3] = @shuffle(u32, x[3], undefined, Vector(4, i32){ 2, 3, 0, 1 });
+            x[1] ^= x[2];
+            x[2] = @shuffle(u32, x[2], undefined, Vector(4, i32){ 1, 2, 3, 0 });
+
+            t1 = x[1];
+            x[1] <<= @splat(4, @as(u5, 7));
+            t1 >>= @splat(4, @as(u5, 25));
+            x[1] ^= t1;
+
+            x[0] +%= x[1];
+            x[3] ^= x[0];
+            x[3] = @bitCast(Vector(4, u32), @shuffle(u8, @bitCast(Vector(16, u8), x[3]), undefined, rot16));
+
+            x[2] +%= x[3];
+            x[1] ^= x[2];
+
+            t1 = x[1];
+            x[1] <<= @splat(4, @as(u5, 12));
+            t1 >>= @splat(4, @as(u5, 20));
+            x[1] ^= t1;
+
+            x[0] +%= x[1];
+            x[3] ^= x[0];
+            x[0] = @shuffle(u32, x[0], undefined, Vector(4, i32){ 1, 2, 3, 0 });
+            x[3] = @bitCast(Vector(4, u32), @shuffle(u8, @bitCast(Vector(16, u8), x[3]), undefined, rot8));
+
+            x[2] +%= x[3];
+            x[3] = @shuffle(u32, x[3], undefined, Vector(4, i32){ 2, 3, 0, 1 });
+            x[1] ^= x[2];
+            x[2] = @shuffle(u32, x[2], undefined, Vector(4, i32){ 3, 0, 1, 2 });
+
+            t1 = x[1];
+            x[1] <<= @splat(4, @as(u5, 7));
+            t1 >>= @splat(4, @as(u5, 25));
+            x[1] ^= t1;
+        }
+    }
 
-fn initContext(key: [8]u32, d: [4]u32) [16]u32 {
-    var ctx: [16]u32 = undefined;
-    const c = "expand 32-byte k";
-    const constant_le = comptime [_]u32{
-        mem.readIntLittle(u32, c[0..4]),
-        mem.readIntLittle(u32, c[4..8]),
-        mem.readIntLittle(u32, c[8..12]),
-        mem.readIntLittle(u32, c[12..16]),
-    };
-    mem.copy(u32, ctx[0..], constant_le[0..4]);
-    mem.copy(u32, ctx[4..12], key[0..8]);
-    mem.copy(u32, ctx[12..16], d[0..4]);
+    inline fn hashToBytes(out: *[64]u8, x: BlockVec) void {
+        var i: usize = 0;
+        while (i < 4) : (i += 1) {
+            mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i][0]);
+            mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i][1]);
+            mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i][2]);
+            mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i][3]);
+        }
+    }
 
-    return ctx;
-}
+    inline fn contextFeedback(x: *BlockVec, ctx: BlockVec) void {
+        x[0] +%= ctx[0];
+        x[1] +%= ctx[1];
+        x[2] +%= ctx[2];
+        x[3] +%= ctx[3];
+    }
 
-// The chacha family of ciphers are based on the salsa family.
-inline fn chacha20Core(x: []u32, input: [16]u32) void {
-    for (x) |_, i|
-        x[i] = input[i];
-
-    const rounds = comptime [_]QuarterRound{
-        Rp(0, 4, 8, 12),
-        Rp(1, 5, 9, 13),
-        Rp(2, 6, 10, 14),
-        Rp(3, 7, 11, 15),
-        Rp(0, 5, 10, 15),
-        Rp(1, 6, 11, 12),
-        Rp(2, 7, 8, 13),
-        Rp(3, 4, 9, 14),
-    };
+    fn chaCha20Internal(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void {
+        var ctx = initContext(key, counter);
+        var x: BlockVec = undefined;
+        var buf: [64]u8 = undefined;
+        var i: usize = 0;
+        while (i + 64 <= in.len) : (i += 64) {
+            chacha20Core(x[0..], ctx);
+            contextFeedback(&x, ctx);
+            hashToBytes(buf[0..], x);
+
+            var xout = out[i..];
+            const xin = in[i..];
+            var j: usize = 0;
+            while (j < 64) : (j += 1) {
+                xout[j] = xin[j];
+            }
+            j = 0;
+            while (j < 64) : (j += 1) {
+                xout[j] ^= buf[j];
+            }
+            ctx[3][0] += 1;
+        }
+        if (i < in.len) {
+            chacha20Core(x[0..], ctx);
+            contextFeedback(&x, ctx);
+            hashToBytes(buf[0..], x);
+
+            var xout = out[i..];
+            const xin = in[i..];
+            var j: usize = 0;
+            while (j < in.len % 64) : (j += 1) {
+                xout[j] = xin[j] ^ buf[j];
+            }
+        }
+    }
 
-    comptime var j: usize = 0;
-    inline while (j < 20) : (j += 2) {
-        // two-round cycles
-        inline for (rounds) |r| {
-            x[r.a] +%= x[r.b];
-            x[r.d] = std.math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 16));
-            x[r.c] +%= x[r.d];
-            x[r.b] = std.math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 12));
-            x[r.a] +%= x[r.b];
-            x[r.d] = std.math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 8));
-            x[r.c] +%= x[r.d];
-            x[r.b] = std.math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 7));
+    fn hchacha20(input: [16]u8, key: [32]u8) [32]u8 {
+        var c: [4]u32 = undefined;
+        for (c) |_, i| {
+            c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]);
         }
+        const ctx = initContext(keyToWords(key), c);
+        var x: BlockVec = undefined;
+        chacha20Core(x[0..], ctx);
+        var out: [32]u8 = undefined;
+        mem.writeIntLittle(u32, out[0..4], x[0][0]);
+        mem.writeIntLittle(u32, out[4..8], x[0][1]);
+        mem.writeIntLittle(u32, out[8..12], x[0][2]);
+        mem.writeIntLittle(u32, out[12..16], x[0][3]);
+        mem.writeIntLittle(u32, out[16..20], x[3][0]);
+        mem.writeIntLittle(u32, out[20..24], x[3][1]);
+        mem.writeIntLittle(u32, out[24..28], x[3][2]);
+        mem.writeIntLittle(u32, out[28..32], x[3][3]);
+        return out;
     }
-}
+};
 
-fn hashToBytes(out: []u8, x: [16]u32) void {
-    for (x) |_, i| {
-        mem.writeIntLittle(u32, out[4 * i ..][0..4], x[i]);
+// Non-vectorized implementation of the core function
+const ChaCha20NonVecImpl = struct {
+    const BlockVec = [16]u32;
+
+    fn initContext(key: [8]u32, d: [4]u32) BlockVec {
+        const c = "expand 32-byte k";
+        const constant_le = comptime [4]u32{
+            mem.readIntLittle(u32, c[0..4]),
+            mem.readIntLittle(u32, c[4..8]),
+            mem.readIntLittle(u32, c[8..12]),
+            mem.readIntLittle(u32, c[12..16]),
+        };
+        return BlockVec{
+            constant_le[0], constant_le[1], constant_le[2], constant_le[3],
+            key[0],         key[1],         key[2],         key[3],
+            key[4],         key[5],         key[6],         key[7],
+            d[0],           d[1],           d[2],           d[3],
+        };
     }
-}
 
-fn chaCha20_internal(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void {
-    var ctx = initContext(key, counter);
-    var remaining: usize = if (in.len > out.len) in.len else out.len;
-    var cursor: usize = 0;
+    const QuarterRound = struct {
+        a: usize,
+        b: usize,
+        c: usize,
+        d: usize,
+    };
 
-    while (true) {
-        var x: [16]u32 = undefined;
-        var buf: [64]u8 = undefined;
-        chacha20Core(x[0..], ctx);
-        for (x) |_, i| {
-            x[i] +%= ctx[i];
+    fn Rp(a: usize, b: usize, c: usize, d: usize) QuarterRound {
+        return QuarterRound{
+            .a = a,
+            .b = b,
+            .c = c,
+            .d = d,
+        };
+    }
+
+    inline fn chacha20Core(x: *BlockVec, input: BlockVec) void {
+        x.* = input;
+
+        const rounds = comptime [_]QuarterRound{
+            Rp(0, 4, 8, 12),
+            Rp(1, 5, 9, 13),
+            Rp(2, 6, 10, 14),
+            Rp(3, 7, 11, 15),
+            Rp(0, 5, 10, 15),
+            Rp(1, 6, 11, 12),
+            Rp(2, 7, 8, 13),
+            Rp(3, 4, 9, 14),
+        };
+
+        comptime var j: usize = 0;
+        inline while (j < 20) : (j += 2) {
+            inline for (rounds) |r| {
+                x[r.a] +%= x[r.b];
+                x[r.d] = std.math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 16));
+                x[r.c] +%= x[r.d];
+                x[r.b] = std.math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 12));
+                x[r.a] +%= x[r.b];
+                x[r.d] = std.math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 8));
+                x[r.c] +%= x[r.d];
+                x[r.b] = std.math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 7));
+            }
         }
-        hashToBytes(buf[0..], x);
-        if (remaining < 64) {
-            var i: usize = 0;
-            while (i < remaining) : (i += 1)
-                out[cursor + i] = in[cursor + i] ^ buf[i];
-            return;
+    }
+
+    inline fn hashToBytes(out: *[64]u8, x: BlockVec) void {
+        var i: usize = 0;
+        while (i < 4) : (i += 1) {
+            mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i * 4 + 0]);
+            mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i * 4 + 1]);
+            mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i * 4 + 2]);
+            mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i * 4 + 3]);
         }
+    }
 
+    inline fn contextFeedback(x: *BlockVec, ctx: BlockVec) void {
         var i: usize = 0;
-        while (i < 64) : (i += 1)
-            out[cursor + i] = in[cursor + i] ^ buf[i];
+        while (i < 16) : (i += 1) {
+            x[i] +%= ctx[i];
+        }
+    }
 
-        cursor += 64;
-        remaining -= 64;
+    fn chaCha20Internal(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void {
+        var ctx = initContext(key, counter);
+        var x: BlockVec = undefined;
+        var buf: [64]u8 = undefined;
+        var i: usize = 0;
+        while (i + 64 <= in.len) : (i += 64) {
+            chacha20Core(x[0..], ctx);
+            contextFeedback(&x, ctx);
+            hashToBytes(buf[0..], x);
+
+            var xout = out[i..];
+            const xin = in[i..];
+            var j: usize = 0;
+            while (j < 64) : (j += 1) {
+                xout[j] = xin[j];
+            }
+            j = 0;
+            while (j < 64) : (j += 1) {
+                xout[j] ^= buf[j];
+            }
+            ctx[12] += 1;
+        }
+        if (i < in.len) {
+            chacha20Core(x[0..], ctx);
+            contextFeedback(&x, ctx);
+            hashToBytes(buf[0..], x);
+
+            var xout = out[i..];
+            const xin = in[i..];
+            var j: usize = 0;
+            while (j < in.len % 64) : (j += 1) {
+                xout[j] = xin[j] ^ buf[j];
+            }
+        }
+    }
 
-        ctx[12] += 1;
+    fn hchacha20(input: [16]u8, key: [32]u8) [32]u8 {
+        var c: [4]u32 = undefined;
+        for (c) |_, i| {
+            c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]);
+        }
+        const ctx = initContext(keyToWords(key), c);
+        var x: BlockVec = undefined;
+        chacha20Core(x[0..], ctx);
+        var out: [32]u8 = undefined;
+        mem.writeIntLittle(u32, out[0..4], x[0]);
+        mem.writeIntLittle(u32, out[4..8], x[1]);
+        mem.writeIntLittle(u32, out[8..12], x[2]);
+        mem.writeIntLittle(u32, out[12..16], x[3]);
+        mem.writeIntLittle(u32, out[16..20], x[12]);
+        mem.writeIntLittle(u32, out[20..24], x[13]);
+        mem.writeIntLittle(u32, out[24..28], x[14]);
+        mem.writeIntLittle(u32, out[28..32], x[15]);
+        return out;
     }
-}
+};
+
+const ChaCha20Impl = if (std.Target.current.cpu.arch == .x86_64) ChaCha20VecImpl else ChaCha20NonVecImpl;
 
 fn keyToWords(key: [32]u8) [8]u32 {
     var k: [8]u32 = undefined;
-    k[0] = mem.readIntLittle(u32, key[0..4]);
-    k[1] = mem.readIntLittle(u32, key[4..8]);
-    k[2] = mem.readIntLittle(u32, key[8..12]);
-    k[3] = mem.readIntLittle(u32, key[12..16]);
-    k[4] = mem.readIntLittle(u32, key[16..20]);
-    k[5] = mem.readIntLittle(u32, key[20..24]);
-    k[6] = mem.readIntLittle(u32, key[24..28]);
-    k[7] = mem.readIntLittle(u32, key[28..32]);
-
+    var i: usize = 0;
+    while (i < 8) : (i += 1) {
+        k[i] = mem.readIntLittle(u32, key[i * 4 ..][0..4]);
+    }
     return k;
 }
 
@@ -145,7 +340,7 @@ pub const ChaCha20IETF = struct {
         c[1] = mem.readIntLittle(u32, nonce[0..4]);
         c[2] = mem.readIntLittle(u32, nonce[4..8]);
         c[3] = mem.readIntLittle(u32, nonce[8..12]);
-        chaCha20_internal(out, in, keyToWords(key), c);
+        ChaCha20Impl.chaCha20Internal(out, in, keyToWords(key), c);
     }
 };
 
@@ -171,7 +366,7 @@ pub const ChaCha20With64BitNonce = struct {
 
         // first partial big block
         if (((@intCast(u64, maxInt(u32) - @truncate(u32, counter)) + 1) << 6) < in.len) {
-            chaCha20_internal(out[cursor..big_block], in[cursor..big_block], k, c);
+            ChaCha20Impl.chaCha20Internal(out[cursor..big_block], in[cursor..big_block], k, c);
             cursor = big_block - cursor;
             c[1] += 1;
             if (comptime @sizeOf(usize) > 4) {
@@ -179,14 +374,14 @@ pub const ChaCha20With64BitNonce = struct {
                 var remaining_blocks: u32 = @intCast(u32, (in.len / big_block));
                 var i: u32 = 0;
                 while (remaining_blocks > 0) : (remaining_blocks -= 1) {
-                    chaCha20_internal(out[cursor .. cursor + big_block], in[cursor .. cursor + big_block], k, c);
-                    c[1] += 1; // upper 32-bit of counter, generic chaCha20_internal() doesn't know about this.
+                    ChaCha20Impl.chaCha20Internal(out[cursor .. cursor + big_block], in[cursor .. cursor + big_block], k, c);
+                    c[1] += 1; // upper 32-bit of counter, generic chaCha20Internal() doesn't know about this.
                     cursor += big_block;
                 }
             }
         }
 
-        chaCha20_internal(out[cursor..], in[cursor..], k, c);
+        ChaCha20Impl.chaCha20Internal(out[cursor..], in[cursor..], k, c);
     }
 };
 
@@ -533,33 +728,12 @@ fn chacha20poly1305Open(dst: []u8, ciphertextAndTag: []const u8, data: []const u
     return try chacha20poly1305OpenDetached(dst, ciphertextAndTag[0..ciphertextLen], ciphertextAndTag[ciphertextLen..][0..chacha20poly1305_tag_size], data, key, nonce);
 }
 
-fn hchacha20(input: [16]u8, key: [32]u8) [32]u8 {
-    var c: [4]u32 = undefined;
-    for (c) |_, i| {
-        c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]);
-    }
-    const ctx = initContext(keyToWords(key), c);
-    var x: [16]u32 = undefined;
-    chacha20Core(x[0..], ctx);
-    var out: [32]u8 = undefined;
-    mem.writeIntLittle(u32, out[0..4], x[0]);
-    mem.writeIntLittle(u32, out[4..8], x[1]);
-    mem.writeIntLittle(u32, out[8..12], x[2]);
-    mem.writeIntLittle(u32, out[12..16], x[3]);
-    mem.writeIntLittle(u32, out[16..20], x[12]);
-    mem.writeIntLittle(u32, out[20..24], x[13]);
-    mem.writeIntLittle(u32, out[24..28], x[14]);
-    mem.writeIntLittle(u32, out[28..32], x[15]);
-
-    return out;
-}
-
 fn extend(key: [32]u8, nonce: [24]u8) struct { key: [32]u8, nonce: [12]u8 } {
     var subnonce: [12]u8 = undefined;
     mem.set(u8, subnonce[0..4], 0);
     mem.copy(u8, subnonce[4..], nonce[16..24]);
     return .{
-        .key = hchacha20(nonce[0..16].*, key),
+        .key = ChaCha20Impl.hchacha20(nonce[0..16].*, key),
         .nonce = subnonce,
     };
 }