Commit 0d192ee9ef

Frank Denis <124872+jedisct1@users.noreply.github.com>
2022-11-01 18:49:13
std.crypto.onetimeauth.Ghash: make GHASH 2 - 2.5x faster (#13374)
Rewrite GHASH to use 128-bit multiplication over non-reversed integers, and up to 8 blocks aggregated reduction. lib/std/crypto/benchmark.zig results: Xeon E5: Before: 1604 MiB/s After: 4005 MiB/s Apple M1: Before: 2769 MiB/s After: 6014 MiB/s This also makes AES-GCM faster by the way.
1 parent 1780d7a
Changed files (2)
lib
lib/std/crypto/aes_gcm.zig
@@ -3,6 +3,7 @@ const assert = std.debug.assert;
 const crypto = std.crypto;
 const debug = std.debug;
 const Ghash = std.crypto.onetimeauth.Ghash;
+const math = std.math;
 const mem = std.mem;
 const modes = crypto.core.modes;
 const AuthenticationError = crypto.errors.AuthenticationError;
@@ -34,7 +35,8 @@ fn AesGcm(comptime Aes: anytype) type {
             mem.writeIntBig(u32, j[nonce_length..][0..4], 1);
             aes.encrypt(&t, &j);
 
-            var mac = Ghash.init(&h);
+            const block_count = (math.divCeil(usize, ad.len, Ghash.block_length) catch unreachable) + (math.divCeil(usize, c.len, Ghash.block_length) catch unreachable);
+            var mac = Ghash.initForBlockCount(&h, block_count);
             mac.update(ad);
             mac.pad();
 
@@ -66,7 +68,8 @@ fn AesGcm(comptime Aes: anytype) type {
             mem.writeIntBig(u32, j[nonce_length..][0..4], 1);
             aes.encrypt(&t, &j);
 
-            var mac = Ghash.init(&h);
+            const block_count = (math.divCeil(usize, ad.len, Ghash.block_length) catch unreachable) + (math.divCeil(usize, c.len, Ghash.block_length) catch unreachable) + 1;
+            var mac = Ghash.initForBlockCount(&h, block_count);
             mac.update(ad);
             mac.pad();
 
lib/std/crypto/ghash.zig
@@ -1,6 +1,3 @@
-//
-// Adapted from BearSSL's ctmul64 implementation originally written by Thomas Pornin <pornin@bolet.org>
-
 const std = @import("../std.zig");
 const builtin = @import("builtin");
 const assert = std.debug.assert;
@@ -8,6 +5,8 @@ const math = std.math;
 const mem = std.mem;
 const utils = std.crypto.utils;
 
+const Precomp = u128;
+
 /// GHASH is a universal hash function that features multiplication
 /// by a fixed parameter within a Galois field.
 ///
@@ -19,116 +18,132 @@ pub const Ghash = struct {
     pub const mac_length = 16;
     pub const key_length = 16;
 
-    y0: u64 = 0,
-    y1: u64 = 0,
-    h0: u64,
-    h1: u64,
-    h2: u64,
-    h0r: u64,
-    h1r: u64,
-    h2r: u64,
-
-    hh0: u64 = undefined,
-    hh1: u64 = undefined,
-    hh2: u64 = undefined,
-    hh0r: u64 = undefined,
-    hh1r: u64 = undefined,
-    hh2r: u64 = undefined,
+    const pc_count = if (builtin.mode != .ReleaseSmall) 8 else 1;
+
+    hx: [pc_count]Precomp,
+    acc: u128 = 0,
 
     leftover: usize = 0,
     buf: [block_length]u8 align(16) = undefined,
 
-    pub fn init(key: *const [key_length]u8) Ghash {
-        const h1 = mem.readIntBig(u64, key[0..8]);
-        const h0 = mem.readIntBig(u64, key[8..16]);
-        const h1r = @bitReverse(h1);
-        const h0r = @bitReverse(h0);
-        const h2 = h0 ^ h1;
-        const h2r = h0r ^ h1r;
-
-        if (builtin.mode == .ReleaseSmall) {
-            return Ghash{
-                .h0 = h0,
-                .h1 = h1,
-                .h2 = h2,
-                .h0r = h0r,
-                .h1r = h1r,
-                .h2r = h2r,
-            };
-        } else {
-            // Precompute H^2
-            var hh = Ghash{
-                .h0 = h0,
-                .h1 = h1,
-                .h2 = h2,
-                .h0r = h0r,
-                .h1r = h1r,
-                .h2r = h2r,
-            };
-            hh.update(key);
-            const hh1 = hh.y1;
-            const hh0 = hh.y0;
-            const hh1r = @bitReverse(hh1);
-            const hh0r = @bitReverse(hh0);
-            const hh2 = hh0 ^ hh1;
-            const hh2r = hh0r ^ hh1r;
-
-            return Ghash{
-                .h0 = h0,
-                .h1 = h1,
-                .h2 = h2,
-                .h0r = h0r,
-                .h1r = h1r,
-                .h2r = h2r,
-
-                .hh0 = hh0,
-                .hh1 = hh1,
-                .hh2 = hh2,
-                .hh0r = hh0r,
-                .hh1r = hh1r,
-                .hh2r = hh2r,
-            };
+    /// Initialize the GHASH state with a key, and a minimum number of block count.
+    pub fn initForBlockCount(key: *const [key_length]u8, block_count: usize) Ghash {
+        const h0 = mem.readIntBig(u128, key[0..16]);
+
+        // We keep the values encoded as in GCM, not Polyval, i.e. without reversing the bits.
+        // This is fine, but the reversed result would be shifted by 1 bit. So, we shift h
+        // to compensate.
+        const carry = ((@as(u128, 0xc2) << 120) | 1) & (@as(u128, 0) -% (h0 >> 127));
+        const h = (h0 << 1) ^ carry;
+
+        var hx: [pc_count]Precomp = undefined;
+        hx[0] = h;
+        if (builtin.mode != .ReleaseSmall) {
+            if (block_count > 2) {
+                hx[1] = gcm_reduce(clsq128(hx[0])); // h^2
+            }
+            if (block_count > 4) {
+                hx[2] = gcm_reduce(clmul128(hx[1], h)); // h^3
+                hx[3] = gcm_reduce(clsq128(hx[1])); // h^4
+            }
+            if (block_count > 8) {
+                hx[4] = gcm_reduce(clmul128(hx[3], h)); // h^5
+                hx[5] = gcm_reduce(clmul128(hx[4], h)); // h^6
+                hx[6] = gcm_reduce(clmul128(hx[5], h)); // h^7
+                hx[7] = gcm_reduce(clsq128(hx[3])); // h^8
+            }
         }
+        return Ghash{ .hx = hx };
     }
 
-    inline fn clmul_pclmul(x: u64, y: u64) u64 {
+    /// Initialize the GHASH state with a key.
+    pub fn init(key: *const [key_length]u8) Ghash {
+        return Ghash.initForBlockCount(key, math.maxInt(usize));
+    }
+
+    // Carryless multiplication of two 64-bit integers for x86_64.
+    inline fn clmul_pclmul(x: u64, y: u64) u128 {
         const product = asm (
             \\ vpclmulqdq $0x00, %[x], %[y], %[out]
             : [out] "=x" (-> @Vector(2, u64)),
             : [x] "x" (@bitCast(@Vector(2, u64), @as(u128, x))),
               [y] "x" (@bitCast(@Vector(2, u64), @as(u128, y))),
         );
-        return product[0];
+        return (@as(u128, product[1]) << 64) | product[0];
     }
 
-    inline fn clmul_pmull(x: u64, y: u64) u64 {
+    // Carryless multiplication of two 64-bit integers for ARM crypto.
+    inline fn clmul_pmull(x: u64, y: u64) u128 {
         const product = asm (
             \\ pmull %[out].1q, %[x].1d, %[y].1d
             : [out] "=w" (-> @Vector(2, u64)),
             : [x] "w" (@bitCast(@Vector(2, u64), @as(u128, x))),
               [y] "w" (@bitCast(@Vector(2, u64), @as(u128, y))),
         );
-        return product[0];
+        return (@as(u128, product[1]) << 64) | product[0];
     }
 
-    fn clmul_soft(x: u64, y: u64) u64 {
-        const x0 = x & 0x1111111111111111;
-        const x1 = x & 0x2222222222222222;
-        const x2 = x & 0x4444444444444444;
-        const x3 = x & 0x8888888888888888;
+    // Software carryless multiplication of two 64-bit integers.
+    fn clmul_soft(x: u64, y: u64) u128 {
+        const x0 = x & 0x1111111111111110;
+        const x1 = x & 0x2222222222222220;
+        const x2 = x & 0x4444444444444440;
+        const x3 = x & 0x8888888888888880;
         const y0 = y & 0x1111111111111111;
         const y1 = y & 0x2222222222222222;
         const y2 = y & 0x4444444444444444;
         const y3 = y & 0x8888888888888888;
-        var z0 = (x0 *% y0) ^ (x1 *% y3) ^ (x2 *% y2) ^ (x3 *% y1);
-        var z1 = (x0 *% y1) ^ (x1 *% y0) ^ (x2 *% y3) ^ (x3 *% y2);
-        var z2 = (x0 *% y2) ^ (x1 *% y1) ^ (x2 *% y0) ^ (x3 *% y3);
-        var z3 = (x0 *% y3) ^ (x1 *% y2) ^ (x2 *% y1) ^ (x3 *% y0);
-        z0 &= 0x1111111111111111;
-        z1 &= 0x2222222222222222;
-        z2 &= 0x4444444444444444;
-        z3 &= 0x8888888888888888;
-        return z0 | z1 | z2 | z3;
+        const z0 = (x0 * @as(u128, y0)) ^ (x1 * @as(u128, y3)) ^ (x2 * @as(u128, y2)) ^ (x3 * @as(u128, y1));
+        const z1 = (x0 * @as(u128, y1)) ^ (x1 * @as(u128, y0)) ^ (x2 * @as(u128, y3)) ^ (x3 * @as(u128, y2));
+        const z2 = (x0 * @as(u128, y2)) ^ (x1 * @as(u128, y1)) ^ (x2 * @as(u128, y0)) ^ (x3 * @as(u128, y3));
+        const z3 = (x0 * @as(u128, y3)) ^ (x1 * @as(u128, y2)) ^ (x2 * @as(u128, y1)) ^ (x3 * @as(u128, y0));
+
+        const x0_mask = @as(u64, 0) -% (x & 1);
+        const x1_mask = @as(u64, 0) -% ((x >> 1) & 1);
+        const x2_mask = @as(u64, 0) -% ((x >> 2) & 1);
+        const x3_mask = @as(u64, 0) -% ((x >> 3) & 1);
+        const extra = (x0_mask & y) ^ (@as(u128, x1_mask & y) << 1) ^
+            (@as(u128, x2_mask & y) << 2) ^ (@as(u128, x3_mask & y) << 3);
+
+        return (z0 & 0x11111111111111111111111111111111) ^
+            (z1 & 0x22222222222222222222222222222222) ^
+            (z2 & 0x44444444444444444444444444444444) ^
+            (z3 & 0x88888888888888888888888888888888) ^ extra;
+    }
+
+    // Square a 128-bit integer in GF(2^128).
+    fn clsq128(x: u128) u256 {
+        const lo = @truncate(u64, x);
+        const hi = @truncate(u64, x >> 64);
+        const mid = lo ^ hi;
+        const r_lo = clmul(lo, lo);
+        const r_hi = clmul(hi, hi);
+        const r_mid = clmul(mid, mid) ^ r_lo ^ r_hi;
+        return (@as(u256, r_hi) << 128) ^ (@as(u256, r_mid) << 64) ^ r_lo;
+    }
+
+    // Multiply two 128-bit integers in GF(2^128).
+    inline fn clmul128(x: u128, y: u128) u256 {
+        const x_lo = @truncate(u64, x);
+        const x_hi = @truncate(u64, x >> 64);
+        const y_lo = @truncate(u64, y);
+        const y_hi = @truncate(u64, y >> 64);
+        const r_lo = clmul(x_lo, y_lo);
+        const r_hi = clmul(x_hi, y_hi);
+        const r_mid = clmul(x_lo ^ x_hi, y_lo ^ y_hi) ^ r_lo ^ r_hi;
+        return (@as(u256, r_hi) << 128) ^ (@as(u256, r_mid) << 64) ^ r_lo;
+    }
+
+    // Reduce a 256-bit representative of a polynomial modulo the irreducible polynomial x^128 + x^127 + x^126 + x^121 + 1.
+    // This is done *without reversing the bits*, using Shay Gueron's black magic demysticated here:
+    // https://blog.quarkslab.com/reversing-a-finite-field-multiplication-optimization.html
+    inline fn gcm_reduce(x: u256) u128 {
+        const p64 = (((1 << 121) | (1 << 126) | (1 << 127)) >> 64);
+        const a = clmul(@truncate(u64, x), p64);
+        const b = ((@truncate(u128, x) << 64) | (@truncate(u128, x) >> 64)) ^ a;
+        const c = clmul(@truncate(u64, b), p64);
+        const d = ((b << 64) | (b >> 64)) ^ c;
+        return d ^ @truncate(u128, x >> 128);
     }
 
     const has_pclmul = std.Target.x86.featureSetHas(builtin.cpu.features, .pclmul);
@@ -142,116 +157,100 @@ pub const Ghash = struct {
         break :impl clmul_soft;
     };
 
+    // Process a block of 16 bytes.
     fn blocks(st: *Ghash, msg: []const u8) void {
         assert(msg.len % 16 == 0); // GHASH blocks() expects full blocks
-        var y1 = st.y1;
-        var y0 = st.y0;
+        var acc = st.acc;
 
         var i: usize = 0;
 
-        // 2-blocks aggregated reduction
         if (builtin.mode != .ReleaseSmall) {
+            // 8-blocks aggregated reduction
+            while (i + 128 <= msg.len) : (i += 128) {
+                const b0 = mem.readIntBig(u128, msg[i..][0..16]);
+                const z0 = acc ^ b0;
+                const z0h = clmul128(z0, st.hx[7]);
+
+                const b1 = mem.readIntBig(u128, msg[i..][16..32]);
+                const b1h = clmul128(b1, st.hx[6]);
+
+                const b2 = mem.readIntBig(u128, msg[i..][32..48]);
+                const b2h = clmul128(b2, st.hx[5]);
+
+                const b3 = mem.readIntBig(u128, msg[i..][48..64]);
+                const b3h = clmul128(b3, st.hx[4]);
+
+                const b4 = mem.readIntBig(u128, msg[i..][64..80]);
+                const b4h = clmul128(b4, st.hx[3]);
+
+                const b5 = mem.readIntBig(u128, msg[i..][80..96]);
+                const b5h = clmul128(b5, st.hx[2]);
+
+                const b6 = mem.readIntBig(u128, msg[i..][96..112]);
+                const b6h = clmul128(b6, st.hx[1]);
+
+                const b7 = mem.readIntBig(u128, msg[i..][112..128]);
+                const b7h = clmul128(b7, st.hx[0]);
+
+                const u = z0h ^ b1h ^ b2h ^ b3h ^ b4h ^ b5h ^ b6h ^ b7h;
+                acc = gcm_reduce(u);
+            }
+
+            // 4-blocks aggregated reduction
+            while (i + 64 <= msg.len) : (i += 64) {
+                // (acc + b0) * H^4 unreduced
+                const b0 = mem.readIntBig(u128, msg[i..][0..16]);
+                const z0 = acc ^ b0;
+                const z0h = clmul128(z0, st.hx[3]);
+
+                // b1 * H^3 unreduced
+                const b1 = mem.readIntBig(u128, msg[i..][16..32]);
+                const b1h = clmul128(b1, st.hx[2]);
+
+                // b2 * H^2 unreduced
+                const b2 = mem.readIntBig(u128, msg[i..][32..48]);
+                const b2h = clmul128(b2, st.hx[1]);
+
+                // b3 * H unreduced
+                const b3 = mem.readIntBig(u128, msg[i..][48..64]);
+                const b3h = clmul128(b3, st.hx[0]);
+
+                // (((acc + b0) * H^4) + B1 * H^3 + B2 * H^2 + B3 * H) (mod P)
+                const u = z0h ^ b1h ^ b2h ^ b3h;
+                acc = gcm_reduce(u);
+            }
+
+            // 2-blocks aggregated reduction
             while (i + 32 <= msg.len) : (i += 32) {
-                // B0 * H^2 unreduced
-                y1 ^= mem.readIntBig(u64, msg[i..][0..8]);
-                y0 ^= mem.readIntBig(u64, msg[i..][8..16]);
-
-                const y1r = @bitReverse(y1);
-                const y0r = @bitReverse(y0);
-                const y2 = y0 ^ y1;
-                const y2r = y0r ^ y1r;
-
-                var z0 = clmul(y0, st.hh0);
-                var z1 = clmul(y1, st.hh1);
-                var z2 = clmul(y2, st.hh2) ^ z0 ^ z1;
-                var z0h = clmul(y0r, st.hh0r);
-                var z1h = clmul(y1r, st.hh1r);
-                var z2h = clmul(y2r, st.hh2r) ^ z0h ^ z1h;
-
-                // B1 * H unreduced
-                const sy1 = mem.readIntBig(u64, msg[i..][16..24]);
-                const sy0 = mem.readIntBig(u64, msg[i..][24..32]);
-
-                const sy1r = @bitReverse(sy1);
-                const sy0r = @bitReverse(sy0);
-                const sy2 = sy0 ^ sy1;
-                const sy2r = sy0r ^ sy1r;
-
-                const sz0 = clmul(sy0, st.h0);
-                const sz1 = clmul(sy1, st.h1);
-                const sz2 = clmul(sy2, st.h2) ^ sz0 ^ sz1;
-                const sz0h = clmul(sy0r, st.h0r);
-                const sz1h = clmul(sy1r, st.h1r);
-                const sz2h = clmul(sy2r, st.h2r) ^ sz0h ^ sz1h;
-
-                // ((B0 * H^2) + B1 * H) (mod M)
-                z0 ^= sz0;
-                z1 ^= sz1;
-                z2 ^= sz2;
-                z0h ^= sz0h;
-                z1h ^= sz1h;
-                z2h ^= sz2h;
-                z0h = @bitReverse(z0h) >> 1;
-                z1h = @bitReverse(z1h) >> 1;
-                z2h = @bitReverse(z2h) >> 1;
-
-                var v3 = z1h;
-                var v2 = z1 ^ z2h;
-                var v1 = z0h ^ z2;
-                var v0 = z0;
-
-                v3 = (v3 << 1) | (v2 >> 63);
-                v2 = (v2 << 1) | (v1 >> 63);
-                v1 = (v1 << 1) | (v0 >> 63);
-                v0 = (v0 << 1);
-
-                v2 ^= v0 ^ (v0 >> 1) ^ (v0 >> 2) ^ (v0 >> 7);
-                v1 ^= (v0 << 63) ^ (v0 << 62) ^ (v0 << 57);
-                y1 = v3 ^ v1 ^ (v1 >> 1) ^ (v1 >> 2) ^ (v1 >> 7);
-                y0 = v2 ^ (v1 << 63) ^ (v1 << 62) ^ (v1 << 57);
+                // (acc + b0) * H^2 unreduced
+                const b0 = mem.readIntBig(u128, msg[i..][0..16]);
+                const z0 = acc ^ b0;
+                const z0h = clmul128(z0, st.hx[1]);
+
+                // b1 * H unreduced
+                const b1 = mem.readIntBig(u128, msg[i..][16..32]);
+                const b1h = clmul128(b1, st.hx[0]);
+
+                // (((acc + b0) * H^2) + B1 * H) (mod P)
+                const u = z0h ^ b1h;
+                acc = gcm_reduce(u);
             }
         }
 
         // single block
         while (i + 16 <= msg.len) : (i += 16) {
-            y1 ^= mem.readIntBig(u64, msg[i..][0..8]);
-            y0 ^= mem.readIntBig(u64, msg[i..][8..16]);
-
-            const y1r = @bitReverse(y1);
-            const y0r = @bitReverse(y0);
-            const y2 = y0 ^ y1;
-            const y2r = y0r ^ y1r;
-
-            const z0 = clmul(y0, st.h0);
-            const z1 = clmul(y1, st.h1);
-            var z2 = clmul(y2, st.h2) ^ z0 ^ z1;
-            var z0h = clmul(y0r, st.h0r);
-            var z1h = clmul(y1r, st.h1r);
-            var z2h = clmul(y2r, st.h2r) ^ z0h ^ z1h;
-            z0h = @bitReverse(z0h) >> 1;
-            z1h = @bitReverse(z1h) >> 1;
-            z2h = @bitReverse(z2h) >> 1;
-
-            // shift & reduce
-            var v3 = z1h;
-            var v2 = z1 ^ z2h;
-            var v1 = z0h ^ z2;
-            var v0 = z0;
-
-            v3 = (v3 << 1) | (v2 >> 63);
-            v2 = (v2 << 1) | (v1 >> 63);
-            v1 = (v1 << 1) | (v0 >> 63);
-            v0 = (v0 << 1);
-
-            v2 ^= v0 ^ (v0 >> 1) ^ (v0 >> 2) ^ (v0 >> 7);
-            v1 ^= (v0 << 63) ^ (v0 << 62) ^ (v0 << 57);
-            y1 = v3 ^ v1 ^ (v1 >> 1) ^ (v1 >> 2) ^ (v1 >> 7);
-            y0 = v2 ^ (v1 << 63) ^ (v1 << 62) ^ (v1 << 57);
+            // (acc + b0) * H unreduced
+            const b0 = mem.readIntBig(u128, msg[i..][0..16]);
+            const z0 = acc ^ b0;
+            const z0h = clmul128(z0, st.hx[0]);
+
+            // (acc + b0) * H (mod P)
+            acc = gcm_reduce(z0h);
         }
-        st.y1 = y1;
-        st.y0 = y0;
+        st.acc = acc;
     }
 
+    /// Absorb a message into the GHASH state.
     pub fn update(st: *Ghash, m: []const u8) void {
         var mb = m;
 
@@ -295,14 +294,15 @@ pub const Ghash = struct {
         st.leftover = 0;
     }
 
+    /// Compute the GHASH of the entire input.
     pub fn final(st: *Ghash, out: *[mac_length]u8) void {
         st.pad();
-        mem.writeIntBig(u64, out[0..8], st.y1);
-        mem.writeIntBig(u64, out[8..16], st.y0);
+        mem.writeIntBig(u128, out[0..16], st.acc);
 
         utils.secureZero(u8, @ptrCast([*]u8, st)[0..@sizeOf(Ghash)]);
     }
 
+    /// Compute the GHASH of a message.
     pub fn create(out: *[mac_length]u8, msg: []const u8, key: *const [key_length]u8) void {
         var st = Ghash.init(key);
         st.update(msg);