Commit 72064eba23

Frank Denis <github@pureftpd.org>
2020-10-23 16:18:35
std/crypto: vectorize BLAKE3
Gives a ~40% speedup on x86_64. However, the generic code remains faster on aarch64. This is still processing only one block at a time for now. I'm pretty confident that processing more blocks per round will eventually give a substantial performance improvement on all platforms with vector units.
1 parent 1b4ab74
Changed files (1)
lib
std
crypto
lib/std/crypto/blake3.zig
@@ -11,6 +11,7 @@ const fmt = std.fmt;
 const math = std.math;
 const mem = std.mem;
 const testing = std.testing;
+const Vector = std.meta.Vector;
 
 const ChunkIterator = struct {
     slice: []u8,
@@ -61,87 +62,173 @@ const KEYED_HASH: u8 = 1 << 4;
 const DERIVE_KEY_CONTEXT: u8 = 1 << 5;
 const DERIVE_KEY_MATERIAL: u8 = 1 << 6;
 
-// The mixing function, G, which mixes either a column or a diagonal.
-fn g(state: *[16]u32, a: usize, b: usize, c: usize, d: usize, mx: u32, my: u32) void {
-    _ = @addWithOverflow(u32, state[a], state[b], &state[a]);
-    _ = @addWithOverflow(u32, state[a], mx, &state[a]);
-    state[d] = math.rotr(u32, state[d] ^ state[a], 16);
-    _ = @addWithOverflow(u32, state[c], state[d], &state[c]);
-    state[b] = math.rotr(u32, state[b] ^ state[c], 12);
-    _ = @addWithOverflow(u32, state[a], state[b], &state[a]);
-    _ = @addWithOverflow(u32, state[a], my, &state[a]);
-    state[d] = math.rotr(u32, state[d] ^ state[a], 8);
-    _ = @addWithOverflow(u32, state[c], state[d], &state[c]);
-    state[b] = math.rotr(u32, state[b] ^ state[c], 7);
-}
+const CompressVectorized = struct {
+    const Lane = Vector(4, u32);
+    const Rows = [4]Lane;
 
-fn round(state: *[16]u32, msg: [16]u32, schedule: [16]u8) void {
-    // Mix the columns.
-    g(state, 0, 4, 8, 12, msg[schedule[0]], msg[schedule[1]]);
-    g(state, 1, 5, 9, 13, msg[schedule[2]], msg[schedule[3]]);
-    g(state, 2, 6, 10, 14, msg[schedule[4]], msg[schedule[5]]);
-    g(state, 3, 7, 11, 15, msg[schedule[6]], msg[schedule[7]]);
-
-    // Mix the diagonals.
-    g(state, 0, 5, 10, 15, msg[schedule[8]], msg[schedule[9]]);
-    g(state, 1, 6, 11, 12, msg[schedule[10]], msg[schedule[11]]);
-    g(state, 2, 7, 8, 13, msg[schedule[12]], msg[schedule[13]]);
-    g(state, 3, 4, 9, 14, msg[schedule[14]], msg[schedule[15]]);
-}
+    inline fn rot(x: Lane, comptime n: u5) Lane {
+        return (x >> @splat(4, @as(u5, n))) | (x << @splat(4, @as(u5, 1 +% ~n)));
+    }
 
-fn compress(
-    chaining_value: [8]u32,
-    block_words: [16]u32,
-    block_len: u32,
-    counter: u64,
-    flags: u8,
-) [16]u32 {
-    var state = [16]u32{
-        chaining_value[0],
-        chaining_value[1],
-        chaining_value[2],
-        chaining_value[3],
-        chaining_value[4],
-        chaining_value[5],
-        chaining_value[6],
-        chaining_value[7],
-        IV[0],
-        IV[1],
-        IV[2],
-        IV[3],
-        @truncate(u32, counter),
-        @truncate(u32, counter >> 32),
-        block_len,
-        flags,
-    };
-    for (MSG_SCHEDULE) |schedule| {
-        round(&state, block_words, schedule);
+    inline fn g(comptime even: bool, rows: *Rows, m: Lane) void {
+        rows[0] +%= rows[1] +% m;
+        rows[3] ^= rows[0];
+        rows[3] = rot(rows[3], if (even) 8 else 16);
+        rows[2] +%= rows[3];
+        rows[1] ^= rows[2];
+        rows[1] = rot(rows[1], if (even) 7 else 12);
     }
-    for (chaining_value) |_, i| {
-        state[i] ^= state[i + 8];
-        state[i + 8] ^= chaining_value[i];
+
+    inline fn diagonalize(rows: *Rows) void {
+        rows[0] = @shuffle(u32, rows[0], undefined, [_]i32{ 3, 0, 1, 2 });
+        rows[3] = @shuffle(u32, rows[3], undefined, [_]i32{ 2, 3, 0, 1 });
+        rows[2] = @shuffle(u32, rows[2], undefined, [_]i32{ 1, 2, 3, 0 });
     }
-    return state;
-}
+
+    inline fn undiagonalize(rows: *Rows) void {
+        rows[0] = @shuffle(u32, rows[0], undefined, [_]i32{ 1, 2, 3, 0 });
+        rows[3] = @shuffle(u32, rows[3], undefined, [_]i32{ 2, 3, 0, 1 });
+        rows[2] = @shuffle(u32, rows[2], undefined, [_]i32{ 3, 0, 1, 2 });
+    }
+
+    fn compress(
+        chaining_value: [8]u32,
+        block_words: [16]u32,
+        block_len: u32,
+        counter: u64,
+        flags: u8,
+    ) [16]u32 {
+        const md = Lane{ @truncate(u32, counter), @truncate(u32, counter >> 32), block_len, @as(u32, flags) };
+        var rows = Rows{ chaining_value[0..4].*, chaining_value[4..8].*, IV[0..4].*, md };
+
+        var m = Rows{ block_words[0..4].*, block_words[4..8].*, block_words[8..12].*, block_words[12..16].* };
+        var t0 = @shuffle(u32, m[0], m[1], [_]i32{ 0, 2, (-1 - 0), (-1 - 2) });
+        g(false, &rows, t0);
+        var t1 = @shuffle(u32, m[0], m[1], [_]i32{ 1, 3, (-1 - 1), (-1 - 3) });
+        g(true, &rows, t1);
+        diagonalize(&rows);
+        var t2 = @shuffle(u32, m[2], m[3], [_]i32{ 0, 2, (-1 - 0), (-1 - 2) });
+        t2 = @shuffle(u32, t2, undefined, [_]i32{ 3, 0, 1, 2 });
+        g(false, &rows, t2);
+        var t3 = @shuffle(u32, m[2], m[3], [_]i32{ 1, 3, (-1 - 1), (-1 - 3) });
+        t3 = @shuffle(u32, t3, undefined, [_]i32{ 3, 0, 1, 2 });
+        g(true, &rows, t3);
+        undiagonalize(&rows);
+        m = Rows{ t0, t1, t2, t3 };
+
+        var i: usize = 0;
+        while (i < 6) : (i += 1) {
+            t0 = @shuffle(u32, m[0], m[1], [_]i32{ 2, 1, (-1 - 1), (-1 - 3) });
+            t0 = @shuffle(u32, t0, undefined, [_]i32{ 1, 2, 3, 0 });
+            g(false, &rows, t0);
+            t1 = @shuffle(u32, m[2], m[3], [_]i32{ 2, 2, (-1 - 3), (-1 - 3) });
+            var tt = @shuffle(u32, m[0], undefined, [_]i32{ 3, 3, 0, 0 });
+            t1 = @shuffle(u32, tt, t1, [_]i32{ 0, (-1 - 1), 2, (-1 - 3) });
+            g(true, &rows, t1);
+            diagonalize(&rows);
+            t2 = @shuffle(u32, m[3], m[1], [_]i32{ 0, 1, (-1 - 0), (-1 - 1) });
+            tt = @shuffle(u32, t2, m[2], [_]i32{ 0, 1, 2, (-1 - 3) });
+            t2 = @shuffle(u32, tt, undefined, [_]i32{ 0, 2, 3, 1 });
+            g(false, &rows, t2);
+            t3 = @shuffle(u32, m[1], m[3], [_]i32{ 2, (-1 - 2), 3, (-1 - 3) });
+            tt = @shuffle(u32, m[2], t3, [_]i32{ 0, (-1 - 0), 1, (-1 - 1) });
+            t3 = @shuffle(u32, tt, undefined, [_]i32{ 2, 3, 1, 0 });
+            g(true, &rows, t3);
+            undiagonalize(&rows);
+            m = Rows{ t0, t1, t2, t3 };
+        }
+
+        rows[0] ^= rows[2];
+        rows[1] ^= rows[3];
+        rows[2] ^= Vector(4, u32){ chaining_value[0], chaining_value[1], chaining_value[2], chaining_value[3] };
+        rows[3] ^= Vector(4, u32){ chaining_value[4], chaining_value[5], chaining_value[6], chaining_value[7] };
+
+        return @bitCast([16]u32, rows);
+    }
+};
+
+const CompressGeneric = struct {
+    fn g(state: *[16]u32, comptime a: usize, comptime b: usize, comptime c: usize, comptime d: usize, mx: u32, my: u32) void {
+        state[a] +%= state[b] +% mx;
+        state[d] = math.rotr(u32, state[d] ^ state[a], 16);
+        state[c] +%= state[d];
+        state[b] = math.rotr(u32, state[b] ^ state[c], 12);
+        state[a] +%= state[b] +% my;
+        state[d] = math.rotr(u32, state[d] ^ state[a], 8);
+        state[c] +%= state[d];
+        state[b] = math.rotr(u32, state[b] ^ state[c], 7);
+    }
+
+    fn round(state: *[16]u32, msg: [16]u32, schedule: [16]u8) void {
+        // Mix the columns.
+        g(state, 0, 4, 8, 12, msg[schedule[0]], msg[schedule[1]]);
+        g(state, 1, 5, 9, 13, msg[schedule[2]], msg[schedule[3]]);
+        g(state, 2, 6, 10, 14, msg[schedule[4]], msg[schedule[5]]);
+        g(state, 3, 7, 11, 15, msg[schedule[6]], msg[schedule[7]]);
+
+        // Mix the diagonals.
+        g(state, 0, 5, 10, 15, msg[schedule[8]], msg[schedule[9]]);
+        g(state, 1, 6, 11, 12, msg[schedule[10]], msg[schedule[11]]);
+        g(state, 2, 7, 8, 13, msg[schedule[12]], msg[schedule[13]]);
+        g(state, 3, 4, 9, 14, msg[schedule[14]], msg[schedule[15]]);
+    }
+
+    fn compress(
+        chaining_value: [8]u32,
+        block_words: [16]u32,
+        block_len: u32,
+        counter: u64,
+        flags: u8,
+    ) [16]u32 {
+        var state = [16]u32{
+            chaining_value[0],
+            chaining_value[1],
+            chaining_value[2],
+            chaining_value[3],
+            chaining_value[4],
+            chaining_value[5],
+            chaining_value[6],
+            chaining_value[7],
+            IV[0],
+            IV[1],
+            IV[2],
+            IV[3],
+            @truncate(u32, counter),
+            @truncate(u32, counter >> 32),
+            block_len,
+            flags,
+        };
+        for (MSG_SCHEDULE) |schedule| {
+            round(&state, block_words, schedule);
+        }
+        for (chaining_value) |_, i| {
+            state[i] ^= state[i + 8];
+            state[i + 8] ^= chaining_value[i];
+        }
+        return state;
+    }
+};
+
+const compress = if (std.Target.current.cpu.arch == .x86_64) CompressVectorized.compress else CompressGeneric.compress;
 
 fn first8Words(words: [16]u32) [8]u32 {
     return @ptrCast(*const [8]u32, &words).*;
 }
 
-fn wordsFromLittleEndianBytes(words: []u32, bytes: []const u8) void {
-    var byte_slice = bytes;
-    for (words) |*word| {
-        word.* = mem.readIntSliceLittle(u32, byte_slice);
-        byte_slice = byte_slice[4..];
+fn wordsFromLittleEndianBytes(comptime count: usize, bytes: [count * 4]u8) [count]u32 {
+    var words: [count]u32 = undefined;
+    for (words) |*word, i| {
+        word.* = mem.readIntSliceLittle(u32, bytes[4 * i ..]);
     }
+    return words;
 }
 
 // Each chunk or parent node can produce either an 8-word chaining value or, by
 // setting the ROOT flag, any number of final output bytes. The Output struct
 // captures the state just prior to choosing between those two possibilities.
 const Output = struct {
-    input_chaining_value: [8]u32,
-    block_words: [16]u32,
+    input_chaining_value: [8]u32 align(16),
+    block_words: [16]u32 align(16),
     block_len: u32,
     counter: u64,
     flags: u8,
@@ -181,9 +268,9 @@ const Output = struct {
 };
 
 const ChunkState = struct {
-    chaining_value: [8]u32,
+    chaining_value: [8]u32 align(16),
     chunk_counter: u64,
-    block: [BLOCK_LEN]u8 = [_]u8{0} ** BLOCK_LEN,
+    block: [BLOCK_LEN]u8 align(16) = [_]u8{0} ** BLOCK_LEN,
     block_len: u8 = 0,
     blocks_compressed: u8 = 0,
     flags: u8,
@@ -218,8 +305,7 @@ const ChunkState = struct {
             // If the block buffer is full, compress it and clear it. More
             // input is coming, so this compression is not CHUNK_END.
             if (self.block_len == BLOCK_LEN) {
-                var block_words: [16]u32 = undefined;
-                wordsFromLittleEndianBytes(block_words[0..], self.block[0..]);
+                const block_words = wordsFromLittleEndianBytes(16, self.block);
                 self.chaining_value = first8Words(compress(
                     self.chaining_value,
                     block_words,
@@ -238,8 +324,7 @@ const ChunkState = struct {
     }
 
     fn output(self: *const ChunkState) Output {
-        var block_words: [16]u32 = undefined;
-        wordsFromLittleEndianBytes(block_words[0..], self.block[0..]);
+        const block_words = wordsFromLittleEndianBytes(16, self.block);
         return Output{
             .input_chaining_value = self.chaining_value,
             .block_words = block_words,
@@ -256,7 +341,7 @@ fn parentOutput(
     key: [8]u32,
     flags: u8,
 ) Output {
-    var block_words: [16]u32 = undefined;
+    var block_words: [16]u32 align(16) = undefined;
     mem.copy(u32, block_words[0..8], left_child_cv[0..]);
     mem.copy(u32, block_words[8..], right_child_cv[0..]);
     return Output{
@@ -303,8 +388,7 @@ pub const Blake3 = struct {
     /// Construct a new `Blake3` for the hash function, with an optional key
     pub fn init(options: Options) Blake3 {
         if (options.key) |key| {
-            var key_words: [8]u32 = undefined;
-            wordsFromLittleEndianBytes(key_words[0..], key[0..]);
+            const key_words = wordsFromLittleEndianBytes(8, key);
             return Blake3.init_internal(key_words, KEYED_HASH);
         } else {
             return Blake3.init_internal(IV, 0);
@@ -318,8 +402,7 @@ pub const Blake3 = struct {
         context_hasher.update(context);
         var context_key: [KEY_LEN]u8 = undefined;
         context_hasher.final(context_key[0..]);
-        var context_key_words: [8]u32 = undefined;
-        wordsFromLittleEndianBytes(context_key_words[0..], context_key[0..]);
+        const context_key_words = wordsFromLittleEndianBytes(8, context_key);
         return Blake3.init_internal(context_key_words, DERIVE_KEY_MATERIAL);
     }