Commit 7cfeae1ce7

Frank Denis <124872+jedisct1@users.noreply.github.com>
2022-11-17 13:07:07
std.crypto.onetimeauth.ghash: faster GHASH on modern CPUs (#13566)
* std.crypto.onetimeauth.ghash: faster GHASH on modern CPUs Carryless multiplication was slow on older Intel CPUs, justifying the need for using Karatsuba multiplication. This is not the case any more; using 4 multiplications to multiply two 128-bit numbers is actually faster than 3 multiplications + shifts and additions. This is also true on aarch64. Keep using Karatsuba only when targeting x86 (granted, this is a bit of a brutal shortcut, we should really list all the CPU models that had a slow clmul instruction). Also remove useless agg_2 treshold and restore the ability to precompute only H and H^2 in ReleaseSmall. Finally, avoid using u256. Using 128-bit registers is actually faster. * Use a switch, add some comments
1 parent 58d9004
Changed files (1)
lib
std
crypto
lib/std/crypto/ghash.zig
@@ -18,12 +18,19 @@ pub const Ghash = struct {
     pub const mac_length = 16;
     pub const key_length = 16;
 
-    const pc_count = if (builtin.mode != .ReleaseSmall) 16 else 4;
-    const agg_2_treshold = 5;
+    const pc_count = if (builtin.mode != .ReleaseSmall) 16 else 2;
     const agg_4_treshold = 22;
     const agg_8_treshold = 84;
     const agg_16_treshold = 328;
 
+    // Before the Haswell architecture, the carryless multiplication instruction was
+    // extremely slow. Even with 128-bit operands, using Karatsuba multiplication was
+    // thus faster than a schoolbook multiplication.
+    // This is no longer the case -- Modern CPUs, including ARM-based ones, have a fast
+    // carryless multiplication instruction; using 4 multiplications is now faster than
+    // 3 multiplications with extra shifts and additions.
+    const mul_algorithm = if (builtin.cpu.arch == .x86) .karatsuba else .schoolbook;
+
     hx: [pc_count]Precomp,
     acc: u128 = 0,
 
@@ -43,10 +50,10 @@ pub const Ghash = struct {
         var hx: [pc_count]Precomp = undefined;
         hx[0] = h;
         hx[1] = gcmReduce(clsq128(hx[0])); // h^2
-        hx[2] = gcmReduce(clmul128(hx[1], h)); // h^3
-        hx[3] = gcmReduce(clsq128(hx[1])); // h^4 = h^2^2
 
         if (builtin.mode != .ReleaseSmall) {
+            hx[2] = gcmReduce(clmul128(hx[1], h)); // h^3
+            hx[3] = gcmReduce(clsq128(hx[1])); // h^4 = h^2^2
             if (block_count >= agg_8_treshold) {
                 hx[4] = gcmReduce(clmul128(hx[3], h)); // h^5
                 hx[5] = gcmReduce(clsq128(hx[2])); // h^6 = h^3^2
@@ -69,47 +76,71 @@ pub const Ghash = struct {
         return Ghash.initForBlockCount(key, math.maxInt(usize));
     }
 
-    const Selector = enum { lo, hi };
+    const Selector = enum { lo, hi, hi_lo };
 
     // Carryless multiplication of two 64-bit integers for x86_64.
     inline fn clmulPclmul(x: u128, y: u128, comptime half: Selector) u128 {
-        if (half == .hi) {
-            const product = asm (
-                \\ vpclmulqdq $0x11, %[x], %[y], %[out]
-                : [out] "=x" (-> @Vector(2, u64)),
-                : [x] "x" (@bitCast(@Vector(2, u64), @as(u128, x))),
-                  [y] "x" (@bitCast(@Vector(2, u64), @as(u128, y))),
-            );
-            return @bitCast(u128, product);
-        } else {
-            const product = asm (
-                \\ vpclmulqdq $0x00, %[x], %[y], %[out]
-                : [out] "=x" (-> @Vector(2, u64)),
-                : [x] "x" (@bitCast(@Vector(2, u64), @as(u128, x))),
-                  [y] "x" (@bitCast(@Vector(2, u64), @as(u128, y))),
-            );
-            return @bitCast(u128, product);
+        switch (half) {
+            .hi => {
+                const product = asm (
+                    \\ vpclmulqdq $0x11, %[x], %[y], %[out]
+                    : [out] "=x" (-> @Vector(2, u64)),
+                    : [x] "x" (@bitCast(@Vector(2, u64), x)),
+                      [y] "x" (@bitCast(@Vector(2, u64), y)),
+                );
+                return @bitCast(u128, product);
+            },
+            .lo => {
+                const product = asm (
+                    \\ vpclmulqdq $0x00, %[x], %[y], %[out]
+                    : [out] "=x" (-> @Vector(2, u64)),
+                    : [x] "x" (@bitCast(@Vector(2, u64), x)),
+                      [y] "x" (@bitCast(@Vector(2, u64), y)),
+                );
+                return @bitCast(u128, product);
+            },
+            .hi_lo => {
+                const product = asm (
+                    \\ vpclmulqdq $0x10, %[x], %[y], %[out]
+                    : [out] "=x" (-> @Vector(2, u64)),
+                    : [x] "x" (@bitCast(@Vector(2, u64), x)),
+                      [y] "x" (@bitCast(@Vector(2, u64), y)),
+                );
+                return @bitCast(u128, product);
+            },
         }
     }
 
     // Carryless multiplication of two 64-bit integers for ARM crypto.
     inline fn clmulPmull(x: u128, y: u128, comptime half: Selector) u128 {
-        if (half == .hi) {
-            const product = asm (
-                \\ pmull2 %[out].1q, %[x].2d, %[y].2d
-                : [out] "=w" (-> @Vector(2, u64)),
-                : [x] "w" (@bitCast(@Vector(2, u64), @as(u128, x))),
-                  [y] "w" (@bitCast(@Vector(2, u64), @as(u128, y))),
-            );
-            return @bitCast(u128, product);
-        } else {
-            const product = asm (
-                \\ pmull %[out].1q, %[x].1d, %[y].1d
-                : [out] "=w" (-> @Vector(2, u64)),
-                : [x] "w" (@bitCast(@Vector(2, u64), @as(u128, x))),
-                  [y] "w" (@bitCast(@Vector(2, u64), @as(u128, y))),
-            );
-            return @bitCast(u128, product);
+        switch (half) {
+            .hi => {
+                const product = asm (
+                    \\ pmull2 %[out].1q, %[x].2d, %[y].2d
+                    : [out] "=w" (-> @Vector(2, u64)),
+                    : [x] "w" (@bitCast(@Vector(2, u64), x)),
+                      [y] "w" (@bitCast(@Vector(2, u64), y)),
+                );
+                return @bitCast(u128, product);
+            },
+            .lo => {
+                const product = asm (
+                    \\ pmull %[out].1q, %[x].1d, %[y].1d
+                    : [out] "=w" (-> @Vector(2, u64)),
+                    : [x] "w" (@bitCast(@Vector(2, u64), x)),
+                      [y] "w" (@bitCast(@Vector(2, u64), y)),
+                );
+                return @bitCast(u128, product);
+            },
+            .hi_lo => {
+                const product = asm (
+                    \\ pmull %[out].1q, %[x].1d, %[y].1d
+                    : [out] "=w" (-> @Vector(2, u64)),
+                    : [x] "w" (@bitCast(@Vector(2, u64), x >> 64)),
+                      [y] "w" (@bitCast(@Vector(2, u64), y)),
+                );
+                return @bitCast(u128, product);
+            },
         }
     }
 
@@ -144,38 +175,63 @@ pub const Ghash = struct {
             (z3 & 0x88888888888888888888888888888888) ^ extra;
     }
 
+    const I256 = struct {
+        hi: u128,
+        lo: u128,
+        mid: u128,
+    };
+
+    inline fn xor256(x: *I256, y: I256) void {
+        x.* = I256{
+            .hi = x.hi ^ y.hi,
+            .lo = x.lo ^ y.lo,
+            .mid = x.mid ^ y.mid,
+        };
+    }
+
     // Square a 128-bit integer in GF(2^128).
-    fn clsq128(x: u128) u256 {
-        const lo = @truncate(u64, x);
-        const hi = @truncate(u64, x >> 64);
-        const mid = lo ^ hi;
-        const r_lo = clmul(x, x, .lo);
-        const r_hi = clmul(x, x, .hi);
-        const r_mid = clmul(mid, mid, .lo) ^ r_lo ^ r_hi;
-        return (@as(u256, r_hi) << 128) ^ (@as(u256, r_mid) << 64) ^ r_lo;
+    fn clsq128(x: u128) I256 {
+        return .{
+            .hi = clmul(x, x, .hi),
+            .lo = clmul(x, x, .lo),
+            .mid = 0,
+        };
     }
 
     // Multiply two 128-bit integers in GF(2^128).
-    inline fn clmul128(x: u128, y: u128) u256 {
-        const x_hi = @truncate(u64, x >> 64);
-        const y_hi = @truncate(u64, y >> 64);
-        const r_lo = clmul(x, y, .lo);
-        const r_hi = clmul(x, y, .hi);
-        const r_mid = clmul(x ^ x_hi, y ^ y_hi, .lo) ^ r_lo ^ r_hi;
-        return (@as(u256, r_hi) << 128) ^ (@as(u256, r_mid) << 64) ^ r_lo;
+    inline fn clmul128(x: u128, y: u128) I256 {
+        if (mul_algorithm == .karatsuba) {
+            const x_hi = @truncate(u64, x >> 64);
+            const y_hi = @truncate(u64, y >> 64);
+            const r_lo = clmul(x, y, .lo);
+            const r_hi = clmul(x, y, .hi);
+            const r_mid = clmul(x ^ x_hi, y ^ y_hi, .lo) ^ r_lo ^ r_hi;
+            return .{
+                .hi = r_hi,
+                .lo = r_lo,
+                .mid = r_mid,
+            };
+        } else {
+            return .{
+                .hi = clmul(x, y, .hi),
+                .lo = clmul(x, y, .lo),
+                .mid = clmul(x, y, .hi_lo) ^ clmul(y, x, .hi_lo),
+            };
+        }
     }
 
     // Reduce a 256-bit representative of a polynomial modulo the irreducible polynomial x^128 + x^127 + x^126 + x^121 + 1.
     // This is done *without reversing the bits*, using Shay Gueron's black magic demysticated here:
     // https://blog.quarkslab.com/reversing-a-finite-field-multiplication-optimization.html
-    inline fn gcmReduce(x: u256) u128 {
+    inline fn gcmReduce(x: I256) u128 {
+        const hi = x.hi ^ (x.mid >> 64);
+        const lo = x.lo ^ (x.mid << 64);
         const p64 = (((1 << 121) | (1 << 126) | (1 << 127)) >> 64);
-        const lo = @truncate(u128, x);
         const a = clmul(lo, p64, .lo);
         const b = ((lo << 64) | (lo >> 64)) ^ a;
         const c = clmul(b, p64, .lo);
         const d = ((b << 64) | (b >> 64)) ^ c;
-        return d ^ @truncate(u128, x >> 128);
+        return d ^ hi;
     }
 
     const has_pclmul = std.Target.x86.featureSetHas(builtin.cpu.features, .pclmul);
@@ -202,7 +258,7 @@ pub const Ghash = struct {
                 var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[15 - 0]);
                 comptime var j = 1;
                 inline while (j < 16) : (j += 1) {
-                    u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[15 - j]);
+                    xor256(&u, clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[15 - j]));
                 }
                 acc = gcmReduce(u);
             }
@@ -212,7 +268,7 @@ pub const Ghash = struct {
                 var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[7 - 0]);
                 comptime var j = 1;
                 inline while (j < 8) : (j += 1) {
-                    u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[7 - j]);
+                    xor256(&u, clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[7 - j]));
                 }
                 acc = gcmReduce(u);
             }
@@ -222,31 +278,25 @@ pub const Ghash = struct {
                 var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[3 - 0]);
                 comptime var j = 1;
                 inline while (j < 4) : (j += 1) {
-                    u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[3 - j]);
+                    xor256(&u, clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[3 - j]));
                 }
                 acc = gcmReduce(u);
             }
-        } else if (msg.len >= agg_2_treshold * block_length) {
-            // 2-blocks aggregated reduction
-            while (i + 32 <= msg.len) : (i += 32) {
-                var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[1 - 0]);
-                comptime var j = 1;
-                inline while (j < 2) : (j += 1) {
-                    u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[1 - j]);
-                }
-                acc = gcmReduce(u);
+        }
+        // 2-blocks aggregated reduction
+        while (i + 32 <= msg.len) : (i += 32) {
+            var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[1 - 0]);
+            comptime var j = 1;
+            inline while (j < 2) : (j += 1) {
+                xor256(&u, clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[1 - j]));
             }
+            acc = gcmReduce(u);
         }
         // remaining blocks
         if (i < msg.len) {
-            const n = (msg.len - i) / 16;
-            var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[n - 1 - 0]);
-            var j: usize = 1;
-            while (j < n) : (j += 1) {
-                u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[n - 1 - j]);
-            }
-            i += n * 16;
+            const u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[0]);
             acc = gcmReduce(u);
+            i += 16;
         }
         assert(i == msg.len);
         st.acc = acc;