master
  1const std = @import("../std.zig");
  2const builtin = @import("builtin");
  3const assert = std.debug.assert;
  4const math = std.math;
  5const mem = std.mem;
  6
  7const Precomp = u128;
  8
  9/// GHASH is a universal hash function that uses multiplication by a fixed
 10/// parameter within a Galois field.
 11///
 12/// It is not a general purpose hash function - The key must be secret, unpredictable and never reused.
 13///
 14/// GHASH is typically used to compute the authentication tag in the AES-GCM construction.
 15pub const Ghash = Hash(.big, true);
 16
 17/// POLYVAL is a universal hash function that uses multiplication by a fixed
 18/// parameter within a Galois field.
 19///
 20/// It is not a general purpose hash function - The key must be secret, unpredictable and never reused.
 21///
 22/// POLYVAL is typically used to compute the authentication tag in the AES-GCM-SIV construction.
 23pub const Polyval = Hash(.little, false);
 24
 25fn Hash(comptime endian: std.builtin.Endian, comptime shift_key: bool) type {
 26    return struct {
 27        const Self = @This();
 28
 29        pub const block_length: usize = 16;
 30        pub const mac_length = 16;
 31        pub const key_length = 16;
 32
 33        const pc_count = if (builtin.mode != .ReleaseSmall) 16 else 2;
 34        const agg_4_threshold = 22;
 35        const agg_8_threshold = 84;
 36        const agg_16_threshold = 328;
 37
 38        // Before the Haswell architecture, the carryless multiplication instruction was
 39        // extremely slow. Even with 128-bit operands, using Karatsuba multiplication was
 40        // thus faster than a schoolbook multiplication.
 41        // This is no longer the case -- Modern CPUs, including ARM-based ones, have a fast
 42        // carryless multiplication instruction; using 4 multiplications is now faster than
 43        // 3 multiplications with extra shifts and additions.
 44        const mul_algorithm = if (builtin.cpu.arch == .x86) .karatsuba else .schoolbook;
 45
 46        hx: [pc_count]Precomp,
 47        acc: u128 = 0,
 48
 49        leftover: usize = 0,
 50        buf: [block_length]u8 align(16) = undefined,
 51
 52        /// Initialize the GHASH state with a key, and a minimum number of block count.
 53        pub fn initForBlockCount(key: *const [key_length]u8, block_count: usize) Self {
 54            var h = mem.readInt(u128, key[0..16], endian);
 55            if (shift_key) {
 56                // Shift the key by 1 bit to the left & reduce for GCM.
 57                const carry = ((@as(u128, 0xc2) << 120) | 1) & (@as(u128, 0) -% (h >> 127));
 58                h = (h << 1) ^ carry;
 59            }
 60            var hx: [pc_count]Precomp = undefined;
 61            hx[0] = h;
 62            hx[1] = reduce(clsq128(hx[0])); // h^2
 63
 64            if (builtin.mode != .ReleaseSmall) {
 65                hx[2] = reduce(clmul128(hx[1], h)); // h^3
 66                hx[3] = reduce(clsq128(hx[1])); // h^4 = h^2^2
 67                if (block_count >= agg_8_threshold) {
 68                    hx[4] = reduce(clmul128(hx[3], h)); // h^5
 69                    hx[5] = reduce(clsq128(hx[2])); // h^6 = h^3^2
 70                    hx[6] = reduce(clmul128(hx[5], h)); // h^7
 71                    hx[7] = reduce(clsq128(hx[3])); // h^8 = h^4^2
 72                }
 73                if (block_count >= agg_16_threshold) {
 74                    var i: usize = 8;
 75                    while (i < 16) : (i += 2) {
 76                        hx[i] = reduce(clmul128(hx[i - 1], h));
 77                        hx[i + 1] = reduce(clsq128(hx[i / 2]));
 78                    }
 79                }
 80            }
 81            return Self{ .hx = hx };
 82        }
 83
 84        /// Initialize the GHASH state with a key.
 85        pub fn init(key: *const [key_length]u8) Self {
 86            return Self.initForBlockCount(key, math.maxInt(usize));
 87        }
 88
 89        const Selector = enum { lo, hi, hi_lo };
 90
 91        // Carryless multiplication of two 64-bit integers for x86_64.
 92        fn clmulPclmul(x: u128, y: u128, comptime half: Selector) u128 {
 93            switch (half) {
 94                .hi => {
 95                    const product = asm (
 96                        \\ vpclmulqdq $0x11, %[x], %[y], %[out]
 97                        : [out] "=x" (-> @Vector(2, u64)),
 98                        : [x] "x" (@as(@Vector(2, u64), @bitCast(x))),
 99                          [y] "x" (@as(@Vector(2, u64), @bitCast(y))),
100                    );
101                    return @as(u128, @bitCast(product));
102                },
103                .lo => {
104                    const product = asm (
105                        \\ vpclmulqdq $0x00, %[x], %[y], %[out]
106                        : [out] "=x" (-> @Vector(2, u64)),
107                        : [x] "x" (@as(@Vector(2, u64), @bitCast(x))),
108                          [y] "x" (@as(@Vector(2, u64), @bitCast(y))),
109                    );
110                    return @as(u128, @bitCast(product));
111                },
112                .hi_lo => {
113                    const product = asm (
114                        \\ vpclmulqdq $0x10, %[x], %[y], %[out]
115                        : [out] "=x" (-> @Vector(2, u64)),
116                        : [x] "x" (@as(@Vector(2, u64), @bitCast(x))),
117                          [y] "x" (@as(@Vector(2, u64), @bitCast(y))),
118                    );
119                    return @as(u128, @bitCast(product));
120                },
121            }
122        }
123
124        // Carryless multiplication of two 64-bit integers for ARM crypto.
125        fn clmulPmull(x: u128, y: u128, comptime half: Selector) u128 {
126            switch (half) {
127                .hi => {
128                    const product = asm (
129                        \\ pmull2 %[out].1q, %[x].2d, %[y].2d
130                        : [out] "=w" (-> @Vector(2, u64)),
131                        : [x] "w" (@as(@Vector(2, u64), @bitCast(x))),
132                          [y] "w" (@as(@Vector(2, u64), @bitCast(y))),
133                    );
134                    return @as(u128, @bitCast(product));
135                },
136                .lo => {
137                    const product = asm (
138                        \\ pmull %[out].1q, %[x].1d, %[y].1d
139                        : [out] "=w" (-> @Vector(2, u64)),
140                        : [x] "w" (@as(@Vector(2, u64), @bitCast(x))),
141                          [y] "w" (@as(@Vector(2, u64), @bitCast(y))),
142                    );
143                    return @as(u128, @bitCast(product));
144                },
145                .hi_lo => {
146                    const product = asm (
147                        \\ pmull %[out].1q, %[x].1d, %[y].1d
148                        : [out] "=w" (-> @Vector(2, u64)),
149                        : [x] "w" (@as(@Vector(2, u64), @bitCast(x >> 64))),
150                          [y] "w" (@as(@Vector(2, u64), @bitCast(y))),
151                    );
152                    return @as(u128, @bitCast(product));
153                },
154            }
155        }
156
157        /// clmulSoft128_64 is faster on platforms with no native 128-bit registers.
158        const clmulSoft = switch (builtin.cpu.arch) {
159            .wasm32, .wasm64 => clmulSoft128_64,
160            else => if (std.simd.suggestVectorLength(u128) != null) clmulSoft128 else clmulSoft128_64,
161        };
162
163        // Software carryless multiplication of two 64-bit integers using native 128-bit registers.
164        fn clmulSoft128(x_: u128, y_: u128, comptime half: Selector) u128 {
165            const x = @as(u64, @truncate(if (half == .hi or half == .hi_lo) x_ >> 64 else x_));
166            const y = @as(u64, @truncate(if (half == .hi) y_ >> 64 else y_));
167
168            const x0 = x & 0x1111111111111110;
169            const x1 = x & 0x2222222222222220;
170            const x2 = x & 0x4444444444444440;
171            const x3 = x & 0x8888888888888880;
172            const y0 = y & 0x1111111111111111;
173            const y1 = y & 0x2222222222222222;
174            const y2 = y & 0x4444444444444444;
175            const y3 = y & 0x8888888888888888;
176            const z0 = (x0 * @as(u128, y0)) ^ (x1 * @as(u128, y3)) ^ (x2 * @as(u128, y2)) ^ (x3 * @as(u128, y1));
177            const z1 = (x0 * @as(u128, y1)) ^ (x1 * @as(u128, y0)) ^ (x2 * @as(u128, y3)) ^ (x3 * @as(u128, y2));
178            const z2 = (x0 * @as(u128, y2)) ^ (x1 * @as(u128, y1)) ^ (x2 * @as(u128, y0)) ^ (x3 * @as(u128, y3));
179            const z3 = (x0 * @as(u128, y3)) ^ (x1 * @as(u128, y2)) ^ (x2 * @as(u128, y1)) ^ (x3 * @as(u128, y0));
180
181            const x0_mask = @as(u64, 0) -% (x & 1);
182            const x1_mask = @as(u64, 0) -% ((x >> 1) & 1);
183            const x2_mask = @as(u64, 0) -% ((x >> 2) & 1);
184            const x3_mask = @as(u64, 0) -% ((x >> 3) & 1);
185            const extra = (x0_mask & y) ^ (@as(u128, x1_mask & y) << 1) ^
186                (@as(u128, x2_mask & y) << 2) ^ (@as(u128, x3_mask & y) << 3);
187
188            return (z0 & 0x11111111111111111111111111111111) ^
189                (z1 & 0x22222222222222222222222222222222) ^
190                (z2 & 0x44444444444444444444444444444444) ^
191                (z3 & 0x88888888888888888888888888888888) ^ extra;
192        }
193
194        // Software carryless multiplication of two 32-bit integers.
195        fn clmulSoft32(x: u32, y: u32) u64 {
196            const mulWide = math.mulWide;
197            const a0 = x & 0x11111111;
198            const a1 = x & 0x22222222;
199            const a2 = x & 0x44444444;
200            const a3 = x & 0x88888888;
201            const b0 = y & 0x11111111;
202            const b1 = y & 0x22222222;
203            const b2 = y & 0x44444444;
204            const b3 = y & 0x88888888;
205            const c0 = mulWide(u32, a0, b0) ^ mulWide(u32, a1, b3) ^ mulWide(u32, a2, b2) ^ mulWide(u32, a3, b1);
206            const c1 = mulWide(u32, a0, b1) ^ mulWide(u32, a1, b0) ^ mulWide(u32, a2, b3) ^ mulWide(u32, a3, b2);
207            const c2 = mulWide(u32, a0, b2) ^ mulWide(u32, a1, b1) ^ mulWide(u32, a2, b0) ^ mulWide(u32, a3, b3);
208            const c3 = mulWide(u32, a0, b3) ^ mulWide(u32, a1, b2) ^ mulWide(u32, a2, b1) ^ mulWide(u32, a3, b0);
209            return (c0 & 0x1111111111111111) | (c1 & 0x2222222222222222) | (c2 & 0x4444444444444444) | (c3 & 0x8888888888888888);
210        }
211
212        // Software carryless multiplication of two 128-bit integers using 64-bit registers.
213        fn clmulSoft128_64(x_: u128, y_: u128, comptime half: Selector) u128 {
214            const a = @as(u64, @truncate(if (half == .hi or half == .hi_lo) x_ >> 64 else x_));
215            const b = @as(u64, @truncate(if (half == .hi) y_ >> 64 else y_));
216            const a0 = @as(u32, @truncate(a));
217            const a1 = @as(u32, @truncate(a >> 32));
218            const b0 = @as(u32, @truncate(b));
219            const b1 = @as(u32, @truncate(b >> 32));
220            const lo = clmulSoft32(a0, b0);
221            const hi = clmulSoft32(a1, b1);
222            const mid = clmulSoft32(a0 ^ a1, b0 ^ b1) ^ lo ^ hi;
223            const res_lo = lo ^ (mid << 32);
224            const res_hi = hi ^ (mid >> 32);
225            return @as(u128, res_lo) | (@as(u128, res_hi) << 64);
226        }
227
228        const I256 = struct {
229            hi: u128,
230            lo: u128,
231            mid: u128,
232        };
233
234        fn xor256(x: *I256, y: I256) void {
235            x.* = I256{
236                .hi = x.hi ^ y.hi,
237                .lo = x.lo ^ y.lo,
238                .mid = x.mid ^ y.mid,
239            };
240        }
241
242        // Square a 128-bit integer in GF(2^128).
243        fn clsq128(x: u128) I256 {
244            return .{
245                .hi = clmul(x, x, .hi),
246                .lo = clmul(x, x, .lo),
247                .mid = 0,
248            };
249        }
250
251        // Multiply two 128-bit integers in GF(2^128).
252        fn clmul128(x: u128, y: u128) I256 {
253            if (mul_algorithm == .karatsuba) {
254                const x_hi = @as(u64, @truncate(x >> 64));
255                const y_hi = @as(u64, @truncate(y >> 64));
256                const r_lo = clmul(x, y, .lo);
257                const r_hi = clmul(x, y, .hi);
258                const r_mid = clmul(x ^ x_hi, y ^ y_hi, .lo) ^ r_lo ^ r_hi;
259                return .{
260                    .hi = r_hi,
261                    .lo = r_lo,
262                    .mid = r_mid,
263                };
264            } else {
265                return .{
266                    .hi = clmul(x, y, .hi),
267                    .lo = clmul(x, y, .lo),
268                    .mid = clmul(x, y, .hi_lo) ^ clmul(y, x, .hi_lo),
269                };
270            }
271        }
272
273        // Reduce a 256-bit representative of a polynomial modulo the irreducible polynomial x^128 + x^127 + x^126 + x^121 + 1.
274        // This is done using Shay Gueron's black magic demysticated here:
275        // https://blog.quarkslab.com/reversing-a-finite-field-multiplication-optimization.html
276        fn reduce(x: I256) u128 {
277            const hi = x.hi ^ (x.mid >> 64);
278            const lo = x.lo ^ (x.mid << 64);
279            const p64 = (((1 << 121) | (1 << 126) | (1 << 127)) >> 64);
280            const a = clmul(lo, p64, .lo);
281            const b = ((lo << 64) | (lo >> 64)) ^ a;
282            const c = clmul(b, p64, .lo);
283            const d = ((b << 64) | (b >> 64)) ^ c;
284            return d ^ hi;
285        }
286
287        const has_pclmul = builtin.cpu.has(.x86, .pclmul);
288        const has_avx = builtin.cpu.has(.x86, .avx);
289        const has_armaes = builtin.cpu.has(.aarch64, .aes);
290        // C backend doesn't currently support passing vectors to inline asm.
291        const clmul = if (builtin.cpu.arch == .x86_64 and builtin.zig_backend != .stage2_c and has_pclmul and has_avx) impl: {
292            break :impl clmulPclmul;
293        } else if (builtin.cpu.arch == .aarch64 and builtin.zig_backend != .stage2_c and has_armaes) impl: {
294            break :impl clmulPmull;
295        } else impl: {
296            break :impl clmulSoft;
297        };
298
299        // Process 16 byte blocks.
300        fn blocks(st: *Self, msg: []const u8) void {
301            assert(msg.len % 16 == 0); // GHASH blocks() expects full blocks
302            var acc = st.acc;
303
304            var i: usize = 0;
305
306            if (builtin.mode != .ReleaseSmall and msg.len >= agg_16_threshold * block_length) {
307                // 16-blocks aggregated reduction
308                while (i + 256 <= msg.len) : (i += 256) {
309                    var u = clmul128(acc ^ mem.readInt(u128, msg[i..][0..16], endian), st.hx[15 - 0]);
310                    comptime var j = 1;
311                    inline while (j < 16) : (j += 1) {
312                        xor256(&u, clmul128(mem.readInt(u128, msg[i..][j * 16 ..][0..16], endian), st.hx[15 - j]));
313                    }
314                    acc = reduce(u);
315                }
316            } else if (builtin.mode != .ReleaseSmall and msg.len >= agg_8_threshold * block_length) {
317                // 8-blocks aggregated reduction
318                while (i + 128 <= msg.len) : (i += 128) {
319                    var u = clmul128(acc ^ mem.readInt(u128, msg[i..][0..16], endian), st.hx[7 - 0]);
320                    comptime var j = 1;
321                    inline while (j < 8) : (j += 1) {
322                        xor256(&u, clmul128(mem.readInt(u128, msg[i..][j * 16 ..][0..16], endian), st.hx[7 - j]));
323                    }
324                    acc = reduce(u);
325                }
326            } else if (builtin.mode != .ReleaseSmall and msg.len >= agg_4_threshold * block_length) {
327                // 4-blocks aggregated reduction
328                while (i + 64 <= msg.len) : (i += 64) {
329                    var u = clmul128(acc ^ mem.readInt(u128, msg[i..][0..16], endian), st.hx[3 - 0]);
330                    comptime var j = 1;
331                    inline while (j < 4) : (j += 1) {
332                        xor256(&u, clmul128(mem.readInt(u128, msg[i..][j * 16 ..][0..16], endian), st.hx[3 - j]));
333                    }
334                    acc = reduce(u);
335                }
336            }
337            // 2-blocks aggregated reduction
338            while (i + 32 <= msg.len) : (i += 32) {
339                var u = clmul128(acc ^ mem.readInt(u128, msg[i..][0..16], endian), st.hx[1 - 0]);
340                comptime var j = 1;
341                inline while (j < 2) : (j += 1) {
342                    xor256(&u, clmul128(mem.readInt(u128, msg[i..][j * 16 ..][0..16], endian), st.hx[1 - j]));
343                }
344                acc = reduce(u);
345            }
346            // remaining blocks
347            if (i < msg.len) {
348                const u = clmul128(acc ^ mem.readInt(u128, msg[i..][0..16], endian), st.hx[0]);
349                acc = reduce(u);
350                i += 16;
351            }
352            assert(i == msg.len);
353            st.acc = acc;
354        }
355
356        /// Absorb a message into the GHASH state.
357        pub fn update(st: *Self, m: []const u8) void {
358            var mb = m;
359
360            if (st.leftover > 0) {
361                const want = @min(block_length - st.leftover, mb.len);
362                const mc = mb[0..want];
363                for (mc, 0..) |x, i| {
364                    st.buf[st.leftover + i] = x;
365                }
366                mb = mb[want..];
367                st.leftover += want;
368                if (st.leftover < block_length) {
369                    return;
370                }
371                st.blocks(&st.buf);
372                st.leftover = 0;
373            }
374            if (mb.len >= block_length) {
375                const want = mb.len & ~(block_length - 1);
376                st.blocks(mb[0..want]);
377                mb = mb[want..];
378            }
379            if (mb.len > 0) {
380                for (mb, 0..) |x, i| {
381                    st.buf[st.leftover + i] = x;
382                }
383                st.leftover += mb.len;
384            }
385        }
386
387        /// Zero-pad to align the next input to the first byte of a block
388        pub fn pad(st: *Self) void {
389            if (st.leftover == 0) {
390                return;
391            }
392            var i = st.leftover;
393            while (i < block_length) : (i += 1) {
394                st.buf[i] = 0;
395            }
396            st.blocks(&st.buf);
397            st.leftover = 0;
398        }
399
400        /// Compute the GHASH of the entire input.
401        pub fn final(st: *Self, out: *[mac_length]u8) void {
402            st.pad();
403            mem.writeInt(u128, out[0..16], st.acc, endian);
404
405            std.crypto.secureZero(u8, @as([*]u8, @ptrCast(st))[0..@sizeOf(Self)]);
406        }
407
408        /// Compute the GHASH of a message.
409        pub fn create(out: *[mac_length]u8, msg: []const u8, key: *const [key_length]u8) void {
410            var st = Self.init(key);
411            st.update(msg);
412            st.final(out);
413        }
414    };
415}
416
417const htest = @import("test.zig");
418
419test "ghash" {
420    const key = [_]u8{0x42} ** 16;
421    const m = [_]u8{0x69} ** 256;
422
423    var st = Ghash.init(&key);
424    st.update(&m);
425    var out: [16]u8 = undefined;
426    st.final(&out);
427    try htest.assertEqual("889295fa746e8b174bf4ec80a65dea41", &out);
428
429    st = Ghash.init(&key);
430    st.update(m[0..100]);
431    st.update(m[100..]);
432    st.final(&out);
433    try htest.assertEqual("889295fa746e8b174bf4ec80a65dea41", &out);
434}
435
436test "ghash2" {
437    var key: [16]u8 = undefined;
438    var i: usize = 0;
439    while (i < key.len) : (i += 1) {
440        key[i] = @as(u8, @intCast(i * 15 + 1));
441    }
442    const tvs = [_]struct { len: usize, hash: [:0]const u8 }{
443        .{ .len = 5263, .hash = "b9395f37c131cd403a327ccf82ec016a" },
444        .{ .len = 1361, .hash = "8c24cb3664e9a36e32ddef0c8178ab33" },
445        .{ .len = 1344, .hash = "015d7243b52d62eee8be33a66a9658cc" },
446        .{ .len = 1000, .hash = "56e148799944193f351f2014ef9dec9d" },
447        .{ .len = 512, .hash = "ca4882ce40d37546185c57709d17d1ca" },
448        .{ .len = 128, .hash = "d36dc3aac16cfe21a75cd5562d598c1c" },
449        .{ .len = 111, .hash = "6e2bea99700fd19cf1694e7b56543320" },
450        .{ .len = 80, .hash = "aa28f4092a7cca155f3de279cf21aa17" },
451        .{ .len = 16, .hash = "9d7eb5ed121a52a4b0996e4ec9b98911" },
452        .{ .len = 1, .hash = "968a203e5c7a98b6d4f3112f4d6b89a7" },
453        .{ .len = 0, .hash = "00000000000000000000000000000000" },
454    };
455    inline for (tvs) |tv| {
456        var m: [tv.len]u8 = undefined;
457        i = 0;
458        while (i < m.len) : (i += 1) {
459            m[i] = @as(u8, @truncate(i % 254 + 1));
460        }
461        var st = Ghash.init(&key);
462        st.update(&m);
463        var out: [16]u8 = undefined;
464        st.final(&out);
465        try htest.assertEqual(tv.hash, &out);
466    }
467}
468
469test "polyval" {
470    const key = [_]u8{0x42} ** 16;
471    const m = [_]u8{0x69} ** 256;
472
473    var st = Polyval.init(&key);
474    st.update(&m);
475    var out: [16]u8 = undefined;
476    st.final(&out);
477    try htest.assertEqual("0713c82b170eef25c8955ddf72c85ccb", &out);
478
479    st = Polyval.init(&key);
480    st.update(m[0..100]);
481    st.update(m[100..]);
482    st.final(&out);
483    try htest.assertEqual("0713c82b170eef25c8955ddf72c85ccb", &out);
484}