Commit 5af89b3dcc

Frank Denis <124872+jedisct1@users.noreply.github.com>
2023-05-22 20:33:35
std.crypto.chacha: support larger vectors on AVX2 and AVX512 targets (#15809)
* std.crypto.chacha: support larger vectors on AVX2 and AVX512 targets Ryzen 7 7700, ChaCha20/8 stream, long outputs: Generic: 3268 MiB/s AVX2 : 6023 MiB/s AVX512 : 8086 MiB/s Bump the rand.chacha buffer a tiny bit to take advantage of this. More than 8 blocks doesn't seem to make any measurable difference. ChaChaPoly also gets a small performance boost from this, albeit Poly1305 remains the bottleneck. Generic: 707 MiB/s AVX2 : 981 MiB/s AVX512 : 1202 MiB/s aarch64 appears to generally benefit from 4-way vectorization. Verified on Apple Silicon, but also on a Cortex A72.
1 parent eef9275
Changed files (2)
lib
std
lib/std/crypto/chacha20.zig
@@ -76,30 +76,98 @@ pub const XChaCha12Poly1305 = XChaChaPoly1305(12);
 pub const XChaCha8Poly1305 = XChaChaPoly1305(8);
 
 // Vectorized implementation of the core function
-fn ChaChaVecImpl(comptime rounds_nb: usize) type {
+fn ChaChaVecImpl(comptime rounds_nb: usize, comptime degree: comptime_int) type {
     return struct {
-        const Lane = @Vector(4, u32);
+        const Lane = @Vector(4 * degree, u32);
         const BlockVec = [4]Lane;
 
         fn initContext(key: [8]u32, d: [4]u32) BlockVec {
             const c = "expand 32-byte k";
-            const constant_le = comptime Lane{
-                mem.readIntLittle(u32, c[0..4]),
-                mem.readIntLittle(u32, c[4..8]),
-                mem.readIntLittle(u32, c[8..12]),
-                mem.readIntLittle(u32, c[12..16]),
-            };
-            return BlockVec{
-                constant_le,
-                Lane{ key[0], key[1], key[2], key[3] },
-                Lane{ key[4], key[5], key[6], key[7] },
-                Lane{ d[0], d[1], d[2], d[3] },
-            };
+            switch (degree) {
+                1 => {
+                    const constant_le = Lane{
+                        mem.readIntLittle(u32, c[0..4]),
+                        mem.readIntLittle(u32, c[4..8]),
+                        mem.readIntLittle(u32, c[8..12]),
+                        mem.readIntLittle(u32, c[12..16]),
+                    };
+                    return BlockVec{
+                        constant_le,
+                        Lane{ key[0], key[1], key[2], key[3] },
+                        Lane{ key[4], key[5], key[6], key[7] },
+                        Lane{ d[0], d[1], d[2], d[3] },
+                    };
+                },
+                2 => {
+                    const constant_le = Lane{
+                        mem.readIntLittle(u32, c[0..4]),
+                        mem.readIntLittle(u32, c[4..8]),
+                        mem.readIntLittle(u32, c[8..12]),
+                        mem.readIntLittle(u32, c[12..16]),
+                        mem.readIntLittle(u32, c[0..4]),
+                        mem.readIntLittle(u32, c[4..8]),
+                        mem.readIntLittle(u32, c[8..12]),
+                        mem.readIntLittle(u32, c[12..16]),
+                    };
+                    return BlockVec{
+                        constant_le,
+                        Lane{ key[0], key[1], key[2], key[3], key[0], key[1], key[2], key[3] },
+                        Lane{ key[4], key[5], key[6], key[7], key[4], key[5], key[6], key[7] },
+                        Lane{ d[0], d[1], d[2], d[3], d[0] +% 1, d[1], d[2], d[3] },
+                    };
+                },
+                4 => {
+                    const constant_le = Lane{
+                        mem.readIntLittle(u32, c[0..4]),
+                        mem.readIntLittle(u32, c[4..8]),
+                        mem.readIntLittle(u32, c[8..12]),
+                        mem.readIntLittle(u32, c[12..16]),
+                        mem.readIntLittle(u32, c[0..4]),
+                        mem.readIntLittle(u32, c[4..8]),
+                        mem.readIntLittle(u32, c[8..12]),
+                        mem.readIntLittle(u32, c[12..16]),
+                        mem.readIntLittle(u32, c[0..4]),
+                        mem.readIntLittle(u32, c[4..8]),
+                        mem.readIntLittle(u32, c[8..12]),
+                        mem.readIntLittle(u32, c[12..16]),
+                        mem.readIntLittle(u32, c[0..4]),
+                        mem.readIntLittle(u32, c[4..8]),
+                        mem.readIntLittle(u32, c[8..12]),
+                        mem.readIntLittle(u32, c[12..16]),
+                    };
+                    return BlockVec{
+                        constant_le,
+                        Lane{ key[0], key[1], key[2], key[3], key[0], key[1], key[2], key[3], key[0], key[1], key[2], key[3], key[0], key[1], key[2], key[3] },
+                        Lane{ key[4], key[5], key[6], key[7], key[4], key[5], key[6], key[7], key[4], key[5], key[6], key[7], key[4], key[5], key[6], key[7] },
+                        Lane{ d[0], d[1], d[2], d[3], d[0] +% 1, d[1], d[2], d[3], d[0] +% 2, d[1], d[2], d[3], d[0] +% 3, d[1], d[2], d[3] },
+                    };
+                },
+                else => @panic("invalid degree"),
+            }
         }
 
         inline fn chacha20Core(x: *BlockVec, input: BlockVec) void {
             x.* = input;
 
+            const m0 = switch (degree) {
+                1 => [_]i32{ 3, 0, 1, 2 },
+                2 => [_]i32{ 3, 0, 1, 2 } ++ [_]i32{ 7, 4, 5, 6 },
+                4 => [_]i32{ 3, 0, 1, 2 } ++ [_]i32{ 7, 4, 5, 6 } ++ [_]i32{ 11, 8, 9, 10 } ++ [_]i32{ 15, 12, 13, 14 },
+                else => @panic("invalid degree"),
+            };
+            const m1 = switch (degree) {
+                1 => [_]i32{ 2, 3, 0, 1 },
+                2 => [_]i32{ 2, 3, 0, 1 } ++ [_]i32{ 6, 7, 4, 5 },
+                4 => [_]i32{ 2, 3, 0, 1 } ++ [_]i32{ 6, 7, 4, 5 } ++ [_]i32{ 10, 11, 8, 9 } ++ [_]i32{ 14, 15, 12, 13 },
+                else => @panic("invalid degree"),
+            };
+            const m2 = switch (degree) {
+                1 => [_]i32{ 1, 2, 3, 0 },
+                2 => [_]i32{ 1, 2, 3, 0 } ++ [_]i32{ 5, 6, 7, 4 },
+                4 => [_]i32{ 1, 2, 3, 0 } ++ [_]i32{ 5, 6, 7, 4 } ++ [_]i32{ 9, 10, 11, 8 } ++ [_]i32{ 13, 14, 15, 12 },
+                else => @panic("invalid degree"),
+            };
+
             var r: usize = 0;
             while (r < rounds_nb) : (r += 2) {
                 x[0] +%= x[1];
@@ -112,13 +180,13 @@ fn ChaChaVecImpl(comptime rounds_nb: usize) type {
 
                 x[0] +%= x[1];
                 x[3] ^= x[0];
-                x[0] = @shuffle(u32, x[0], undefined, [_]i32{ 3, 0, 1, 2 });
+                x[0] = @shuffle(u32, x[0], undefined, m0);
                 x[3] = math.rotl(Lane, x[3], 8);
 
                 x[2] +%= x[3];
-                x[3] = @shuffle(u32, x[3], undefined, [_]i32{ 2, 3, 0, 1 });
+                x[3] = @shuffle(u32, x[3], undefined, m1);
                 x[1] ^= x[2];
-                x[2] = @shuffle(u32, x[2], undefined, [_]i32{ 1, 2, 3, 0 });
+                x[2] = @shuffle(u32, x[2], undefined, m2);
                 x[1] = math.rotl(Lane, x[1], 7);
 
                 x[0] +%= x[1];
@@ -131,24 +199,26 @@ fn ChaChaVecImpl(comptime rounds_nb: usize) type {
 
                 x[0] +%= x[1];
                 x[3] ^= x[0];
-                x[0] = @shuffle(u32, x[0], undefined, [_]i32{ 1, 2, 3, 0 });
+                x[0] = @shuffle(u32, x[0], undefined, m2);
                 x[3] = math.rotl(Lane, x[3], 8);
 
                 x[2] +%= x[3];
-                x[3] = @shuffle(u32, x[3], undefined, [_]i32{ 2, 3, 0, 1 });
+                x[3] = @shuffle(u32, x[3], undefined, m1);
                 x[1] ^= x[2];
-                x[2] = @shuffle(u32, x[2], undefined, [_]i32{ 3, 0, 1, 2 });
+                x[2] = @shuffle(u32, x[2], undefined, m0);
                 x[1] = math.rotl(Lane, x[1], 7);
             }
         }
 
-        inline fn hashToBytes(out: *[64]u8, x: BlockVec) void {
-            var i: usize = 0;
-            while (i < 4) : (i += 1) {
-                mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i][0]);
-                mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i][1]);
-                mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i][2]);
-                mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i][3]);
+        inline fn hashToBytes(comptime dm: usize, out: *[64 * dm]u8, x: BlockVec) void {
+            for (0..dm) |d| {
+                var i: usize = 0;
+                while (i < 4) : (i += 1) {
+                    mem.writeIntLittle(u32, out[64 * d + 16 * i + 0 ..][0..4], x[i][0 + 4 * d]);
+                    mem.writeIntLittle(u32, out[64 * d + 16 * i + 4 ..][0..4], x[i][1 + 4 * d]);
+                    mem.writeIntLittle(u32, out[64 * d + 16 * i + 8 ..][0..4], x[i][2 + 4 * d]);
+                    mem.writeIntLittle(u32, out[64 * d + 16 * i + 12 ..][0..4], x[i][3 + 4 * d]);
+                }
             }
         }
 
@@ -162,29 +232,33 @@ fn ChaChaVecImpl(comptime rounds_nb: usize) type {
         fn chacha20Xor(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void {
             var ctx = initContext(key, counter);
             var x: BlockVec = undefined;
-            var buf: [64]u8 = undefined;
+            var buf: [64 * degree]u8 = undefined;
             var i: usize = 0;
-            while (i + 64 <= in.len) : (i += 64) {
-                chacha20Core(x[0..], ctx);
-                contextFeedback(&x, ctx);
-                hashToBytes(buf[0..], x);
-
-                var xout = out[i..];
-                const xin = in[i..];
-                var j: usize = 0;
-                while (j < 64) : (j += 1) {
-                    xout[j] = xin[j];
-                }
-                j = 0;
-                while (j < 64) : (j += 1) {
-                    xout[j] ^= buf[j];
+            inline for ([_]comptime_int{ 4, 2, 1 }) |d| {
+                while (degree >= d and i + 64 * d <= in.len) : (i += 64 * d) {
+                    chacha20Core(x[0..], ctx);
+                    contextFeedback(&x, ctx);
+                    hashToBytes(d, buf[0 .. 64 * d], x);
+
+                    var xout = out[i..];
+                    const xin = in[i..];
+                    var j: usize = 0;
+                    while (j < 64 * d) : (j += 1) {
+                        xout[j] = xin[j];
+                    }
+                    j = 0;
+                    while (j < 64 * d) : (j += 1) {
+                        xout[j] ^= buf[j];
+                    }
+                    inline for (0..d) |d_| {
+                        ctx[3][4 * d_] += @intCast(u32, d);
+                    }
                 }
-                ctx[3][0] += 1;
             }
             if (i < in.len) {
                 chacha20Core(x[0..], ctx);
                 contextFeedback(&x, ctx);
-                hashToBytes(buf[0..], x);
+                hashToBytes(1, buf[0..64], x);
 
                 var xout = out[i..];
                 const xin = in[i..];
@@ -199,18 +273,22 @@ fn ChaChaVecImpl(comptime rounds_nb: usize) type {
             var ctx = initContext(key, counter);
             var x: BlockVec = undefined;
             var i: usize = 0;
-            while (i + 64 <= out.len) : (i += 64) {
-                chacha20Core(x[0..], ctx);
-                contextFeedback(&x, ctx);
-                hashToBytes(out[i..][0..64], x);
-                ctx[3][0] += 1;
+            inline for ([_]comptime_int{ 4, 2, 1 }) |d| {
+                while (degree >= d and i + 64 * d <= out.len) : (i += 64 * d) {
+                    chacha20Core(x[0..], ctx);
+                    contextFeedback(&x, ctx);
+                    hashToBytes(d, out[i..][0 .. 64 * d], x);
+                    inline for (0..d) |d_| {
+                        ctx[3][4 * d_] += @intCast(u32, d);
+                    }
+                }
             }
             if (i < out.len) {
                 chacha20Core(x[0..], ctx);
                 contextFeedback(&x, ctx);
 
                 var buf: [64]u8 = undefined;
-                hashToBytes(buf[0..], x);
+                hashToBytes(1, buf[0..], x);
                 @memcpy(out[i..], buf[0 .. out.len - i]);
             }
         }
@@ -399,7 +477,21 @@ fn ChaChaNonVecImpl(comptime rounds_nb: usize) type {
 }
 
 fn ChaChaImpl(comptime rounds_nb: usize) type {
-    return if (builtin.cpu.arch == .x86_64) ChaChaVecImpl(rounds_nb) else ChaChaNonVecImpl(rounds_nb);
+    switch (builtin.cpu.arch) {
+        .x86_64 => {
+            const has_avx2 = std.Target.x86.featureSetHas(builtin.cpu.features, .avx2);
+            const has_avx512f = std.Target.x86.featureSetHas(builtin.cpu.features, .avx512f);
+            if (has_avx512f) return ChaChaVecImpl(rounds_nb, 4);
+            if (has_avx2) return ChaChaVecImpl(rounds_nb, 2);
+            return ChaChaVecImpl(rounds_nb, 1);
+        },
+        .aarch64 => {
+            const has_neon = std.Target.aarch64.featureSetHas(builtin.cpu.features, .neon);
+            if (has_neon) return ChaChaVecImpl(rounds_nb, 4);
+            return ChaChaNonVecImpl(rounds_nb);
+        },
+        else => return ChaChaNonVecImpl(rounds_nb),
+    }
 }
 
 fn keyToWords(key: [32]u8) [8]u32 {
lib/std/rand/ChaCha.zig
@@ -10,7 +10,7 @@ const Self = @This();
 
 const Cipher = std.crypto.stream.chacha.ChaCha8IETF;
 
-const State = [2 * Cipher.block_length]u8;
+const State = [8 * Cipher.block_length]u8;
 
 state: State,
 offset: usize,