Commit 6669885aa2

Frank Denis <124872+jedisct1@users.noreply.github.com>
2025-10-15 14:03:56
Faster BLAKE3 implementation (#25574)
This is a rewrite of the BLAKE3 implementation, with vectorization. On Apple Silicon, the new implementation is about twice as fast as the previous one. With AVX2, it is more than 4 times faster. With AVX512, it is more than 7.5x faster than the previous implementation (from 678 MB/s to 5086 MB/s).
1 parent 70c21fd
Changed files (1)
lib
std
crypto
lib/std/crypto/blake3.zig
@@ -1,391 +1,833 @@
-// Translated from BLAKE3 reference implementation.
-// Source: https://github.com/BLAKE3-team/BLAKE3
-
-const std = @import("../std.zig");
+const std = @import("std");
 const builtin = @import("builtin");
 const fmt = std.fmt;
-const math = std.math;
 const mem = std.mem;
-const testing = std.testing;
 
-const ChunkIterator = struct {
-    slice: []u8,
-    chunk_len: usize,
+const Vec4 = @Vector(4, u32);
+const Vec8 = @Vector(8, u32);
+const Vec16 = @Vector(16, u32);
 
-    fn init(slice: []u8, chunk_len: usize) ChunkIterator {
-        return ChunkIterator{
-            .slice = slice,
-            .chunk_len = chunk_len,
-        };
+const chunk_length = 1024;
+const max_depth = 54;
+
+pub const simd_degree = std.simd.suggestVectorLength(u32) orelse 1;
+pub const max_simd_degree = simd_degree;
+const max_simd_degree_or_2 = if (max_simd_degree > 2) max_simd_degree else 2;
+
+const iv: [8]u32 = .{
+    0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A,
+    0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
+};
+
+const msg_schedule: [7][16]u8 = .{
+    .{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 },
+    .{ 2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8 },
+    .{ 3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1 },
+    .{ 10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6 },
+    .{ 12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4 },
+    .{ 9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7 },
+    .{ 11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13 },
+};
+
+const Flags = packed struct(u8) {
+    chunk_start: bool = false,
+    chunk_end: bool = false,
+    parent: bool = false,
+    root: bool = false,
+    keyed_hash: bool = false,
+    derive_key_context: bool = false,
+    derive_key_material: bool = false,
+    reserved: bool = false,
+
+    fn toInt(self: Flags) u8 {
+        return @bitCast(self);
     }
 
-    fn next(self: *ChunkIterator) ?[]u8 {
-        const next_chunk = self.slice[0..@min(self.chunk_len, self.slice.len)];
-        self.slice = self.slice[next_chunk.len..];
-        return if (next_chunk.len > 0) next_chunk else null;
+    fn with(self: Flags, other: Flags) Flags {
+        return @bitCast(self.toInt() | other.toInt());
     }
 };
 
-const OUT_LEN: usize = 32;
-const KEY_LEN: usize = 32;
-const BLOCK_LEN: usize = 64;
-const CHUNK_LEN: usize = 1024;
+const rotr = std.math.rotr;
 
-const IV = [8]u32{
-    0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
-};
+inline fn rotr32(w: u32, c: u5) u32 {
+    return rotr(u32, w, c);
+}
 
-const MSG_SCHEDULE = [7][16]u8{
-    [_]u8{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 },
-    [_]u8{ 2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8 },
-    [_]u8{ 3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1 },
-    [_]u8{ 10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6 },
-    [_]u8{ 12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4 },
-    [_]u8{ 9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7 },
-    [_]u8{ 11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13 },
-};
+inline fn load32(bytes: []const u8) u32 {
+    return mem.readInt(u32, bytes[0..4], .little);
+}
 
-// These are the internal flags that we use to domain separate root/non-root,
-// chunk/parent, and chunk beginning/middle/end. These get set at the high end
-// of the block flags word in the compression function, so their values start
-// high and go down.
-const CHUNK_START: u8 = 1 << 0;
-const CHUNK_END: u8 = 1 << 1;
-const PARENT: u8 = 1 << 2;
-const ROOT: u8 = 1 << 3;
-const KEYED_HASH: u8 = 1 << 4;
-const DERIVE_KEY_CONTEXT: u8 = 1 << 5;
-const DERIVE_KEY_MATERIAL: u8 = 1 << 6;
-
-const CompressVectorized = struct {
-    const Lane = @Vector(4, u32);
-    const Rows = [4]Lane;
-
-    fn g(comptime even: bool, rows: *Rows, m: Lane) void {
-        rows[0] +%= rows[1] +% m;
-        rows[3] ^= rows[0];
-        rows[3] = math.rotr(Lane, rows[3], if (even) 8 else 16);
-        rows[2] +%= rows[3];
-        rows[1] ^= rows[2];
-        rows[1] = math.rotr(Lane, rows[1], if (even) 7 else 12);
-    }
-
-    fn diagonalize(rows: *Rows) void {
-        rows[0] = @shuffle(u32, rows[0], undefined, [_]i32{ 3, 0, 1, 2 });
-        rows[3] = @shuffle(u32, rows[3], undefined, [_]i32{ 2, 3, 0, 1 });
-        rows[2] = @shuffle(u32, rows[2], undefined, [_]i32{ 1, 2, 3, 0 });
-    }
-
-    fn undiagonalize(rows: *Rows) void {
-        rows[0] = @shuffle(u32, rows[0], undefined, [_]i32{ 1, 2, 3, 0 });
-        rows[3] = @shuffle(u32, rows[3], undefined, [_]i32{ 2, 3, 0, 1 });
-        rows[2] = @shuffle(u32, rows[2], undefined, [_]i32{ 3, 0, 1, 2 });
-    }
-
-    fn compress(
-        chaining_value: [8]u32,
-        block_words: [16]u32,
-        block_len: u32,
-        counter: u64,
-        flags: u8,
-    ) [16]u32 {
-        const md = Lane{ @as(u32, @truncate(counter)), @as(u32, @truncate(counter >> 32)), block_len, @as(u32, flags) };
-        var rows = Rows{ chaining_value[0..4].*, chaining_value[4..8].*, IV[0..4].*, md };
-
-        var m = Rows{ block_words[0..4].*, block_words[4..8].*, block_words[8..12].*, block_words[12..16].* };
-        var t0 = @shuffle(u32, m[0], m[1], [_]i32{ 0, 2, (-1 - 0), (-1 - 2) });
-        g(false, &rows, t0);
-        var t1 = @shuffle(u32, m[0], m[1], [_]i32{ 1, 3, (-1 - 1), (-1 - 3) });
-        g(true, &rows, t1);
-        diagonalize(&rows);
-        var t2 = @shuffle(u32, m[2], m[3], [_]i32{ 0, 2, (-1 - 0), (-1 - 2) });
-        t2 = @shuffle(u32, t2, undefined, [_]i32{ 3, 0, 1, 2 });
-        g(false, &rows, t2);
-        var t3 = @shuffle(u32, m[2], m[3], [_]i32{ 1, 3, (-1 - 1), (-1 - 3) });
-        t3 = @shuffle(u32, t3, undefined, [_]i32{ 3, 0, 1, 2 });
-        g(true, &rows, t3);
-        undiagonalize(&rows);
-        m = Rows{ t0, t1, t2, t3 };
-
-        var i: usize = 0;
-        while (i < 6) : (i += 1) {
-            t0 = @shuffle(u32, m[0], m[1], [_]i32{ 2, 1, (-1 - 1), (-1 - 3) });
-            t0 = @shuffle(u32, t0, undefined, [_]i32{ 1, 2, 3, 0 });
-            g(false, &rows, t0);
-            t1 = @shuffle(u32, m[2], m[3], [_]i32{ 2, 2, (-1 - 3), (-1 - 3) });
-            var tt = @shuffle(u32, m[0], undefined, [_]i32{ 3, 3, 0, 0 });
-            t1 = @shuffle(u32, tt, t1, [_]i32{ 0, (-1 - 1), 2, (-1 - 3) });
-            g(true, &rows, t1);
-            diagonalize(&rows);
-            t2 = @shuffle(u32, m[3], m[1], [_]i32{ 0, 1, (-1 - 0), (-1 - 1) });
-            tt = @shuffle(u32, t2, m[2], [_]i32{ 0, 1, 2, (-1 - 3) });
-            t2 = @shuffle(u32, tt, undefined, [_]i32{ 0, 2, 3, 1 });
-            g(false, &rows, t2);
-            t3 = @shuffle(u32, m[1], m[3], [_]i32{ 2, (-1 - 2), 3, (-1 - 3) });
-            tt = @shuffle(u32, m[2], t3, [_]i32{ 0, (-1 - 0), 1, (-1 - 1) });
-            t3 = @shuffle(u32, tt, undefined, [_]i32{ 2, 3, 1, 0 });
-            g(true, &rows, t3);
-            undiagonalize(&rows);
-            m = Rows{ t0, t1, t2, t3 };
-        }
+inline fn store32(bytes: []u8, w: u32) void {
+    mem.writeInt(u32, bytes[0..4], w, .little);
+}
 
-        rows[0] ^= rows[2];
-        rows[1] ^= rows[3];
-        rows[2] ^= @Vector(4, u32){ chaining_value[0], chaining_value[1], chaining_value[2], chaining_value[3] };
-        rows[3] ^= @Vector(4, u32){ chaining_value[4], chaining_value[5], chaining_value[6], chaining_value[7] };
+fn loadKeyWords(key: [Blake3.key_length]u8) [8]u32 {
+    var key_words: [8]u32 = undefined;
+    for (0..8) |i| {
+        key_words[i] = load32(key[i * 4 ..][0..4]);
+    }
+    return key_words;
+}
 
-        return @as([16]u32, @bitCast(rows));
+fn storeCvWords(cv_words: [8]u32) [Blake3.digest_length]u8 {
+    var bytes: [Blake3.digest_length]u8 = undefined;
+    for (0..8) |i| {
+        store32(bytes[i * 4 ..][0..4], cv_words[i]);
     }
-};
+    return bytes;
+}
 
-const CompressGeneric = struct {
-    fn g(state: *[16]u32, comptime a: usize, comptime b: usize, comptime c: usize, comptime d: usize, mx: u32, my: u32) void {
-        state[a] +%= state[b] +% mx;
-        state[d] = math.rotr(u32, state[d] ^ state[a], 16);
-        state[c] +%= state[d];
-        state[b] = math.rotr(u32, state[b] ^ state[c], 12);
-        state[a] +%= state[b] +% my;
-        state[d] = math.rotr(u32, state[d] ^ state[a], 8);
-        state[c] +%= state[d];
-        state[b] = math.rotr(u32, state[b] ^ state[c], 7);
-    }
-
-    fn round(state: *[16]u32, msg: [16]u32, schedule: [16]u8) void {
-        // Mix the columns.
-        g(state, 0, 4, 8, 12, msg[schedule[0]], msg[schedule[1]]);
-        g(state, 1, 5, 9, 13, msg[schedule[2]], msg[schedule[3]]);
-        g(state, 2, 6, 10, 14, msg[schedule[4]], msg[schedule[5]]);
-        g(state, 3, 7, 11, 15, msg[schedule[6]], msg[schedule[7]]);
-
-        // Mix the diagonals.
-        g(state, 0, 5, 10, 15, msg[schedule[8]], msg[schedule[9]]);
-        g(state, 1, 6, 11, 12, msg[schedule[10]], msg[schedule[11]]);
-        g(state, 2, 7, 8, 13, msg[schedule[12]], msg[schedule[13]]);
-        g(state, 3, 4, 9, 14, msg[schedule[14]], msg[schedule[15]]);
-    }
-
-    fn compress(
-        chaining_value: [8]u32,
-        block_words: [16]u32,
-        block_len: u32,
-        counter: u64,
-        flags: u8,
-    ) [16]u32 {
-        var state = [16]u32{
-            chaining_value[0],
-            chaining_value[1],
-            chaining_value[2],
-            chaining_value[3],
-            chaining_value[4],
-            chaining_value[5],
-            chaining_value[6],
-            chaining_value[7],
-            IV[0],
-            IV[1],
-            IV[2],
-            IV[3],
-            @as(u32, @truncate(counter)),
-            @as(u32, @truncate(counter >> 32)),
-            block_len,
-            flags,
-        };
-        for (MSG_SCHEDULE) |schedule| {
-            round(&state, block_words, schedule);
+fn loadCvWords(bytes: [Blake3.digest_length]u8) [8]u32 {
+    var cv_words: [8]u32 = undefined;
+    for (0..8) |i| {
+        cv_words[i] = load32(bytes[i * 4 ..][0..4]);
+    }
+    return cv_words;
+}
+
+inline fn counterLow(counter: u64) u32 {
+    return @truncate(counter);
+}
+
+inline fn counterHigh(counter: u64) u32 {
+    return @truncate(counter >> 32);
+}
+
+fn highestOne(x: u64) u6 {
+    if (x == 0) return 0;
+    return @intCast(63 - @clz(x));
+}
+
+fn roundDownToPowerOf2(x: u64) u64 {
+    return @as(u64, 1) << highestOne(x | 1);
+}
+
+inline fn g(state: *[16]u32, a: usize, b: usize, c: usize, d: usize, x: u32, y: u32) void {
+    state[a] +%= state[b] +% x;
+    state[d] = rotr32(state[d] ^ state[a], 16);
+    state[c] +%= state[d];
+    state[b] = rotr32(state[b] ^ state[c], 12);
+    state[a] +%= state[b] +% y;
+    state[d] = rotr32(state[d] ^ state[a], 8);
+    state[c] +%= state[d];
+    state[b] = rotr32(state[b] ^ state[c], 7);
+}
+
+inline fn roundFn(state: *[16]u32, msg: *const [16]u32, round: usize) void {
+    const schedule = &msg_schedule[round];
+
+    g(state, 0, 4, 8, 12, msg[schedule[0]], msg[schedule[1]]);
+    g(state, 1, 5, 9, 13, msg[schedule[2]], msg[schedule[3]]);
+    g(state, 2, 6, 10, 14, msg[schedule[4]], msg[schedule[5]]);
+    g(state, 3, 7, 11, 15, msg[schedule[6]], msg[schedule[7]]);
+
+    g(state, 0, 5, 10, 15, msg[schedule[8]], msg[schedule[9]]);
+    g(state, 1, 6, 11, 12, msg[schedule[10]], msg[schedule[11]]);
+    g(state, 2, 7, 8, 13, msg[schedule[12]], msg[schedule[13]]);
+    g(state, 3, 4, 9, 14, msg[schedule[14]], msg[schedule[15]]);
+}
+
+fn compressPre(state: *[16]u32, cv: *const [8]u32, block: []const u8, block_len: u8, counter: u64, flags: Flags) void {
+    var block_words: [16]u32 = undefined;
+    for (0..16) |i| {
+        block_words[i] = load32(block[i * 4 ..][0..4]);
+    }
+
+    for (0..8) |i| {
+        state[i] = cv[i];
+    }
+    for (0..4) |i| {
+        state[i + 8] = iv[i];
+    }
+    state[12] = counterLow(counter);
+    state[13] = counterHigh(counter);
+    state[14] = @as(u32, block_len);
+    state[15] = @as(u32, flags.toInt());
+
+    for (0..7) |round| {
+        roundFn(state, &block_words, round);
+    }
+}
+
+fn compressInPlace(cv: *[8]u32, block: []const u8, block_len: u8, counter: u64, flags: Flags) void {
+    var state: [16]u32 = undefined;
+    compressPre(&state, cv, block, block_len, counter, flags);
+    for (0..8) |i| {
+        cv[i] = state[i] ^ state[i + 8];
+    }
+}
+
+fn compressXof(cv: *const [8]u32, block: []const u8, block_len: u8, counter: u64, flags: Flags, out: *[64]u8) void {
+    var state: [16]u32 = undefined;
+    compressPre(&state, cv, block, block_len, counter, flags);
+
+    for (0..8) |i| {
+        store32(out[i * 4 ..][0..4], state[i] ^ state[i + 8]);
+    }
+    for (0..8) |i| {
+        store32(out[(i + 8) * 4 ..][0..4], state[i + 8] ^ cv[i]);
+    }
+}
+
+fn hashOne(input: []const u8, blocks: usize, key: [8]u32, counter: u64, flags: Flags, flags_start: Flags, flags_end: Flags) [Blake3.digest_length]u8 {
+    var cv = key;
+    var block_flags = flags.with(flags_start);
+    var inp = input;
+    var remaining_blocks = blocks;
+
+    while (remaining_blocks > 0) {
+        if (remaining_blocks == 1) {
+            block_flags = block_flags.with(flags_end);
         }
-        for (chaining_value, 0..) |_, i| {
-            state[i] ^= state[i + 8];
-            state[i + 8] ^= chaining_value[i];
+        compressInPlace(&cv, inp[0..Blake3.block_length], Blake3.block_length, counter, block_flags);
+        inp = inp[Blake3.block_length..];
+        remaining_blocks -= 1;
+        block_flags = flags;
+    }
+
+    return storeCvWords(cv);
+}
+
+fn hashManyPortable(inputs: [][*]const u8, num_inputs: usize, blocks: usize, key: [8]u32, counter_arg: u64, increment_counter: bool, flags: Flags, flags_start: Flags, flags_end: Flags, out: []u8) void {
+    var counter = counter_arg;
+    for (0..num_inputs) |i| {
+        const input = inputs[i][0 .. blocks * Blake3.block_length];
+        const result = hashOne(input, blocks, key, counter, flags, flags_start, flags_end);
+        @memcpy(out[i * Blake3.digest_length ..][0..Blake3.digest_length], &result);
+        if (increment_counter) {
+            counter += 1;
         }
-        return state;
     }
-};
+}
 
-const compress = if (builtin.cpu.arch == .x86_64)
-    CompressVectorized.compress
-else
-    CompressGeneric.compress;
+fn transposeNxN(comptime Vec: type, comptime n: comptime_int, vecs: *[n]Vec) void {
+    const temp: [n]Vec = vecs.*;
 
-fn first8Words(words: [16]u32) [8]u32 {
-    return @as(*const [8]u32, @ptrCast(&words)).*;
+    inline for (0..n) |i| {
+        inline for (0..n) |j| {
+            vecs[i][j] = temp[j][i];
+        }
+    }
 }
 
-fn wordsFromLittleEndianBytes(comptime count: usize, bytes: [count * 4]u8) [count]u32 {
-    var words: [count]u32 = undefined;
-    for (&words, 0..) |*word, i| {
-        word.* = mem.readInt(u32, bytes[4 * i ..][0..4], .little);
+fn transposeMsg(comptime Vec: type, comptime n: comptime_int, inputs: [n][*]const u8, block_offset: usize, out: *[16]Vec) void {
+    const info = @typeInfo(Vec);
+    if (info != .vector) @compileError("transposeMsg requires a vector type");
+    if (info.vector.len != n) @compileError("vector width must match N");
+
+    var temp: [n][16]u32 = undefined;
+
+    for (0..n) |i| {
+        const block = inputs[i] + block_offset;
+        for (0..16) |j| {
+            temp[i][j] = load32(block[j * 4 ..][0..4]);
+        }
+    }
+
+    for (0..16) |j| {
+        var result: Vec = undefined;
+        inline for (0..n) |i| {
+            result[i] = temp[i][j];
+        }
+        out[j] = result;
     }
-    return words;
 }
 
-// Each chunk or parent node can produce either an 8-word chaining value or, by
-// setting the ROOT flag, any number of final output bytes. The Output struct
-// captures the state just prior to choosing between those two possibilities.
-const Output = struct {
-    input_chaining_value: [8]u32 align(16),
-    block_words: [16]u32 align(16),
-    block_len: u32,
+fn roundFnVec(comptime Vec: type, v: *[16]Vec, m: *const [16]Vec, r: usize) void {
+    const schedule = &msg_schedule[r];
+
+    // Column round - first half
+    inline for (0..4) |i| {
+        v[i] +%= m[schedule[i * 2]];
+    }
+    inline for (0..4) |i| {
+        v[i] +%= v[i + 4];
+    }
+    inline for (0..4) |i| {
+        v[i + 12] ^= v[i];
+    }
+    inline for (0..4) |i| {
+        v[i + 12] = rotr(Vec, v[i + 12], 16);
+    }
+    inline for (0..4) |i| {
+        v[i + 8] +%= v[i + 12];
+    }
+    inline for (0..4) |i| {
+        v[i + 4] ^= v[i + 8];
+    }
+    inline for (0..4) |i| {
+        v[i + 4] = rotr(Vec, v[i + 4], 12);
+    }
+
+    // Column round - second half
+    inline for (0..4) |i| {
+        v[i] +%= m[schedule[i * 2 + 1]];
+    }
+    inline for (0..4) |i| {
+        v[i] +%= v[i + 4];
+    }
+    inline for (0..4) |i| {
+        v[i + 12] ^= v[i];
+    }
+    inline for (0..4) |i| {
+        v[i + 12] = rotr(Vec, v[i + 12], 8);
+    }
+    inline for (0..4) |i| {
+        v[i + 8] +%= v[i + 12];
+    }
+    inline for (0..4) |i| {
+        v[i + 4] ^= v[i + 8];
+    }
+    inline for (0..4) |i| {
+        v[i + 4] = rotr(Vec, v[i + 4], 7);
+    }
+
+    // Diagonal round - first half
+    inline for (0..4) |i| {
+        v[i] +%= m[schedule[i * 2 + 8]];
+    }
+    const b_indices = [4]u8{ 5, 6, 7, 4 };
+    inline for (0..4) |i| {
+        v[i] +%= v[b_indices[i]];
+    }
+    const d_indices = [4]u8{ 15, 12, 13, 14 };
+    inline for (0..4) |i| {
+        v[d_indices[i]] ^= v[i];
+    }
+    inline for (0..4) |i| {
+        v[d_indices[i]] = rotr(Vec, v[d_indices[i]], 16);
+    }
+    const c_indices = [4]u8{ 10, 11, 8, 9 };
+    inline for (0..4) |i| {
+        v[c_indices[i]] +%= v[d_indices[i]];
+    }
+    inline for (0..4) |i| {
+        v[b_indices[i]] ^= v[c_indices[i]];
+    }
+    inline for (0..4) |i| {
+        v[b_indices[i]] = rotr(Vec, v[b_indices[i]], 12);
+    }
+
+    // Diagonal round - second half
+    inline for (0..4) |i| {
+        v[i] +%= m[schedule[i * 2 + 9]];
+    }
+    inline for (0..4) |i| {
+        v[i] +%= v[b_indices[i]];
+    }
+    inline for (0..4) |i| {
+        v[d_indices[i]] ^= v[i];
+    }
+    inline for (0..4) |i| {
+        v[d_indices[i]] = rotr(Vec, v[d_indices[i]], 8);
+    }
+    inline for (0..4) |i| {
+        v[c_indices[i]] +%= v[d_indices[i]];
+    }
+    inline for (0..4) |i| {
+        v[b_indices[i]] ^= v[c_indices[i]];
+    }
+    inline for (0..4) |i| {
+        v[b_indices[i]] = rotr(Vec, v[b_indices[i]], 7);
+    }
+}
+
+fn hashVec(
+    comptime Vec: type,
+    comptime n: comptime_int,
+    inputs: [n][*]const u8,
+    blocks: usize,
+    key: [8]u32,
     counter: u64,
-    flags: u8,
+    increment_counter: bool,
+    flags: Flags,
+    flags_start: Flags,
+    flags_end: Flags,
+    out: *[n * Blake3.digest_length]u8,
+) void {
+    var h_vecs: [8]Vec = undefined;
+    for (0..8) |i| {
+        h_vecs[i] = @splat(key[i]);
+    }
 
-    fn chainingValue(self: *const Output) [8]u32 {
-        return first8Words(compress(
-            self.input_chaining_value,
-            self.block_words,
-            self.block_len,
-            self.counter,
-            self.flags,
-        ));
-    }
-
-    fn rootOutputBytes(self: *const Output, output: []u8) void {
-        var out_block_it = ChunkIterator.init(output, 2 * OUT_LEN);
-        var output_block_counter: usize = 0;
-        while (out_block_it.next()) |out_block| {
-            const words = compress(
-                self.input_chaining_value,
-                self.block_words,
-                self.block_len,
-                output_block_counter,
-                self.flags | ROOT,
-            );
-            var out_word_it = ChunkIterator.init(out_block, 4);
-            var word_counter: usize = 0;
-            while (out_word_it.next()) |out_word| {
-                var word_bytes: [4]u8 = undefined;
-                mem.writeInt(u32, &word_bytes, words[word_counter], .little);
-                @memcpy(out_word, word_bytes[0..out_word.len]);
-                word_counter += 1;
+    const counter_low_vec = if (increment_counter) blk: {
+        var result: Vec = undefined;
+        inline for (0..n) |i| {
+            result[i] = counterLow(counter + i);
+        }
+        break :blk result;
+    } else @as(Vec, @splat(counterLow(counter)));
+
+    const counter_high_vec = if (increment_counter) blk: {
+        var result: Vec = undefined;
+        inline for (0..n) |i| {
+            result[i] = counterHigh(counter + i);
+        }
+        break :blk result;
+    } else @as(Vec, @splat(counterHigh(counter)));
+
+    var block_flags = flags.with(flags_start);
+
+    for (0..blocks) |block| {
+        if (block + 1 == blocks) {
+            block_flags = block_flags.with(flags_end);
+        }
+
+        const block_len_vec: Vec = @splat(Blake3.block_length);
+        const block_flags_vec: Vec = @splat(@as(u32, block_flags.toInt()));
+
+        var msg_vecs: [16]Vec = undefined;
+        transposeMsg(Vec, n, inputs, block * Blake3.block_length, &msg_vecs);
+
+        var v: [16]Vec = .{
+            h_vecs[0],       h_vecs[1],        h_vecs[2],     h_vecs[3],
+            h_vecs[4],       h_vecs[5],        h_vecs[6],     h_vecs[7],
+            @splat(iv[0]),   @splat(iv[1]),    @splat(iv[2]), @splat(iv[3]),
+            counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec,
+        };
+
+        inline for (0..7) |r| {
+            roundFnVec(Vec, &v, &msg_vecs, r);
+        }
+
+        inline for (0..8) |i| {
+            h_vecs[i] = v[i] ^ v[i + 8];
+        }
+
+        block_flags = flags;
+    }
+
+    // Output serialization - different strategies for different widths
+    switch (n) {
+        4 => {
+            // Special interleaved pattern for Vec4
+            var out_vecs = [4]Vec{ h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3] };
+            transposeNxN(Vec, 4, &out_vecs);
+            inline for (0..4) |i| {
+                mem.writeInt(u32, out[0 * 16 + i * 4 ..][0..4], out_vecs[0][i], .little);
             }
-            output_block_counter += 1;
+            inline for (0..4) |i| {
+                mem.writeInt(u32, out[2 * 16 + i * 4 ..][0..4], out_vecs[1][i], .little);
+            }
+            inline for (0..4) |i| {
+                mem.writeInt(u32, out[4 * 16 + i * 4 ..][0..4], out_vecs[2][i], .little);
+            }
+            inline for (0..4) |i| {
+                mem.writeInt(u32, out[6 * 16 + i * 4 ..][0..4], out_vecs[3][i], .little);
+            }
+
+            out_vecs = [4]Vec{ h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7] };
+            transposeNxN(Vec, 4, &out_vecs);
+            inline for (0..4) |i| {
+                mem.writeInt(u32, out[1 * 16 + i * 4 ..][0..4], out_vecs[0][i], .little);
+            }
+            inline for (0..4) |i| {
+                mem.writeInt(u32, out[3 * 16 + i * 4 ..][0..4], out_vecs[1][i], .little);
+            }
+            inline for (0..4) |i| {
+                mem.writeInt(u32, out[5 * 16 + i * 4 ..][0..4], out_vecs[2][i], .little);
+            }
+            inline for (0..4) |i| {
+                mem.writeInt(u32, out[7 * 16 + i * 4 ..][0..4], out_vecs[3][i], .little);
+            }
+        },
+        8 => {
+            // Linear pattern with transpose for Vec8
+            var out_vecs = [8]Vec{ h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3], h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7] };
+            transposeNxN(Vec, 8, &out_vecs);
+            inline for (0..8) |i| {
+                mem.writeInt(u32, out[0 * 32 + i * 4 ..][0..4], out_vecs[0][i], .little);
+            }
+            inline for (0..8) |i| {
+                mem.writeInt(u32, out[1 * 32 + i * 4 ..][0..4], out_vecs[1][i], .little);
+            }
+            inline for (0..8) |i| {
+                mem.writeInt(u32, out[2 * 32 + i * 4 ..][0..4], out_vecs[2][i], .little);
+            }
+            inline for (0..8) |i| {
+                mem.writeInt(u32, out[3 * 32 + i * 4 ..][0..4], out_vecs[3][i], .little);
+            }
+            inline for (0..8) |i| {
+                mem.writeInt(u32, out[4 * 32 + i * 4 ..][0..4], out_vecs[4][i], .little);
+            }
+            inline for (0..8) |i| {
+                mem.writeInt(u32, out[5 * 32 + i * 4 ..][0..4], out_vecs[5][i], .little);
+            }
+            inline for (0..8) |i| {
+                mem.writeInt(u32, out[6 * 32 + i * 4 ..][0..4], out_vecs[6][i], .little);
+            }
+            inline for (0..8) |i| {
+                mem.writeInt(u32, out[7 * 32 + i * 4 ..][0..4], out_vecs[7][i], .little);
+            }
+        },
+        16 => {
+            // Direct lane-by-lane output for Vec16 (no transpose)
+            inline for (0..16) |lane| {
+                const hash_offset = lane * Blake3.digest_length;
+                inline for (0..8) |word_idx| {
+                    const word = h_vecs[word_idx][lane];
+                    out[hash_offset + word_idx * 4 + 0] = @truncate(word);
+                    out[hash_offset + word_idx * 4 + 1] = @truncate(word >> 8);
+                    out[hash_offset + word_idx * 4 + 2] = @truncate(word >> 16);
+                    out[hash_offset + word_idx * 4 + 3] = @truncate(word >> 24);
+                }
+            }
+        },
+        else => @compileError("Unsupported SIMD width"),
+    }
+}
+
+fn hashManySimd(
+    inputs: [][*]const u8,
+    num_inputs: usize,
+    blocks: usize,
+    key: [8]u32,
+    counter: u64,
+    increment_counter: bool,
+    flags: Flags,
+    flags_start: Flags,
+    flags_end: Flags,
+    out: []u8,
+) void {
+    var remaining = num_inputs;
+    var inp = inputs.ptr;
+    var out_ptr = out.ptr;
+    var cnt = counter;
+
+    const simd_deg = comptime simd_degree;
+
+    if (comptime simd_deg >= 16) {
+        while (remaining >= 16) {
+            const sixteen_inputs = [16][*]const u8{
+                inp[0],  inp[1],  inp[2],  inp[3],
+                inp[4],  inp[5],  inp[6],  inp[7],
+                inp[8],  inp[9],  inp[10], inp[11],
+                inp[12], inp[13], inp[14], inp[15],
+            };
+
+            var simd_out: [16 * Blake3.digest_length]u8 = undefined;
+            hashVec(Vec16, 16, sixteen_inputs, blocks, key, cnt, increment_counter, flags, flags_start, flags_end, &simd_out);
+
+            @memcpy(out_ptr[0 .. 16 * Blake3.digest_length], &simd_out);
+
+            if (increment_counter) cnt += 16;
+            inp += 16;
+            remaining -= 16;
+            out_ptr += 16 * Blake3.digest_length;
         }
     }
-};
+
+    if (comptime simd_deg >= 8) {
+        while (remaining >= 8) {
+            const eight_inputs = [8][*]const u8{
+                inp[0], inp[1], inp[2], inp[3],
+                inp[4], inp[5], inp[6], inp[7],
+            };
+
+            var simd_out: [8 * Blake3.digest_length]u8 = undefined;
+            hashVec(Vec8, 8, eight_inputs, blocks, key, cnt, increment_counter, flags, flags_start, flags_end, &simd_out);
+
+            @memcpy(out_ptr[0 .. 8 * Blake3.digest_length], &simd_out);
+
+            if (increment_counter) cnt += 8;
+            inp += 8;
+            remaining -= 8;
+            out_ptr += 8 * Blake3.digest_length;
+        }
+    }
+
+    if (comptime simd_deg >= 4) {
+        while (remaining >= 4) {
+            const four_inputs = [4][*]const u8{
+                inp[0],
+                inp[1],
+                inp[2],
+                inp[3],
+            };
+
+            var simd_out: [4 * Blake3.digest_length]u8 = undefined;
+            hashVec(Vec4, 4, four_inputs, blocks, key, cnt, increment_counter, flags, flags_start, flags_end, &simd_out);
+
+            @memcpy(out_ptr[0 .. 4 * Blake3.digest_length], &simd_out);
+
+            if (increment_counter) cnt += 4;
+            inp += 4;
+            remaining -= 4;
+            out_ptr += 4 * Blake3.digest_length;
+        }
+    }
+
+    if (remaining > 0) {
+        hashManyPortable(inp[0..remaining], remaining, blocks, key, cnt, increment_counter, flags, flags_start, flags_end, out_ptr[0 .. remaining * Blake3.digest_length]);
+    }
+}
+
+fn hashMany(inputs: [][*]const u8, num_inputs: usize, blocks: usize, key: [8]u32, counter: u64, increment_counter: bool, flags: Flags, flags_start: Flags, flags_end: Flags, out: []u8) void {
+    if (comptime max_simd_degree >= 4) {
+        hashManySimd(inputs, num_inputs, blocks, key, counter, increment_counter, flags, flags_start, flags_end, out);
+    } else {
+        hashManyPortable(inputs, num_inputs, blocks, key, counter, increment_counter, flags, flags_start, flags_end, out);
+    }
+}
+
+fn compressChunksParallel(input: []const u8, key: [8]u32, chunk_counter: u64, flags: Flags, out: []u8) usize {
+    var chunks_array: [max_simd_degree][*]const u8 = undefined;
+    var input_position: usize = 0;
+    var chunks_array_len: usize = 0;
+
+    while (input.len - input_position >= chunk_length) {
+        chunks_array[chunks_array_len] = input[input_position..].ptr;
+        input_position += chunk_length;
+        chunks_array_len += 1;
+    }
+
+    hashMany(chunks_array[0..chunks_array_len], chunks_array_len, chunk_length / Blake3.block_length, key, chunk_counter, true, flags, .{ .chunk_start = true }, .{ .chunk_end = true }, out);
+
+    if (input.len > input_position) {
+        const counter = chunk_counter + @as(u64, chunks_array_len);
+        var chunk_state = ChunkState.init(key, flags);
+        chunk_state.chunk_counter = counter;
+        chunk_state.update(input[input_position..]);
+        const output = chunk_state.output();
+        const cv = output.chainingValue();
+        const cv_bytes = storeCvWords(cv);
+        @memcpy(out[chunks_array_len * Blake3.digest_length ..][0..Blake3.digest_length], &cv_bytes);
+        return chunks_array_len + 1;
+    } else {
+        return chunks_array_len;
+    }
+}
+
+fn compressParentsParallel(child_chaining_values: []const u8, num_chaining_values: usize, key: [8]u32, flags: Flags, out: []u8) usize {
+    var parents_array: [max_simd_degree_or_2][*]const u8 = undefined;
+    var parents_array_len: usize = 0;
+
+    while (num_chaining_values - (2 * parents_array_len) >= 2) {
+        parents_array[parents_array_len] = child_chaining_values[2 * parents_array_len * Blake3.digest_length ..].ptr;
+        parents_array_len += 1;
+    }
+
+    hashMany(parents_array[0..parents_array_len], parents_array_len, 1, key, 0, false, flags.with(.{ .parent = true }), .{}, .{}, out);
+
+    if (num_chaining_values > 2 * parents_array_len) {
+        @memcpy(out[parents_array_len * Blake3.digest_length ..][0..Blake3.digest_length], child_chaining_values[2 * parents_array_len * Blake3.digest_length ..][0..Blake3.digest_length]);
+        return parents_array_len + 1;
+    } else {
+        return parents_array_len;
+    }
+}
+
+fn compressSubtreeWide(input: []const u8, key: [8]u32, chunk_counter: u64, flags: Flags, out: []u8) usize {
+    if (input.len <= max_simd_degree * chunk_length) {
+        return compressChunksParallel(input, key, chunk_counter, flags, out);
+    }
+
+    const left_input_len = leftSubtreeLen(input.len);
+    const right_input = input[left_input_len..];
+    const right_chunk_counter = chunk_counter + @as(u64, left_input_len / chunk_length);
+
+    var cv_array: [2 * max_simd_degree_or_2 * Blake3.digest_length]u8 = undefined;
+    var degree: usize = max_simd_degree;
+    if (left_input_len > chunk_length and degree == 1) {
+        degree = 2;
+    }
+    const right_cvs = cv_array[degree * Blake3.digest_length ..];
+
+    const left_n = compressSubtreeWide(input[0..left_input_len], key, chunk_counter, flags, cv_array[0..]);
+    const right_n = compressSubtreeWide(right_input, key, right_chunk_counter, flags, right_cvs);
+
+    if (left_n == 1) {
+        @memcpy(out[0 .. 2 * Blake3.digest_length], cv_array[0 .. 2 * Blake3.digest_length]);
+        return 2;
+    }
+
+    const num_chaining_values = left_n + right_n;
+    return compressParentsParallel(&cv_array, num_chaining_values, key, flags, out);
+}
+
+fn compressSubtreeToParentNode(input: []const u8, key: [8]u32, chunk_counter: u64, flags: Flags, out: *[2 * Blake3.digest_length]u8) void {
+    var cv_array: [max_simd_degree_or_2 * Blake3.digest_length]u8 = undefined;
+    var num_cvs = compressSubtreeWide(input, key, chunk_counter, flags, &cv_array);
+
+    if (max_simd_degree_or_2 > 2) {
+        var out_array: [max_simd_degree_or_2 * Blake3.digest_length / 2]u8 = undefined;
+        while (num_cvs > 2) {
+            num_cvs = compressParentsParallel(&cv_array, num_cvs, key, flags, &out_array);
+            @memcpy(cv_array[0 .. num_cvs * Blake3.digest_length], out_array[0 .. num_cvs * Blake3.digest_length]);
+        }
+    }
+
+    @memcpy(out, cv_array[0 .. 2 * Blake3.digest_length]);
+}
+
+fn leftSubtreeLen(input_len: usize) usize {
+    const full_chunks = (input_len - 1) / chunk_length;
+    return @intCast(roundDownToPowerOf2(full_chunks) * chunk_length);
+}
+
+fn parentOutput(parent_block: []const u8, key: [8]u32, flags: Flags) Output {
+    var block: [Blake3.block_length]u8 = undefined;
+    @memcpy(&block, parent_block[0..Blake3.block_length]);
+    return Output{
+        .input_cv = key,
+        .block = block,
+        .block_len = Blake3.block_length,
+        .counter = 0,
+        .flags = flags.with(.{ .parent = true }),
+    };
+}
+
+fn parentOutputFromCvs(left_cv: [8]u32, right_cv: [8]u32, key: [8]u32, flags: Flags) Output {
+    var block: [Blake3.block_length]u8 align(16) = undefined;
+    for (0..8) |i| {
+        store32(block[i * 4 ..][0..4], left_cv[i]);
+        store32(block[(i + 8) * 4 ..][0..4], right_cv[i]);
+    }
+    return Output{
+        .input_cv = key,
+        .block = block,
+        .block_len = Blake3.block_length,
+        .counter = 0,
+        .flags = flags.with(.{ .parent = true }),
+    };
+}
 
 const ChunkState = struct {
-    chaining_value: [8]u32 align(16),
+    cv: [8]u32 align(16),
     chunk_counter: u64,
-    block: [BLOCK_LEN]u8 align(16) = [_]u8{0} ** BLOCK_LEN,
-    block_len: u8 = 0,
-    blocks_compressed: u8 = 0,
-    flags: u8,
+    buf: [Blake3.block_length]u8 align(16),
+    buf_len: u8,
+    blocks_compressed: u8,
+    flags: Flags,
 
-    fn init(key: [8]u32, chunk_counter: u64, flags: u8) ChunkState {
+    fn init(key: [8]u32, flags: Flags) ChunkState {
         return ChunkState{
-            .chaining_value = key,
-            .chunk_counter = chunk_counter,
+            .cv = key,
+            .chunk_counter = 0,
+            .buf = [_]u8{0} ** Blake3.block_length,
+            .buf_len = 0,
+            .blocks_compressed = 0,
             .flags = flags,
         };
     }
 
+    fn reset(self: *ChunkState, key: [8]u32, chunk_counter: u64) void {
+        self.cv = key;
+        self.chunk_counter = chunk_counter;
+        self.blocks_compressed = 0;
+        self.buf = [_]u8{0} ** Blake3.block_length;
+        self.buf_len = 0;
+    }
+
     fn len(self: *const ChunkState) usize {
-        return BLOCK_LEN * @as(usize, self.blocks_compressed) + @as(usize, self.block_len);
-    }
-
-    fn fillBlockBuf(self: *ChunkState, input: []const u8) []const u8 {
-        const want = BLOCK_LEN - self.block_len;
-        const take = @min(want, input.len);
-        @memcpy(self.block[self.block_len..][0..take], input[0..take]);
-        self.block_len += @as(u8, @truncate(take));
-        return input[take..];
-    }
-
-    fn startFlag(self: *const ChunkState) u8 {
-        return if (self.blocks_compressed == 0) CHUNK_START else 0;
-    }
-
-    fn update(self: *ChunkState, input_slice: []const u8) void {
-        var input = input_slice;
-        while (input.len > 0) {
-            // If the block buffer is full, compress it and clear it. More
-            // input is coming, so this compression is not CHUNK_END.
-            if (self.block_len == BLOCK_LEN) {
-                const block_words = wordsFromLittleEndianBytes(16, self.block);
-                self.chaining_value = first8Words(compress(
-                    self.chaining_value,
-                    block_words,
-                    BLOCK_LEN,
-                    self.chunk_counter,
-                    self.flags | self.startFlag(),
-                ));
+        return (Blake3.block_length * @as(usize, self.blocks_compressed)) + @as(usize, self.buf_len);
+    }
+
+    fn fillBuf(self: *ChunkState, input: []const u8) usize {
+        const take = @min(Blake3.block_length - @as(usize, self.buf_len), input.len);
+        @memcpy(self.buf[self.buf_len..][0..take], input[0..take]);
+        self.buf_len += @intCast(take);
+        return take;
+    }
+
+    fn maybeStartFlag(self: *const ChunkState) Flags {
+        return if (self.blocks_compressed == 0) .{ .chunk_start = true } else .{};
+    }
+
+    fn update(self: *ChunkState, input: []const u8) void {
+        var inp = input;
+
+        while (inp.len > 0) {
+            if (self.buf_len == Blake3.block_length) {
+                compressInPlace(&self.cv, &self.buf, Blake3.block_length, self.chunk_counter, self.flags.with(self.maybeStartFlag()));
                 self.blocks_compressed += 1;
-                self.block = [_]u8{0} ** BLOCK_LEN;
-                self.block_len = 0;
+                self.buf = [_]u8{0} ** Blake3.block_length;
+                self.buf_len = 0;
             }
 
-            // Copy input bytes into the block buffer.
-            input = self.fillBlockBuf(input);
+            const take = self.fillBuf(inp);
+            inp = inp[take..];
         }
     }
 
     fn output(self: *const ChunkState) Output {
-        const block_words = wordsFromLittleEndianBytes(16, self.block);
+        const block_flags = self.flags.with(self.maybeStartFlag()).with(.{ .chunk_end = true });
         return Output{
-            .input_chaining_value = self.chaining_value,
-            .block_words = block_words,
-            .block_len = self.block_len,
+            .input_cv = self.cv,
+            .block = self.buf,
+            .block_len = self.buf_len,
             .counter = self.chunk_counter,
-            .flags = self.flags | self.startFlag() | CHUNK_END,
+            .flags = block_flags,
         };
     }
 };
 
-fn parentOutput(
-    left_child_cv: [8]u32,
-    right_child_cv: [8]u32,
-    key: [8]u32,
-    flags: u8,
-) Output {
-    var block_words: [16]u32 align(16) = undefined;
-    block_words[0..8].* = left_child_cv;
-    block_words[8..].* = right_child_cv;
-    return Output{
-        .input_chaining_value = key,
-        .block_words = block_words,
-        .block_len = BLOCK_LEN, // Always BLOCK_LEN (64) for parent nodes.
-        .counter = 0, // Always 0 for parent nodes.
-        .flags = PARENT | flags,
-    };
-}
+const Output = struct {
+    input_cv: [8]u32 align(16),
+    block: [Blake3.block_length]u8 align(16),
+    block_len: u8,
+    counter: u64,
+    flags: Flags,
 
-fn parentCv(
-    left_child_cv: [8]u32,
-    right_child_cv: [8]u32,
-    key: [8]u32,
-    flags: u8,
-) [8]u32 {
-    return parentOutput(left_child_cv, right_child_cv, key, flags).chainingValue();
-}
+    fn chainingValue(self: *const Output) [8]u32 {
+        var cv_words = self.input_cv;
+        compressInPlace(&cv_words, &self.block, self.block_len, self.counter, self.flags);
+        return cv_words;
+    }
+
+    fn rootBytes(self: *const Output, seek: u64, out: []u8) void {
+        if (out.len == 0) return;
+
+        var output_block_counter = seek / 64;
+        const offset_within_block = @as(usize, @intCast(seek % 64));
+        var out_remaining = out;
+
+        if (offset_within_block > 0) {
+            var wide_buf: [64]u8 = undefined;
+            compressXof(&self.input_cv, &self.block, self.block_len, output_block_counter, self.flags.with(.{ .root = true }), &wide_buf);
+            const available_bytes = 64 - offset_within_block;
+            const bytes = @min(out_remaining.len, available_bytes);
+            @memcpy(out_remaining[0..bytes], wide_buf[offset_within_block..][0..bytes]);
+            out_remaining = out_remaining[bytes..];
+            output_block_counter += 1;
+        }
 
-/// An incremental hasher that can accept any number of writes.
+        while (out_remaining.len >= 64) {
+            compressXof(&self.input_cv, &self.block, self.block_len, output_block_counter, self.flags.with(.{ .root = true }), out_remaining[0..64]);
+            out_remaining = out_remaining[64..];
+            output_block_counter += 1;
+        }
+
+        if (out_remaining.len > 0) {
+            var wide_buf: [64]u8 = undefined;
+            compressXof(&self.input_cv, &self.block, self.block_len, output_block_counter, self.flags.with(.{ .root = true }), &wide_buf);
+            @memcpy(out_remaining, wide_buf[0..out_remaining.len]);
+        }
+    }
+};
+
+/// BLAKE3 is a cryptographic hash function that produces a 256-bit digest by default but also supports extendable output.
 pub const Blake3 = struct {
+    pub const block_length = 64;
+    pub const digest_length = 32;
+    pub const key_length = 32;
+
     pub const Options = struct { key: ?[digest_length]u8 = null };
     pub const KdfOptions = struct {};
 
-    chunk_state: ChunkState,
     key: [8]u32,
-    cv_stack: [54][8]u32 = undefined, // Space for 54 subtree chaining values:
-    cv_stack_len: u8 = 0, // 2^54 * CHUNK_LEN = 2^64
-    flags: u8,
-
-    pub const block_length = BLOCK_LEN;
-    pub const digest_length = OUT_LEN;
-    pub const key_length = KEY_LEN;
-
-    fn init_internal(key: [8]u32, flags: u8) Blake3 {
-        return Blake3{
-            .chunk_state = ChunkState.init(key, 0, flags),
-            .key = key,
-            .flags = flags,
-        };
-    }
+    chunk: ChunkState,
+    cv_stack_len: u8,
+    cv_stack: [max_depth + 1][8]u32,
 
     /// Construct a new `Blake3` for the hash function, with an optional key
     pub fn init(options: Options) Blake3 {
         if (options.key) |key| {
-            const key_words = wordsFromLittleEndianBytes(8, key);
-            return Blake3.init_internal(key_words, KEYED_HASH);
+            const key_words = loadKeyWords(key);
+            return init_internal(key_words, .{ .keyed_hash = true });
         } else {
-            return Blake3.init_internal(IV, 0);
+            return init_internal(iv, .{});
         }
     }
 
@@ -393,12 +835,12 @@ pub const Blake3 = struct {
     /// string should be hardcoded, globally unique, and application-specific.
     pub fn initKdf(context: []const u8, options: KdfOptions) Blake3 {
         _ = options;
-        var context_hasher = Blake3.init_internal(IV, DERIVE_KEY_CONTEXT);
+        var context_hasher = init_internal(iv, .{ .derive_key_context = true });
         context_hasher.update(context);
-        var context_key: [KEY_LEN]u8 = undefined;
-        context_hasher.final(context_key[0..]);
-        const context_key_words = wordsFromLittleEndianBytes(8, context_key);
-        return Blake3.init_internal(context_key_words, DERIVE_KEY_MATERIAL);
+        var context_key: [key_length]u8 = undefined;
+        context_hasher.final(&context_key);
+        const context_key_words = loadKeyWords(context_key);
+        return init_internal(context_key_words, .{ .derive_key_material = true });
     }
 
     pub fn hash(b: []const u8, out: []u8, options: Options) void {
@@ -407,78 +849,135 @@ pub const Blake3 = struct {
         d.final(out);
     }
 
-    fn pushCv(self: *Blake3, cv: [8]u32) void {
-        self.cv_stack[self.cv_stack_len] = cv;
-        self.cv_stack_len += 1;
+    fn init_internal(key: [8]u32, flags: Flags) Blake3 {
+        return Blake3{
+            .key = key,
+            .chunk = ChunkState.init(key, flags),
+            .cv_stack_len = 0,
+            .cv_stack = undefined,
+        };
     }
 
-    fn popCv(self: *Blake3) [8]u32 {
-        self.cv_stack_len -= 1;
-        return self.cv_stack[self.cv_stack_len];
-    }
-
-    // Section 5.1.2 of the BLAKE3 spec explains this algorithm in more detail.
-    fn addChunkChainingValue(self: *Blake3, first_cv: [8]u32, total_chunks: u64) void {
-        // This chunk might complete some subtrees. For each completed subtree,
-        // its left child will be the current top entry in the CV stack, and
-        // its right child will be the current value of `new_cv`. Pop each left
-        // child off the stack, merge it with `new_cv`, and overwrite `new_cv`
-        // with the result. After all these merges, push the final value of
-        // `new_cv` onto the stack. The number of completed subtrees is given
-        // by the number of trailing 0-bits in the new total number of chunks.
-        var new_cv = first_cv;
-        var chunk_counter = total_chunks;
-        while (chunk_counter & 1 == 0) {
-            new_cv = parentCv(self.popCv(), new_cv, self.key, self.flags);
-            chunk_counter >>= 1;
+    fn mergeCvStack(self: *Blake3, total_len: u64) void {
+        const post_merge_stack_len = @as(u8, @intCast(@popCount(total_len)));
+        while (self.cv_stack_len > post_merge_stack_len) {
+            const left_cv = self.cv_stack[self.cv_stack_len - 2];
+            const right_cv = self.cv_stack[self.cv_stack_len - 1];
+            const output = parentOutputFromCvs(left_cv, right_cv, self.key, self.chunk.flags);
+            const cv = output.chainingValue();
+            self.cv_stack[self.cv_stack_len - 2] = cv;
+            self.cv_stack_len -= 1;
         }
-        self.pushCv(new_cv);
+    }
+
+    fn pushCv(self: *Blake3, new_cv: [8]u32, chunk_counter: u64) void {
+        self.mergeCvStack(chunk_counter);
+        self.cv_stack[self.cv_stack_len] = new_cv;
+        self.cv_stack_len += 1;
     }
 
     /// Add input to the hash state. This can be called any number of times.
-    pub fn update(self: *Blake3, input_slice: []const u8) void {
-        var input = input_slice;
-        while (input.len > 0) {
-            // If the current chunk is complete, finalize it and reset the
-            // chunk state. More input is coming, so this chunk is not ROOT.
-            if (self.chunk_state.len() == CHUNK_LEN) {
-                const chunk_cv = self.chunk_state.output().chainingValue();
-                const total_chunks = self.chunk_state.chunk_counter + 1;
-                self.addChunkChainingValue(chunk_cv, total_chunks);
-                self.chunk_state = ChunkState.init(self.key, total_chunks, self.flags);
+    pub fn update(self: *Blake3, input: []const u8) void {
+        if (input.len == 0) return;
+
+        var inp = input;
+
+        if (self.chunk.len() > 0) {
+            const take = @min(chunk_length - self.chunk.len(), inp.len);
+            self.chunk.update(inp[0..take]);
+            inp = inp[take..];
+            if (inp.len > 0) {
+                const output = self.chunk.output();
+                const chunk_cv = output.chainingValue();
+                self.pushCv(chunk_cv, self.chunk.chunk_counter);
+                self.chunk.reset(self.key, self.chunk.chunk_counter + 1);
+            } else {
+                return;
             }
+        }
+
+        while (inp.len > chunk_length) {
+            var subtree_len = roundDownToPowerOf2(inp.len);
+            const count_so_far = self.chunk.chunk_counter * chunk_length;
 
-            // Compress input bytes into the current chunk state.
-            const want = CHUNK_LEN - self.chunk_state.len();
-            const take = @min(want, input.len);
-            self.chunk_state.update(input[0..take]);
-            input = input[take..];
+            while ((subtree_len - 1) & count_so_far != 0) {
+                subtree_len /= 2;
+            }
+
+            const subtree_chunks = subtree_len / chunk_length;
+            if (subtree_len <= chunk_length) {
+                var chunk_state = ChunkState.init(self.key, self.chunk.flags);
+                chunk_state.chunk_counter = self.chunk.chunk_counter;
+                chunk_state.update(inp[0..@intCast(subtree_len)]);
+                const output = chunk_state.output();
+                const cv = output.chainingValue();
+                self.pushCv(cv, chunk_state.chunk_counter);
+            } else {
+                var cv_pair: [2 * digest_length]u8 = undefined;
+                compressSubtreeToParentNode(inp[0..@intCast(subtree_len)], self.key, self.chunk.chunk_counter, self.chunk.flags, &cv_pair);
+                const left_cv = loadCvWords(cv_pair[0..digest_length].*);
+                const right_cv = loadCvWords(cv_pair[digest_length..][0..digest_length].*);
+                self.pushCv(left_cv, self.chunk.chunk_counter);
+                self.pushCv(right_cv, self.chunk.chunk_counter + (subtree_chunks / 2));
+            }
+            self.chunk.chunk_counter += subtree_chunks;
+            inp = inp[@intCast(subtree_len)..];
+        }
+
+        if (inp.len > 0) {
+            self.chunk.update(inp);
+            self.mergeCvStack(self.chunk.chunk_counter);
         }
     }
 
     /// Finalize the hash and write any number of output bytes.
-    pub fn final(self: *const Blake3, out_slice: []u8) void {
-        // Starting with the Output from the current chunk, compute all the
-        // parent chaining values along the right edge of the tree, until we
-        // have the root Output.
-        var output = self.chunk_state.output();
-        var parent_nodes_remaining: usize = self.cv_stack_len;
-        while (parent_nodes_remaining > 0) {
-            parent_nodes_remaining -= 1;
-            output = parentOutput(
-                self.cv_stack[parent_nodes_remaining],
-                output.chainingValue(),
-                self.key,
-                self.flags,
-            );
+    pub fn final(self: *const Blake3, out: []u8) void {
+        self.finalizeSeek(0, out);
+    }
+
+    /// Finalize the hash and write any number of output bytes, starting at a given seek position.
+    /// This is an XOF (extendable-output function) extension.
+    pub fn finalizeSeek(self: *const Blake3, seek: u64, out: []u8) void {
+        if (out.len == 0) return;
+
+        if (self.cv_stack_len == 0) {
+            const output = self.chunk.output();
+            output.rootBytes(seek, out);
+            return;
         }
-        output.rootOutputBytes(out_slice);
+
+        var output: Output = undefined;
+        var cvs_remaining: usize = undefined;
+
+        if (self.chunk.len() > 0) {
+            cvs_remaining = self.cv_stack_len;
+            output = self.chunk.output();
+        } else {
+            cvs_remaining = self.cv_stack_len - 2;
+            const left_cv = self.cv_stack[cvs_remaining];
+            const right_cv = self.cv_stack[cvs_remaining + 1];
+            output = parentOutputFromCvs(left_cv, right_cv, self.key, self.chunk.flags);
+        }
+
+        while (cvs_remaining > 0) {
+            cvs_remaining -= 1;
+            const left_cv = self.cv_stack[cvs_remaining];
+            const right_cv = output.chainingValue();
+            output = parentOutputFromCvs(left_cv, right_cv, self.key, self.chunk.flags);
+        }
+
+        output.rootBytes(seek, out);
+    }
+
+    pub fn reset(self: *Blake3) void {
+        self.chunk.reset(self.key, 0);
+        self.cv_stack_len = 0;
     }
 };
 
 // Use named type declarations to workaround crash with anonymous structs (issue #4373).
 const ReferenceTest = struct {
-    key: *const [KEY_LEN]u8,
+    key: *const [Blake3.key_length]u8,
     context_string: []const u8,
     cases: []const ReferenceTestCase,
 };
@@ -663,7 +1162,7 @@ fn testBlake3(hasher: *Blake3, input_len: usize, expected_hex: [262]u8) !void {
     // Compare to expected value
     var expected_bytes: [expected_hex.len / 2]u8 = undefined;
     _ = fmt.hexToBytes(expected_bytes[0..], expected_hex[0..]) catch unreachable;
-    try testing.expectEqual(actual_bytes, expected_bytes);
+    try std.testing.expectEqual(actual_bytes, expected_bytes);
 
     // Restore initial state
     hasher.* = initial_state;