Commit 4c1f71e866

Cody Tapscott <topolarity@tapscott.me>
2022-10-24 18:47:31
std.crypto: Optimize SHA-256 intrinsics for AMD x86-64
This gets us most of the way back to the performance I had when I was using the LLVM intrinsics: - Intel Intel(R) Core(TM) i7-1068NG7 CPU @ 2.30GHz: 190.67 MB/s (w/o intrinsics) -> 1285.08 MB/s - AMD EPYC 7763 (VM) @ 2.45 GHz: 240.09 MB/s (w/o intrinsics) -> 1360.78 MB/s - Apple M1: 216.96 MB/s (w/o intrinsics) -> 2133.69 MB/s Minor changes to this source can swing performance from 400 MB/s to 1400 MB/s or... 20 MB/s, depending on how it interacts with the optimizer. I have a sneaking suspicion that despite LLVM inheriting GCC's extremely strict inline assembly semantics, its passes are rather skittish around inline assembly (and almost certainly, its instruction cost models can assume nothing)
1 parent ee241c4
Changed files (1)
lib
std
crypto
lib/std/crypto/sha2.zig
@@ -182,14 +182,8 @@ fn Sha2x32(comptime params: Sha2Params32) type {
 
         fn round(d: *Self, b: *const [64]u8) void {
             var s: [64]u32 align(16) = undefined;
-
-            var i: usize = 0;
-            while (i < 16) : (i += 1) {
-                s[i] = 0;
-                s[i] |= @as(u32, b[i * 4 + 0]) << 24;
-                s[i] |= @as(u32, b[i * 4 + 1]) << 16;
-                s[i] |= @as(u32, b[i * 4 + 2]) << 8;
-                s[i] |= @as(u32, b[i * 4 + 3]) << 0;
+            for (@ptrCast(*align(1) const [16]u32, b)) |*elem, i| {
+                s[i] = mem.readIntBig(u32, mem.asBytes(elem));
             }
 
             switch (builtin.cpu.arch) {
@@ -238,30 +232,35 @@ fn Sha2x32(comptime params: Sha2Params32) type {
                     comptime var k: u8 = 0;
                     inline while (k < 16) : (k += 1) {
                         if (k < 12) {
-                            const r = asm ("sha256msg1 %[w4_7], %[w0_3]"
-                                : [w0_3] "=x" (-> v4u32),
-                                : [_] "0" (s_v[k]),
+                            var tmp = s_v[k];
+                            s_v[k + 4] = asm (
+                                \\ sha256msg1 %[w4_7], %[tmp]
+                                \\ vpalignr $0x4, %[w8_11], %[w12_15], %[result]
+                                \\ paddd %[tmp], %[result]
+                                \\ sha256msg2 %[w12_15], %[result]
+                                : [tmp] "=&x" (tmp),
+                                  [result] "=&x" (-> v4u32),
+                                : [_] "0" (tmp),
                                   [w4_7] "x" (s_v[k + 1]),
-                            );
-                            const t = @shuffle(u32, s_v[k + 2], s_v[k + 3], [_]i32{ 1, 2, 3, -1 });
-                            s_v[k + 4] = asm ("sha256msg2 %[w12_15], %[t]"
-                                : [t] "=x" (-> v4u32),
-                                : [_] "0" (r +% t),
+                                  [w8_11] "x" (s_v[k + 2]),
                                   [w12_15] "x" (s_v[k + 3]),
                             );
                         }
 
                         const w: v4u32 = s_v[k] +% @as(v4u32, W[4 * k ..][0..4].*);
-                        asm volatile (
-                            \\sha256rnds2 %[x], %[y]
-                            \\pshufd $0xe, %%xmm0, %%xmm0
-                            \\sha256rnds2 %[y], %[x]
-                            : [y] "=x" (y),
-                              [x] "=x" (x),
+                        y = asm ("sha256rnds2 %[x], %[y]"
+                            : [y] "=x" (-> v4u32),
                             : [_] "0" (y),
-                              [_] "1" (x),
+                              [x] "x" (x),
                               [_] "{xmm0}" (w),
                         );
+
+                        x = asm ("sha256rnds2 %[y], %[x]"
+                            : [x] "=x" (-> v4u32),
+                            : [_] "0" (x),
+                              [y] "x" (y),
+                              [_] "{xmm0}" (@bitCast(v4u32, @bitCast(u128, w) >> 64)),
+                        );
                     }
 
                     d.s[0] +%= x[3];
@@ -277,6 +276,7 @@ fn Sha2x32(comptime params: Sha2Params32) type {
                 else => {},
             }
 
+            var i: usize = 16;
             while (i < 64) : (i += 1) {
                 s[i] = s[i - 16] +% s[i - 7] +% (math.rotr(u32, s[i - 15], @as(u32, 7)) ^ math.rotr(u32, s[i - 15], @as(u32, 18)) ^ (s[i - 15] >> 3)) +% (math.rotr(u32, s[i - 2], @as(u32, 17)) ^ math.rotr(u32, s[i - 2], @as(u32, 19)) ^ (s[i - 2] >> 10));
             }