Commit 7d48cb1138

Frank Denis <124872+jedisct1@users.noreply.github.com>
2022-11-07 21:45:29
std.crypto: make ghash faster, esp. for small messages (#13464)
* std.crypto: make ghash faster, esp. for small messages Aggregated reduction requires 5 additional multiplications (to precompute the powers of H), in order to save 2 multiplications per batch. So, only use large batches when it's actually interesting to do so. For the last blocks, reuse the precomputations in order to perform a single reduction. Also, even in .ReleaseSmall, allow 2-block aggregation. The speedup is worth it, and the code increase is reasonable. And in .ReleaseFast, bump the upper batch size up to 16. Leverage comptime by the way instead of duplicating code. std/crypto/benchmark.zig on Apple M1: Zig 0.10.0: 2769 MiB/s Before: 6014 MiB/s After: 7334 MiB/s Normalize function names by the way. * Change clmul() to accept the half to be processed This avoids a bunch of truncate() calls. * Add more ghash tests to check all code paths
1 parent 88d2e4f
Changed files (1)
lib
std
crypto
lib/std/crypto/ghash.zig
@@ -18,7 +18,11 @@ pub const Ghash = struct {
     pub const mac_length = 16;
     pub const key_length = 16;
 
-    const pc_count = if (builtin.mode != .ReleaseSmall) 8 else 1;
+    const pc_count = if (builtin.mode != .ReleaseSmall) 16 else 2;
+    const agg_2_treshold = 5 * block_length;
+    const agg_4_treshold = 22 * block_length;
+    const agg_8_treshold = 84 * block_length;
+    const agg_16_treshold = 328 * block_length;
 
     hx: [pc_count]Precomp,
     acc: u128 = 0,
@@ -38,19 +42,26 @@ pub const Ghash = struct {
 
         var hx: [pc_count]Precomp = undefined;
         hx[0] = h;
+        if (block_count >= agg_2_treshold) {
+            hx[1] = gcmReduce(clsq128(hx[0])); // h^2
+        }
         if (builtin.mode != .ReleaseSmall) {
-            if (block_count > 2) {
-                hx[1] = gcm_reduce(clsq128(hx[0])); // h^2
+            if (block_count >= agg_4_treshold) {
+                hx[2] = gcmReduce(clmul128(hx[1], h)); // h^3
+                hx[3] = gcmReduce(clsq128(hx[1])); // h^4 = h^2^2
             }
-            if (block_count > 4) {
-                hx[2] = gcm_reduce(clmul128(hx[1], h)); // h^3
-                hx[3] = gcm_reduce(clsq128(hx[1])); // h^4
+            if (block_count >= agg_8_treshold) {
+                hx[4] = gcmReduce(clmul128(hx[3], h)); // h^5
+                hx[5] = gcmReduce(clsq128(hx[2])); // h^6 = h^3^2
+                hx[6] = gcmReduce(clmul128(hx[5], h)); // h^7
+                hx[7] = gcmReduce(clsq128(hx[3])); // h^8 = h^4^2
             }
-            if (block_count > 8) {
-                hx[4] = gcm_reduce(clmul128(hx[3], h)); // h^5
-                hx[5] = gcm_reduce(clmul128(hx[4], h)); // h^6
-                hx[6] = gcm_reduce(clmul128(hx[5], h)); // h^7
-                hx[7] = gcm_reduce(clsq128(hx[3])); // h^8
+            if (block_count >= agg_16_treshold) {
+                var i: usize = 8;
+                while (i < 16) : (i += 2) {
+                    hx[i] = gcmReduce(clmul128(hx[i - 1], h));
+                    hx[i + 1] = gcmReduce(clsq128(hx[i / 2]));
+                }
             }
         }
         return Ghash{ .hx = hx };
@@ -62,29 +73,52 @@ pub const Ghash = struct {
     }
 
     // Carryless multiplication of two 64-bit integers for x86_64.
-    inline fn clmul_pclmul(x: u64, y: u64) u128 {
-        const product = asm (
-            \\ vpclmulqdq $0x00, %[x], %[y], %[out]
-            : [out] "=x" (-> @Vector(2, u64)),
-            : [x] "x" (@bitCast(@Vector(2, u64), @as(u128, x))),
-              [y] "x" (@bitCast(@Vector(2, u64), @as(u128, y))),
-        );
-        return (@as(u128, product[1]) << 64) | product[0];
+    inline fn clmulPclmul(x: u128, y: u128, comptime half: enum { lo, hi }) u128 {
+        if (half == .hi) {
+            const product = asm (
+                \\ vpclmulqdq $0x11, %[x], %[y], %[out]
+                : [out] "=x" (-> @Vector(2, u64)),
+                : [x] "x" (@bitCast(@Vector(2, u64), @as(u128, x))),
+                  [y] "x" (@bitCast(@Vector(2, u64), @as(u128, y))),
+            );
+            return @bitCast(u128, product);
+        } else {
+            const product = asm (
+                \\ vpclmulqdq $0x00, %[x], %[y], %[out]
+                : [out] "=x" (-> @Vector(2, u64)),
+                : [x] "x" (@bitCast(@Vector(2, u64), @as(u128, x))),
+                  [y] "x" (@bitCast(@Vector(2, u64), @as(u128, y))),
+            );
+            return @bitCast(u128, product);
+        }
     }
 
     // Carryless multiplication of two 64-bit integers for ARM crypto.
-    inline fn clmul_pmull(x: u64, y: u64) u128 {
-        const product = asm (
-            \\ pmull %[out].1q, %[x].1d, %[y].1d
-            : [out] "=w" (-> @Vector(2, u64)),
-            : [x] "w" (@bitCast(@Vector(2, u64), @as(u128, x))),
-              [y] "w" (@bitCast(@Vector(2, u64), @as(u128, y))),
-        );
-        return (@as(u128, product[1]) << 64) | product[0];
+    inline fn clmulPmull(x: u128, y: u128, comptime half: enum { lo, hi }) u128 {
+        if (half == .hi) {
+            const product = asm (
+                \\ pmull2 %[out].1q, %[x].2d, %[y].2d
+                : [out] "=w" (-> @Vector(2, u64)),
+                : [x] "w" (@bitCast(@Vector(2, u64), @as(u128, x))),
+                  [y] "w" (@bitCast(@Vector(2, u64), @as(u128, y))),
+            );
+            return @bitCast(u128, product);
+        } else {
+            const product = asm (
+                \\ pmull %[out].1q, %[x].1d, %[y].1d
+                : [out] "=w" (-> @Vector(2, u64)),
+                : [x] "w" (@bitCast(@Vector(2, u64), @as(u128, x))),
+                  [y] "w" (@bitCast(@Vector(2, u64), @as(u128, y))),
+            );
+            return @bitCast(u128, product);
+        }
     }
 
     // Software carryless multiplication of two 64-bit integers.
-    fn clmul_soft(x: u64, y: u64) u128 {
+    fn clmulSoft(x_: u128, y_: u128, comptime half: enum { lo, hi }) u128 {
+        const x = @truncate(u64, if (half == .hi) x_ >> 64 else x_);
+        const y = @truncate(u64, if (half == .hi) y_ >> 64 else y_);
+
         const x0 = x & 0x1111111111111110;
         const x1 = x & 0x2222222222222220;
         const x2 = x & 0x4444444444444440;
@@ -116,32 +150,31 @@ pub const Ghash = struct {
         const lo = @truncate(u64, x);
         const hi = @truncate(u64, x >> 64);
         const mid = lo ^ hi;
-        const r_lo = clmul(lo, lo);
-        const r_hi = clmul(hi, hi);
-        const r_mid = clmul(mid, mid) ^ r_lo ^ r_hi;
+        const r_lo = clmul(x, x, .lo);
+        const r_hi = clmul(x, x, .hi);
+        const r_mid = clmul(mid, mid, .lo) ^ r_lo ^ r_hi;
         return (@as(u256, r_hi) << 128) ^ (@as(u256, r_mid) << 64) ^ r_lo;
     }
 
     // Multiply two 128-bit integers in GF(2^128).
     inline fn clmul128(x: u128, y: u128) u256 {
-        const x_lo = @truncate(u64, x);
         const x_hi = @truncate(u64, x >> 64);
-        const y_lo = @truncate(u64, y);
         const y_hi = @truncate(u64, y >> 64);
-        const r_lo = clmul(x_lo, y_lo);
-        const r_hi = clmul(x_hi, y_hi);
-        const r_mid = clmul(x_lo ^ x_hi, y_lo ^ y_hi) ^ r_lo ^ r_hi;
+        const r_lo = clmul(x, y, .lo);
+        const r_hi = clmul(x, y, .hi);
+        const r_mid = clmul(x ^ x_hi, y ^ y_hi, .lo) ^ r_lo ^ r_hi;
         return (@as(u256, r_hi) << 128) ^ (@as(u256, r_mid) << 64) ^ r_lo;
     }
 
     // Reduce a 256-bit representative of a polynomial modulo the irreducible polynomial x^128 + x^127 + x^126 + x^121 + 1.
     // This is done *without reversing the bits*, using Shay Gueron's black magic demysticated here:
     // https://blog.quarkslab.com/reversing-a-finite-field-multiplication-optimization.html
-    inline fn gcm_reduce(x: u256) u128 {
+    inline fn gcmReduce(x: u256) u128 {
         const p64 = (((1 << 121) | (1 << 126) | (1 << 127)) >> 64);
-        const a = clmul(@truncate(u64, x), p64);
-        const b = ((@truncate(u128, x) << 64) | (@truncate(u128, x) >> 64)) ^ a;
-        const c = clmul(@truncate(u64, b), p64);
+        const lo = @truncate(u128, x);
+        const a = clmul(lo, p64, .lo);
+        const b = ((lo << 64) | (lo >> 64)) ^ a;
+        const c = clmul(b, p64, .lo);
         const d = ((b << 64) | (b >> 64)) ^ c;
         return d ^ @truncate(u128, x >> 128);
     }
@@ -150,103 +183,73 @@ pub const Ghash = struct {
     const has_avx = std.Target.x86.featureSetHas(builtin.cpu.features, .avx);
     const has_armaes = std.Target.aarch64.featureSetHas(builtin.cpu.features, .aes);
     const clmul = if (builtin.cpu.arch == .x86_64 and has_pclmul and has_avx) impl: {
-        break :impl clmul_pclmul;
+        break :impl clmulPclmul;
     } else if (builtin.cpu.arch == .aarch64 and has_armaes) impl: {
-        break :impl clmul_pmull;
+        break :impl clmulPmull;
     } else impl: {
-        break :impl clmul_soft;
+        break :impl clmulSoft;
     };
 
-    // Process a block of 16 bytes.
+    // Process 16 byte blocks.
     fn blocks(st: *Ghash, msg: []const u8) void {
         assert(msg.len % 16 == 0); // GHASH blocks() expects full blocks
         var acc = st.acc;
 
         var i: usize = 0;
 
-        if (builtin.mode != .ReleaseSmall) {
+        if (builtin.mode != .ReleaseSmall and msg.len >= agg_16_treshold) {
+            // 16-blocks aggregated reduction
+            while (i + 256 <= msg.len) : (i += 256) {
+                var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[15 - 0]);
+                comptime var j = 1;
+                inline while (j < 16) : (j += 1) {
+                    u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[15 - j]);
+                }
+                acc = gcmReduce(u);
+            }
+        } else if (builtin.mode != .ReleaseSmall and msg.len >= agg_8_treshold) {
             // 8-blocks aggregated reduction
             while (i + 128 <= msg.len) : (i += 128) {
-                const b0 = mem.readIntBig(u128, msg[i..][0..16]);
-                const z0 = acc ^ b0;
-                const z0h = clmul128(z0, st.hx[7]);
-
-                const b1 = mem.readIntBig(u128, msg[i..][16..32]);
-                const b1h = clmul128(b1, st.hx[6]);
-
-                const b2 = mem.readIntBig(u128, msg[i..][32..48]);
-                const b2h = clmul128(b2, st.hx[5]);
-
-                const b3 = mem.readIntBig(u128, msg[i..][48..64]);
-                const b3h = clmul128(b3, st.hx[4]);
-
-                const b4 = mem.readIntBig(u128, msg[i..][64..80]);
-                const b4h = clmul128(b4, st.hx[3]);
-
-                const b5 = mem.readIntBig(u128, msg[i..][80..96]);
-                const b5h = clmul128(b5, st.hx[2]);
-
-                const b6 = mem.readIntBig(u128, msg[i..][96..112]);
-                const b6h = clmul128(b6, st.hx[1]);
-
-                const b7 = mem.readIntBig(u128, msg[i..][112..128]);
-                const b7h = clmul128(b7, st.hx[0]);
-
-                const u = z0h ^ b1h ^ b2h ^ b3h ^ b4h ^ b5h ^ b6h ^ b7h;
-                acc = gcm_reduce(u);
+                var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[7 - 0]);
+                comptime var j = 1;
+                inline while (j < 8) : (j += 1) {
+                    u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[7 - j]);
+                }
+                acc = gcmReduce(u);
             }
-
+        } else if (builtin.mode != .ReleaseSmall and msg.len >= agg_4_treshold) {
             // 4-blocks aggregated reduction
             while (i + 64 <= msg.len) : (i += 64) {
-                // (acc + b0) * H^4 unreduced
-                const b0 = mem.readIntBig(u128, msg[i..][0..16]);
-                const z0 = acc ^ b0;
-                const z0h = clmul128(z0, st.hx[3]);
-
-                // b1 * H^3 unreduced
-                const b1 = mem.readIntBig(u128, msg[i..][16..32]);
-                const b1h = clmul128(b1, st.hx[2]);
-
-                // b2 * H^2 unreduced
-                const b2 = mem.readIntBig(u128, msg[i..][32..48]);
-                const b2h = clmul128(b2, st.hx[1]);
-
-                // b3 * H unreduced
-                const b3 = mem.readIntBig(u128, msg[i..][48..64]);
-                const b3h = clmul128(b3, st.hx[0]);
-
-                // (((acc + b0) * H^4) + B1 * H^3 + B2 * H^2 + B3 * H) (mod P)
-                const u = z0h ^ b1h ^ b2h ^ b3h;
-                acc = gcm_reduce(u);
+                var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[3 - 0]);
+                comptime var j = 1;
+                inline while (j < 4) : (j += 1) {
+                    u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[3 - j]);
+                }
+                acc = gcmReduce(u);
             }
-
+        } else if (msg.len >= agg_2_treshold) {
             // 2-blocks aggregated reduction
             while (i + 32 <= msg.len) : (i += 32) {
-                // (acc + b0) * H^2 unreduced
-                const b0 = mem.readIntBig(u128, msg[i..][0..16]);
-                const z0 = acc ^ b0;
-                const z0h = clmul128(z0, st.hx[1]);
-
-                // b1 * H unreduced
-                const b1 = mem.readIntBig(u128, msg[i..][16..32]);
-                const b1h = clmul128(b1, st.hx[0]);
-
-                // (((acc + b0) * H^2) + B1 * H) (mod P)
-                const u = z0h ^ b1h;
-                acc = gcm_reduce(u);
+                var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[1 - 0]);
+                comptime var j = 1;
+                inline while (j < 2) : (j += 1) {
+                    u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[1 - j]);
+                }
+                acc = gcmReduce(u);
             }
         }
-
-        // single block
-        while (i + 16 <= msg.len) : (i += 16) {
-            // (acc + b0) * H unreduced
-            const b0 = mem.readIntBig(u128, msg[i..][0..16]);
-            const z0 = acc ^ b0;
-            const z0h = clmul128(z0, st.hx[0]);
-
-            // (acc + b0) * H (mod P)
-            acc = gcm_reduce(z0h);
+        // remaining blocks
+        if (i < msg.len) {
+            const n = (msg.len - i) / 16;
+            var u = clmul128(acc ^ mem.readIntBig(u128, msg[i..][0..16]), st.hx[n - 1 - 0]);
+            var j: usize = 1;
+            while (j < n) : (j += 1) {
+                u ^= clmul128(mem.readIntBig(u128, msg[i..][j * 16 ..][0..16]), st.hx[n - 1 - j]);
+            }
+            i += n * 16;
+            acc = gcmReduce(u);
         }
+        assert(i == msg.len);
         st.acc = acc;
     }
 
@@ -328,3 +331,36 @@ test "ghash" {
     st.final(&out);
     try htest.assertEqual("889295fa746e8b174bf4ec80a65dea41", &out);
 }
+
+test "ghash2" {
+    var key: [16]u8 = undefined;
+    var i: usize = 0;
+    while (i < key.len) : (i += 1) {
+        key[i] = @intCast(u8, i * 15 + 1);
+    }
+    const tvs = [_]struct { len: usize, hash: [:0]const u8 }{
+        .{ .len = 5263, .hash = "b9395f37c131cd403a327ccf82ec016a" },
+        .{ .len = 1361, .hash = "8c24cb3664e9a36e32ddef0c8178ab33" },
+        .{ .len = 1344, .hash = "015d7243b52d62eee8be33a66a9658cc" },
+        .{ .len = 1000, .hash = "56e148799944193f351f2014ef9dec9d" },
+        .{ .len = 512, .hash = "ca4882ce40d37546185c57709d17d1ca" },
+        .{ .len = 128, .hash = "d36dc3aac16cfe21a75cd5562d598c1c" },
+        .{ .len = 111, .hash = "6e2bea99700fd19cf1694e7b56543320" },
+        .{ .len = 80, .hash = "aa28f4092a7cca155f3de279cf21aa17" },
+        .{ .len = 16, .hash = "9d7eb5ed121a52a4b0996e4ec9b98911" },
+        .{ .len = 1, .hash = "968a203e5c7a98b6d4f3112f4d6b89a7" },
+        .{ .len = 0, .hash = "00000000000000000000000000000000" },
+    };
+    inline for (tvs) |tv| {
+        var m: [tv.len]u8 = undefined;
+        i = 0;
+        while (i < m.len) : (i += 1) {
+            m[i] = @truncate(u8, i % 254 + 1);
+        }
+        var st = Ghash.init(&key);
+        st.update(&m);
+        var out: [16]u8 = undefined;
+        st.final(&out);
+        try htest.assertEqual(tv.hash, &out);
+    }
+}