Commit 10edb6d352

Cody Tapscott <topolarity@tapscott.me>
2022-10-23 08:50:38
crypto.sha2: Use intrinsics for SHA-256 on x86-64 and AArch64
There's probably plenty of room to optimize these further in the future, but for the moment this gives ~3x improvement on Intel x86-64 processors, ~5x on AMD, and ~10x on M1 Macs. These extensions are very new - Most processors prior to 2020 do not support them. AVX-512 is a slightly older alternative that we could use on Intel for a much bigger performance bump, but it's been fused off on Intel's latest hybrid architectures and it relies on computing independent SHA hashes in parallel. In contrast, these SHA intrinsics provide the usual single-threaded, single-stream interface, and should continue working on new processors. AArch64 also has SHA-512 intrinsics that we could take advantage of in the future
1 parent c616141
Changed files (1)
lib
std
crypto
lib/std/crypto/sha2.zig
@@ -1,4 +1,5 @@
 const std = @import("../std.zig");
+const builtin = @import("builtin");
 const mem = std.mem;
 const math = std.math;
 const htest = @import("test.zig");
@@ -16,10 +17,9 @@ const RoundParam256 = struct {
     g: usize,
     h: usize,
     i: usize,
-    k: u32,
 };
 
-fn roundParam256(a: usize, b: usize, c: usize, d: usize, e: usize, f: usize, g: usize, h: usize, i: usize, k: u32) RoundParam256 {
+fn roundParam256(a: usize, b: usize, c: usize, d: usize, e: usize, f: usize, g: usize, h: usize, i: usize) RoundParam256 {
     return RoundParam256{
         .a = a,
         .b = b,
@@ -30,7 +30,6 @@ fn roundParam256(a: usize, b: usize, c: usize, d: usize, e: usize, f: usize, g:
         .g = g,
         .h = h,
         .i = i,
-        .k = k,
     };
 }
 
@@ -70,6 +69,8 @@ const Sha256Params = Sha2Params32{
     .digest_bits = 256,
 };
 
+const v4u32 = @Vector(4, u32);
+
 /// SHA-224
 pub const Sha224 = Sha2x32(Sha224Params);
 
@@ -83,7 +84,7 @@ fn Sha2x32(comptime params: Sha2Params32) type {
         pub const digest_length = params.digest_bits / 8;
         pub const Options = struct {};
 
-        s: [8]u32,
+        s: [8]u32 align(16),
         // Streaming Cache
         buf: [64]u8 = undefined,
         buf_len: u8 = 0,
@@ -168,8 +169,19 @@ fn Sha2x32(comptime params: Sha2Params32) type {
             }
         }
 
+        const W = [64]u32{
+            0x428A2F98, 0x71374491, 0xB5C0FBCF, 0xE9B5DBA5, 0x3956C25B, 0x59F111F1, 0x923F82A4, 0xAB1C5ED5,
+            0xD807AA98, 0x12835B01, 0x243185BE, 0x550C7DC3, 0x72BE5D74, 0x80DEB1FE, 0x9BDC06A7, 0xC19BF174,
+            0xE49B69C1, 0xEFBE4786, 0x0FC19DC6, 0x240CA1CC, 0x2DE92C6F, 0x4A7484AA, 0x5CB0A9DC, 0x76F988DA,
+            0x983E5152, 0xA831C66D, 0xB00327C8, 0xBF597FC7, 0xC6E00BF3, 0xD5A79147, 0x06CA6351, 0x14292967,
+            0x27B70A85, 0x2E1B2138, 0x4D2C6DFC, 0x53380D13, 0x650A7354, 0x766A0ABB, 0x81C2C92E, 0x92722C85,
+            0xA2BFE8A1, 0xA81A664B, 0xC24B8B70, 0xC76C51A3, 0xD192E819, 0xD6990624, 0xF40E3585, 0x106AA070,
+            0x19A4C116, 0x1E376C08, 0x2748774C, 0x34B0BCB5, 0x391C0CB3, 0x4ED8AA4A, 0x5B9CCA4F, 0x682E6FF3,
+            0x748F82EE, 0x78A5636F, 0x84C87814, 0x8CC70208, 0x90BEFFFA, 0xA4506CEB, 0xBEF9A3F7, 0xC67178F2,
+        };
+
         fn round(d: *Self, b: *const [64]u8) void {
-            var s: [64]u32 = undefined;
+            var s: [64]u32 align(16) = undefined;
 
             var i: usize = 0;
             while (i < 16) : (i += 1) {
@@ -179,6 +191,88 @@ fn Sha2x32(comptime params: Sha2Params32) type {
                 s[i] |= @as(u32, b[i * 4 + 2]) << 8;
                 s[i] |= @as(u32, b[i * 4 + 3]) << 0;
             }
+
+            if (builtin.cpu.arch == .aarch64 and builtin.cpu.features.isEnabled(@enumToInt(std.Target.aarch64.Feature.sha2))) {
+                var x: v4u32 = d.s[0..4].*;
+                var y: v4u32 = d.s[4..8].*;
+                const s_v = @ptrCast(*[16]v4u32, &s);
+
+                comptime var k: u8 = 0;
+                inline while (k < 16) : (k += 1) {
+                    if (k > 3) {
+                        s_v[k] = asm (
+                            \\sha256su0.4s %[w0_3], %[w4_7]
+                            \\sha256su1.4s %[w0_3], %[w8_11], %[w12_15]
+                            : [w0_3] "=w" (-> v4u32),
+                            : [_] "0" (s_v[k - 4]),
+                              [w4_7] "w" (s_v[k - 3]),
+                              [w8_11] "w" (s_v[k - 2]),
+                              [w12_15] "w" (s_v[k - 1]),
+                        );
+                    }
+
+                    const w: v4u32 = s_v[k] +% @as(v4u32, W[4 * k ..][0..4].*);
+                    asm volatile (
+                        \\mov.4s v0, %[x]
+                        \\sha256h.4s %[x], %[y], %[w]
+                        \\sha256h2.4s %[y], v0, %[w]
+                        : [x] "=w" (x),
+                          [y] "=w" (y),
+                        : [_] "0" (x),
+                          [_] "1" (y),
+                          [w] "w" (w),
+                        : "v0"
+                    );
+                }
+
+                d.s[0..4].* = x +% @as(v4u32, d.s[0..4].*);
+                d.s[4..8].* = y +% @as(v4u32, d.s[4..8].*);
+                return;
+            } else if (builtin.cpu.arch == .x86_64 and builtin.cpu.features.isEnabled(@enumToInt(std.Target.x86.Feature.sha))) {
+                var x: v4u32 = [_]u32{ d.s[5], d.s[4], d.s[1], d.s[0] };
+                var y: v4u32 = [_]u32{ d.s[7], d.s[6], d.s[3], d.s[2] };
+                const s_v = @ptrCast(*[16]v4u32, &s);
+
+                comptime var k: u8 = 0;
+                inline while (k < 16) : (k += 1) {
+                    if (k < 12) {
+                        const r = asm ("sha256msg1 %[w4_7], %[w0_3]"
+                            : [w0_3] "=x" (-> v4u32),
+                            : [_] "0" (s_v[k]),
+                              [w4_7] "x" (s_v[k + 1]),
+                        );
+                        const t = @shuffle(u32, s_v[k + 2], s_v[k + 3], [_]i32{ 1, 2, 3, -1 });
+                        s_v[k + 4] = asm ("sha256msg2 %[w12_15], %[t]"
+                            : [t] "=x" (-> v4u32),
+                            : [_] "0" (r +% t),
+                              [w12_15] "x" (s_v[k + 3]),
+                        );
+                    }
+
+                    const w: v4u32 = s_v[k] +% @as(v4u32, W[4 * k ..][0..4].*);
+                    asm volatile (
+                        \\sha256rnds2 %[x], %[y]
+                        \\pshufd $0xe, %%xmm0, %%xmm0
+                        \\sha256rnds2 %[y], %[x]
+                        : [y] "=x" (y),
+                          [x] "=x" (x),
+                        : [_] "0" (y),
+                          [_] "1" (x),
+                          [_] "{xmm0}" (w),
+                    );
+                }
+
+                d.s[0] +%= x[3];
+                d.s[1] +%= x[2];
+                d.s[4] +%= x[1];
+                d.s[5] +%= x[0];
+                d.s[2] +%= y[3];
+                d.s[3] +%= y[2];
+                d.s[6] +%= y[1];
+                d.s[7] +%= y[0];
+                return;
+            }
+
             while (i < 64) : (i += 1) {
                 s[i] = s[i - 16] +% s[i - 7] +% (math.rotr(u32, s[i - 15], @as(u32, 7)) ^ math.rotr(u32, s[i - 15], @as(u32, 18)) ^ (s[i - 15] >> 3)) +% (math.rotr(u32, s[i - 2], @as(u32, 17)) ^ math.rotr(u32, s[i - 2], @as(u32, 19)) ^ (s[i - 2] >> 10));
             }
@@ -195,73 +289,73 @@ fn Sha2x32(comptime params: Sha2Params32) type {
             };
 
             const round0 = comptime [_]RoundParam256{
-                roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 0, 0x428A2F98),
-                roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 1, 0x71374491),
-                roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 2, 0xB5C0FBCF),
-                roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 3, 0xE9B5DBA5),
-                roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 4, 0x3956C25B),
-                roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 5, 0x59F111F1),
-                roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 6, 0x923F82A4),
-                roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 7, 0xAB1C5ED5),
-                roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 8, 0xD807AA98),
-                roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 9, 0x12835B01),
-                roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 10, 0x243185BE),
-                roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 11, 0x550C7DC3),
-                roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 12, 0x72BE5D74),
-                roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 13, 0x80DEB1FE),
-                roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 14, 0x9BDC06A7),
-                roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 15, 0xC19BF174),
-                roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 16, 0xE49B69C1),
-                roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 17, 0xEFBE4786),
-                roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 18, 0x0FC19DC6),
-                roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 19, 0x240CA1CC),
-                roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 20, 0x2DE92C6F),
-                roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 21, 0x4A7484AA),
-                roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 22, 0x5CB0A9DC),
-                roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 23, 0x76F988DA),
-                roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 24, 0x983E5152),
-                roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 25, 0xA831C66D),
-                roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 26, 0xB00327C8),
-                roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 27, 0xBF597FC7),
-                roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 28, 0xC6E00BF3),
-                roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 29, 0xD5A79147),
-                roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 30, 0x06CA6351),
-                roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 31, 0x14292967),
-                roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 32, 0x27B70A85),
-                roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 33, 0x2E1B2138),
-                roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 34, 0x4D2C6DFC),
-                roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 35, 0x53380D13),
-                roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 36, 0x650A7354),
-                roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 37, 0x766A0ABB),
-                roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 38, 0x81C2C92E),
-                roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 39, 0x92722C85),
-                roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 40, 0xA2BFE8A1),
-                roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 41, 0xA81A664B),
-                roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 42, 0xC24B8B70),
-                roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 43, 0xC76C51A3),
-                roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 44, 0xD192E819),
-                roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 45, 0xD6990624),
-                roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 46, 0xF40E3585),
-                roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 47, 0x106AA070),
-                roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 48, 0x19A4C116),
-                roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 49, 0x1E376C08),
-                roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 50, 0x2748774C),
-                roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 51, 0x34B0BCB5),
-                roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 52, 0x391C0CB3),
-                roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 53, 0x4ED8AA4A),
-                roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 54, 0x5B9CCA4F),
-                roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 55, 0x682E6FF3),
-                roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 56, 0x748F82EE),
-                roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 57, 0x78A5636F),
-                roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 58, 0x84C87814),
-                roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 59, 0x8CC70208),
-                roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 60, 0x90BEFFFA),
-                roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 61, 0xA4506CEB),
-                roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 62, 0xBEF9A3F7),
-                roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 63, 0xC67178F2),
+                roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 0),
+                roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 1),
+                roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 2),
+                roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 3),
+                roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 4),
+                roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 5),
+                roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 6),
+                roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 7),
+                roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 8),
+                roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 9),
+                roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 10),
+                roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 11),
+                roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 12),
+                roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 13),
+                roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 14),
+                roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 15),
+                roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 16),
+                roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 17),
+                roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 18),
+                roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 19),
+                roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 20),
+                roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 21),
+                roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 22),
+                roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 23),
+                roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 24),
+                roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 25),
+                roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 26),
+                roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 27),
+                roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 28),
+                roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 29),
+                roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 30),
+                roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 31),
+                roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 32),
+                roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 33),
+                roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 34),
+                roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 35),
+                roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 36),
+                roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 37),
+                roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 38),
+                roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 39),
+                roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 40),
+                roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 41),
+                roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 42),
+                roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 43),
+                roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 44),
+                roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 45),
+                roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 46),
+                roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 47),
+                roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 48),
+                roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 49),
+                roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 50),
+                roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 51),
+                roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 52),
+                roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 53),
+                roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 54),
+                roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 55),
+                roundParam256(0, 1, 2, 3, 4, 5, 6, 7, 56),
+                roundParam256(7, 0, 1, 2, 3, 4, 5, 6, 57),
+                roundParam256(6, 7, 0, 1, 2, 3, 4, 5, 58),
+                roundParam256(5, 6, 7, 0, 1, 2, 3, 4, 59),
+                roundParam256(4, 5, 6, 7, 0, 1, 2, 3, 60),
+                roundParam256(3, 4, 5, 6, 7, 0, 1, 2, 61),
+                roundParam256(2, 3, 4, 5, 6, 7, 0, 1, 62),
+                roundParam256(1, 2, 3, 4, 5, 6, 7, 0, 63),
             };
             inline for (round0) |r| {
-                v[r.h] = v[r.h] +% (math.rotr(u32, v[r.e], @as(u32, 6)) ^ math.rotr(u32, v[r.e], @as(u32, 11)) ^ math.rotr(u32, v[r.e], @as(u32, 25))) +% (v[r.g] ^ (v[r.e] & (v[r.f] ^ v[r.g]))) +% r.k +% s[r.i];
+                v[r.h] = v[r.h] +% (math.rotr(u32, v[r.e], @as(u32, 6)) ^ math.rotr(u32, v[r.e], @as(u32, 11)) ^ math.rotr(u32, v[r.e], @as(u32, 25))) +% (v[r.g] ^ (v[r.e] & (v[r.f] ^ v[r.g]))) +% W[r.i] +% s[r.i];
 
                 v[r.d] = v[r.d] +% v[r.h];