Commit 72d3f4b5dc

Andrew Kelley <andrew@ziglang.org>
2022-11-17 23:37:22
Revert "std.crypto.onetimeauth.ghash: faster GHASH on modern CPUs (#13566)"
This reverts commit 7cfeae1ce7aa9f1b3a219d032c43bc2e694ba63b which is causing std lib tests to fail on wasm32-wasi.
1 parent 88a0f3d
Changed files (1)
lib
std
crypto
lib/std/crypto/ghash.zig
@@ -18,19 +18,12 @@ pub const Ghash = struct {
     pub const mac_length = 16;
     pub const key_length = 16;
 
-    const pc_count = if (builtin.mode != .ReleaseSmall) 16 else 2;
+    const pc_count = if (builtin.mode != .ReleaseSmall) 16 else 4;
+    const agg_2_treshold = 5;
     const agg_4_treshold = 22;
     const agg_8_treshold = 84;
     const agg_16_treshold = 328;
 
-    // Before the Haswell architecture, the carryless multiplication instruction was
-    // extremely slow. Even with 128-bit operands, using Karatsuba multiplication was
-    // thus faster than a schoolbook multiplication.
-    // This is no longer the case -- Modern CPUs, including ARM-based ones, have a fast
-    // carryless multiplication instruction; using 4 multiplications is now faster than
-    // 3 multiplications with extra shifts and additions.
-    const mul_algorithm = if (builtin.cpu.arch == .x86) .karatsuba else .schoolbook;
-
     hx: [pc_count]Precomp,
     acc: u128 = 0,
 
@@ -50,10 +43,10 @@ pub const Ghash = struct {
         var hx: [pc_count]Precomp = undefined;
         hx[0] = h;
         hx[1] = gcmReduce(clsq128(hx[0])); // h^2
+        hx[2] = gcmReduce(clmul128(hx[1], h)); // h^3
+        hx[3] = gcmReduce(clsq128(hx[1])); // h^4 = h^2^2
 
         if (builtin.mode != .ReleaseSmall) {
-            hx[2] = gcmReduce(clmul128(hx[1], h)); // h^3
-            hx[3] = gcmReduce(clsq128(hx[1])); // h^4 = h^2^2
             if (block_count >= agg_8_treshold) {
                 hx[4] = gcmReduce(clmul128(hx[3], h)); // h^5
                 hx[5] = gcmReduce(clsq128(hx[2])); // h^6 = h^3^2
@@ -76,71 +69,47 @@ pub const Ghash = struct {
         return Ghash.initForBlockCount(key, math.maxInt(usize));
     }
 
-    const Selector = enum { lo, hi, hi_lo };
+    const Selector = enum { lo, hi };
 
     // Carryless multiplication of two 64-bit integers for x86_64.
     inline fn clmulPclmul(x: u128, y: u128, comptime half: Selector) u128 {
-        switch (half) {
-            .hi => {
-                const product = asm (
-                    \\ vpclmulqdq $0x11, %[x], %[y], %[out]
-                    : [out] "=x" (-> @Vector(2, u64)),
-                    : [x] "x" (@bitCast(@Vector(2, u64), x)),
-                      [y] "x" (@bitCast(@Vector(2, u64), y)),
-                );
-                return @bitCast(u128, product);
-            },
-            .lo => {
-                const product = asm (
-                    \\ vpclmulqdq $0x00, %[x], %[y], %[out]
-                    : [out] "=x" (-> @Vector(2, u64)),
-                    : [x] "x" (@bitCast(@Vector(2, u64), x)),
-                      [y] "x" (@bitCast(@Vector(2, u64), y)),
-                );
-                return @bitCast(u128, product);
-            },
-            .hi_lo => {
-                const product = asm (
-                    \\ vpclmulqdq $0x10, %[x], %[y], %[out]
-                    : [out] "=x" (-> @Vector(2, u64)),
-                    : [x] "x" (@bitCast(@Vector(2, u64), x)),
-                      [y] "x" (@bitCast(@Vector(2, u64), y)),
-                );
-                return @bitCast(u128, product);
-            },
+        if (half == .hi) {
+            const product = asm (
+                \\ vpclmulqdq $0x11, %[x], %[y], %[out]
+                : [out] "=x" (-> @Vector(2, u64)),
+                : [x] "x" (@bitCast(@Vector(2, u64), @as(u128, x))),
+                  [y] "x" (@bitCast(@Vector(2, u64), @as(u128, y))),
+            );
+            return @bitCast(u128, product);
+        } else {
+            const product = asm (
+                \\ vpclmulqdq $0x00, %[x], %[y], %[out]
+                : [out] "=x" (-> @Vector(2, u64)),
+                : [x] "x" (@bitCast(@Vector(2, u64), @as(u128, x))),
+                  [y] "x" (@bitCast(@Vector(2, u64), @as(u128, y))),
+            );
+            return @bitCast(u128, product);
         }
     }
 
     // Carryless multiplication of two 64-bit integers for ARM crypto.
     inline fn clmulPmull(x: u128, y: u128, comptime half: Selector) u128 {
-        switch (half) {
-            .hi => {
-                const product = asm (
-                    \\ pmull2 %[out].1q, %[x].2d, %[y].2d
-                    : [out] "=w" (-> @Vector(2, u64)),
-                    : [x] "w" (@bitCast(@Vector(2, u64), x)),
-                      [y] "w" (@bitCast(@Vector(2, u64), y)),
-                );
-                return @bitCast(u128, product);
-            },
-            .lo => {
-                const product = asm (
-                    \\ pmull %[out].1q, %[x].1d, %[y].1d
-                    : [out] "=w" (-> @Vector(2, u64)),
-                    : [x] "w" (@bitCast(@Vector(2, u64), x)),
-                      [y] "w" (@bitCast(@Vector(2, u64), y)),
-                );
-                return @bitCast(u128, product);
-            },
-            .hi_lo => {
-                const product = asm (
-                    \\ pmull %[out].1q, %[x].1d, %[y].1d
-                    : [out] "=w" (-> @Vector(2, u64)),
-                    : [x] "w" (@bitCast(@Vector(2, u64), x >> 64)),
-                      [y] "w" (@bitCast(@Vector(2, u64), y)),
-                );
-                return @bitCast(u128, product);
-            },
+        if (half == .hi) {
+            const product = asm (
+                \\ pmull2 %[out].1q, %[x].2d, %[y].2d
+                : [out] "=w" (-> @Vector(2, u64)),
+                : [x] "w" (@bitCast(@Vector(2, u64), @as(u128, x))),
+                  [y] "w" (@bitCast(@Vector(2, u64), @as(u128, y))),
+            );
+            return @bitCast(u128, product);
+        } else {
+            const product = asm (
+                \\ pmull %[out].1q, %[x].1d, %[y].1d
+                : [out] "=w" (-> @Vector(2, u64)),
+                : [x] "w" (@bitCast(@Vector(2, u64), @as(u128, x))),
+                  [y] "w" (@bitCast(@Vector(2, u64), @as(u128, y))),
+            );
+            return @bitCast(u128, product);
         }
     }
 
@@ -175,63 +144,38 @@ pub const Ghash = struct {
             (z3 & 0x88888888888888888888888888888888) ^ extra;
     }
 
-    const I256 = struct {
-        hi: u128,
-        lo: u128,
-        mid: u128,
-    };
-
-    inline fn xor256(x: *I256, y: I256) void {
-        x.* = I256{
-            .hi = x.hi ^ y.hi,
-            .lo = x.lo ^ y.lo,
-            .mid = x.mid ^ y.mid,
-        };
-    }
-
     // Square a 128-bit integer in GF(2^128).
-    fn clsq128(x: u128) I256 {
-        return .{
-            .hi = clmul(x, x, .hi),
-            .lo = clmul(x, x, .lo),
-            .mid = 0,
-        };
+    fn clsq128(x: u128) u256 {
+        const lo = @truncate(u64, x);
+        const hi = @truncate(u64, x >> 64);
+        const mid = lo ^ hi;
+        const r_lo = clmul(x, x, .lo);
+        const r_hi = clmul(x, x, .hi);
+        const r_mid = clmul(mid, mid, .lo) ^ r_lo ^ r_hi;
+        return (@as(u256, r_hi) << 128) ^ (@as(u256, r_mid) << 64) ^ r_lo;
     }
 
     // Multiply two 128-bit integers in GF(2^128).
-    inline fn clmul128(x: u128, y: u128) I256 {
-        if (mul_algorithm == .karatsuba) {
-            const x_hi = @truncate(u64, x >> 64);
-            const y_hi = @truncate(u64, y >> 64);
-            const r_lo = clmul(x, y, .lo);
-            const r_hi = clmul(x, y, .hi);
-            const r_mid = clmul(x ^ x_hi, y ^ y_hi, .lo) ^ r_lo ^ r_hi;
-            return .{
-                .hi = r_hi,
-                .lo = r_lo,
-                .mid = r_mid,
-            };
-        } else {
-            return .{
-                .hi = clmul(x, y, .hi),
-                .lo = clmul(x, y, .lo),
-                .mid = clmul(x, y, .hi_lo) ^ clmul(y, x, .hi_lo),
-            };
-        }
+    inline fn clmul128(x: u128, y: u128) u256 {
+        const x_hi = @truncate(u64, x >> 64);
+        const y_hi = @truncate(u64, y >> 64);
+        const r_lo = clmul(x, y, .lo);
+        const r_hi = clmul(x, y, .hi);
+        const r_mid = clmul(x ^ x_hi, y ^ y_hi, .lo) ^ r_lo ^ r_hi;
+        return (@as(u256, r_hi) << 128) ^ (@as(u256, r_mid) << 64) ^ r_lo;
     }
 
     // Reduce a 256-bit representative of a polynomial modulo the irreducible polynomial x^128 + x^127 + x^126 + x^121 + 1.
     // This is done *without reversing the bits*, using Shay Gueron's black magic demysticated here:
     // https://blog.quarkslab.com/reversing-a-finite-field-multiplication-optimization.html
-    inline fn gcmReduce(x: I256) u128 {
-        const hi = x.hi ^ (x.mid >> 64);
-        const lo = x.lo ^ (x.mid << 64);
+    inline fn gcmReduce(x: u256) u128 {
         const p64 = (((1 << 121) | (1 << 126) | (1 << 127)) >> 64);
+        const lo = @truncate(u128, x);
         const a = clmul(lo, p64, .lo);
         const b = ((lo << 64) | (lo >> 64)) ^ a;
         const c = clmul(b, p64, .lo);
         const d = ((b << 64) | (b >> 64)) ^ c;
-        return d ^ hi;
+        return d ^ @truncate(u128, x >> 128);
     }
 
     const has_pclmul = std.Target.x86.featureSetHas(builtin.cpu.features, .pclmul);
@@ -258,7 +202,7 @@ pub const Ghash = struct {
                 var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[15 - 0]);
                 comptime var j = 1;
                 inline while (j < 16) : (j += 1) {
-                    xor256(&u, clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[15 - j]));
+                    u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[15 - j]);
                 }
                 acc = gcmReduce(u);
             }
@@ -268,7 +212,7 @@ pub const Ghash = struct {
                 var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[7 - 0]);
                 comptime var j = 1;
                 inline while (j < 8) : (j += 1) {
-                    xor256(&u, clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[7 - j]));
+                    u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[7 - j]);
                 }
                 acc = gcmReduce(u);
             }
@@ -278,25 +222,31 @@ pub const Ghash = struct {
                 var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[3 - 0]);
                 comptime var j = 1;
                 inline while (j < 4) : (j += 1) {
-                    xor256(&u, clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[3 - j]));
+                    u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[3 - j]);
                 }
                 acc = gcmReduce(u);
             }
-        }
-        // 2-blocks aggregated reduction
-        while (i + 32 <= msg.len) : (i += 32) {
-            var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[1 - 0]);
-            comptime var j = 1;
-            inline while (j < 2) : (j += 1) {
-                xor256(&u, clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[1 - j]));
+        } else if (msg.len >= agg_2_treshold * block_length) {
+            // 2-blocks aggregated reduction
+            while (i + 32 <= msg.len) : (i += 32) {
+                var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[1 - 0]);
+                comptime var j = 1;
+                inline while (j < 2) : (j += 1) {
+                    u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[1 - j]);
+                }
+                acc = gcmReduce(u);
             }
-            acc = gcmReduce(u);
         }
         // remaining blocks
         if (i < msg.len) {
-            const u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[0]);
+            const n = (msg.len - i) / 16;
+            var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[n - 1 - 0]);
+            var j: usize = 1;
+            while (j < n) : (j += 1) {
+                u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[n - 1 - j]);
+            }
+            i += n * 16;
             acc = gcmReduce(u);
-            i += 16;
         }
         assert(i == msg.len);
         st.acc = acc;