Commit d5585bc650

Frank Denis <124872+jedisct1@users.noreply.github.com>
2025-11-01 07:40:03
Implement threaded BLAKE3 (#25587)
Allows BLAKE3 to be computed using multiple threads.
1 parent 5a38dd2
Changed files (2)
lib/std/crypto/benchmark.zig
@@ -35,6 +35,10 @@ const hashes = [_]Crypto{
     Crypto{ .ty = crypto.hash.Blake3, .name = "blake3" },
 };
 
+const parallel_hashes = [_]Crypto{
+    Crypto{ .ty = crypto.hash.Blake3, .name = "blake3-parallel" },
+};
+
 const block_size: usize = 8 * 8192;
 
 pub fn benchmarkHash(comptime Hash: anytype, comptime bytes: comptime_int) !u64 {
@@ -61,6 +65,25 @@ pub fn benchmarkHash(comptime Hash: anytype, comptime bytes: comptime_int) !u64
     return throughput;
 }
 
+pub fn benchmarkHashParallel(comptime Hash: anytype, comptime bytes: comptime_int, allocator: mem.Allocator, io: std.Io) !u64 {
+    const data: []u8 = try allocator.alloc(u8, bytes);
+    defer allocator.free(data);
+    random.bytes(data);
+
+    var timer = try Timer.start();
+    const start = timer.lap();
+    var final: [Hash.digest_length]u8 = undefined;
+    try Hash.hashParallel(data, &final, .{}, allocator, io);
+    std.mem.doNotOptimizeAway(final);
+
+    const end = timer.read();
+
+    const elapsed_s = @as(f64, @floatFromInt(end - start)) / time.ns_per_s;
+    const throughput = @as(u64, @intFromFloat(bytes / elapsed_s));
+
+    return throughput;
+}
+
 const macs = [_]Crypto{
     Crypto{ .ty = crypto.onetimeauth.Ghash, .name = "ghash" },
     Crypto{ .ty = crypto.onetimeauth.Polyval, .name = "polyval" },
@@ -512,6 +535,18 @@ pub fn main() !void {
         }
     }
 
+    var io_threaded = std.Io.Threaded.init(arena_allocator);
+    defer io_threaded.deinit();
+    const io = io_threaded.io();
+
+    inline for (parallel_hashes) |H| {
+        if (filter == null or std.mem.indexOf(u8, H.name, filter.?) != null) {
+            const throughput = try benchmarkHashParallel(H.ty, mode(128 * MiB), arena_allocator, io);
+            try stdout.print("{s:>17}: {:10} MiB/s\n", .{ H.name, throughput / (1 * MiB) });
+            try stdout.flush();
+        }
+    }
+
     inline for (macs) |M| {
         if (filter == null or std.mem.indexOf(u8, M.name, filter.?) != null) {
             const throughput = try benchmarkMac(M.ty, mode(128 * MiB));
lib/std/crypto/blake3.zig
@@ -2,6 +2,8 @@ const std = @import("std");
 const builtin = @import("builtin");
 const fmt = std.fmt;
 const mem = std.mem;
+const Io = std.Io;
+const Thread = std.Thread;
 
 const Vec4 = @Vector(4, u32);
 const Vec8 = @Vector(8, u32);
@@ -14,6 +16,11 @@ pub const simd_degree = std.simd.suggestVectorLength(u32) orelse 1;
 pub const max_simd_degree = simd_degree;
 const max_simd_degree_or_2 = if (max_simd_degree > 2) max_simd_degree else 2;
 
+/// Threshold for switching to parallel processing.
+/// Below this size, sequential hashing is used.
+/// Benchmarks generally show significant speedup starting at 3 MiB.
+const parallel_threshold = 3 * 1024 * 1024;
+
 const iv: [8]u32 = .{
     0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A,
     0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
@@ -666,6 +673,95 @@ fn leftSubtreeLen(input_len: usize) usize {
     return @intCast(roundDownToPowerOf2(full_chunks) * chunk_length);
 }
 
+const ChunkBatch = struct {
+    input: []const u8,
+    start_chunk: usize,
+    end_chunk: usize,
+    cvs: [][8]u32,
+    key: [8]u32,
+    flags: Flags,
+
+    fn process(ctx: ChunkBatch) void {
+        var cv_buffer: [max_simd_degree * Blake3.digest_length]u8 = undefined;
+        var chunk_idx = ctx.start_chunk;
+
+        while (chunk_idx < ctx.end_chunk) {
+            const remaining = ctx.end_chunk - chunk_idx;
+            const batch_size = @min(remaining, max_simd_degree);
+            const offset = chunk_idx * chunk_length;
+            const batch_len = @as(usize, batch_size) * chunk_length;
+
+            const num_cvs = compressChunksParallel(
+                ctx.input[offset..][0..batch_len],
+                ctx.key,
+                chunk_idx,
+                ctx.flags,
+                &cv_buffer,
+            );
+
+            for (0..num_cvs) |i| {
+                const cv_bytes = cv_buffer[i * Blake3.digest_length ..][0..Blake3.digest_length];
+                ctx.cvs[chunk_idx + i] = loadCvWords(cv_bytes.*);
+            }
+
+            chunk_idx += batch_size;
+        }
+    }
+};
+
+const ParentBatchContext = struct {
+    input_cvs: [][8]u32,
+    output_cvs: [][8]u32,
+    start_idx: usize,
+    end_idx: usize,
+    key: [8]u32,
+    flags: Flags,
+};
+
+fn processParentBatch(ctx: ParentBatchContext) void {
+    for (ctx.start_idx..ctx.end_idx) |i| {
+        const output = parentOutputFromCvs(ctx.input_cvs[i * 2], ctx.input_cvs[i * 2 + 1], ctx.key, ctx.flags);
+        ctx.output_cvs[i] = output.chainingValue();
+    }
+}
+
+fn buildMerkleTreeLayerParallel(
+    input_cvs: [][8]u32,
+    output_cvs: [][8]u32,
+    key: [8]u32,
+    flags: Flags,
+    io: Io,
+) void {
+    const num_parents = input_cvs.len / 2;
+
+    if (num_parents <= 16) {
+        for (0..num_parents) |i| {
+            const output = parentOutputFromCvs(input_cvs[i * 2], input_cvs[i * 2 + 1], key, flags);
+            output_cvs[i] = output.chainingValue();
+        }
+        return;
+    }
+
+    const num_workers = Thread.getCpuCount() catch 1;
+    const parents_per_worker = (num_parents + num_workers - 1) / num_workers;
+    var group: Io.Group = .init;
+
+    for (0..num_workers) |worker_id| {
+        const start_idx = worker_id * parents_per_worker;
+        if (start_idx >= num_parents) break;
+
+        group.async(io, processParentBatch, .{ParentBatchContext{
+            .input_cvs = input_cvs,
+            .output_cvs = output_cvs,
+            .start_idx = start_idx,
+            .end_idx = @min(start_idx + parents_per_worker, num_parents),
+            .key = key,
+            .flags = flags,
+        }});
+    }
+    group.wait(io);
+}
+
 fn parentOutput(parent_block: []const u8, key: [8]u32, flags: Flags) Output {
     var block: [Blake3.block_length]u8 = undefined;
     @memcpy(&block, parent_block[0..Blake3.block_length]);
@@ -705,7 +801,7 @@ const ChunkState = struct {
         return ChunkState{
             .cv = key,
             .chunk_counter = 0,
-            .buf = [_]u8{0} ** Blake3.block_length,
+            .buf = @splat(0),
             .buf_len = 0,
             .blocks_compressed = 0,
             .flags = flags,
@@ -716,7 +812,7 @@ const ChunkState = struct {
         self.cv = key;
         self.chunk_counter = chunk_counter;
         self.blocks_compressed = 0;
-        self.buf = [_]u8{0} ** Blake3.block_length;
+        self.buf = @splat(0);
         self.buf_len = 0;
     }
 
@@ -742,7 +838,7 @@ const ChunkState = struct {
             if (self.buf_len == Blake3.block_length) {
                 compressInPlace(&self.cv, &self.buf, Blake3.block_length, self.chunk_counter, self.flags.with(self.maybeStartFlag()));
                 self.blocks_compressed += 1;
-                self.buf = [_]u8{0} ** Blake3.block_length;
+                self.buf = @splat(0);
                 self.buf_len = 0;
             }
 
@@ -849,6 +945,90 @@ pub const Blake3 = struct {
         d.final(out);
     }
 
+    pub fn hashParallel(b: []const u8, out: []u8, options: Options, allocator: std.mem.Allocator, io: Io) !void {
+        if (b.len < parallel_threshold) {
+            return hash(b, out, options);
+        }
+
+        const key_words = if (options.key) |key| loadKeyWords(key) else iv;
+        const flags: Flags = if (options.key != null) .{ .keyed_hash = true } else .{};
+
+        const num_full_chunks = b.len / chunk_length;
+        const thread_count = Thread.getCpuCount() catch 1;
+        if (thread_count <= 1 or num_full_chunks == 0) {
+            return hash(b, out, options);
+        }
+
+        const cvs = try allocator.alloc([8]u32, num_full_chunks);
+        defer allocator.free(cvs);
+
+        // Process chunks in parallel
+        const num_workers = thread_count;
+        const chunks_per_worker = (num_full_chunks + num_workers - 1) / num_workers;
+        var group: Io.Group = .init;
+
+        for (0..num_workers) |worker_id| {
+            const start_chunk = worker_id * chunks_per_worker;
+            if (start_chunk >= num_full_chunks) break;
+
+            group.async(io, ChunkBatch.process, .{ChunkBatch{
+                .input = b,
+                .start_chunk = start_chunk,
+                .end_chunk = @min(start_chunk + chunks_per_worker, num_full_chunks),
+                .cvs = cvs,
+                .key = key_words,
+                .flags = flags,
+            }});
+        }
+        group.wait(io);
+
+        // Build Merkle tree in parallel layers using ping-pong buffers
+        const max_intermediate_size = (num_full_chunks + 1) / 2;
+        const buffer0 = try allocator.alloc([8]u32, max_intermediate_size);
+        defer allocator.free(buffer0);
+        const buffer1 = try allocator.alloc([8]u32, max_intermediate_size);
+        defer allocator.free(buffer1);
+
+        var current_level = cvs;
+        var next_level_buf = buffer0;
+        var toggle = false;
+
+        while (current_level.len > 8) {
+            const num_parents = current_level.len / 2;
+            const has_odd = current_level.len % 2 == 1;
+            const next_level_size = num_parents + @intFromBool(has_odd);
+
+            buildMerkleTreeLayerParallel(
+                current_level[0 .. num_parents * 2],
+                next_level_buf[0..num_parents],
+                key_words,
+                flags,
+                io,
+            );
+
+            if (has_odd) {
+                next_level_buf[num_parents] = current_level[current_level.len - 1];
+            }
+
+            current_level = next_level_buf[0..next_level_size];
+            next_level_buf = if (toggle) buffer0 else buffer1;
+            toggle = !toggle;
+        }
+
+        // Finalize remaining small tree sequentially
+        var hasher = init_internal(key_words, flags);
+        for (current_level, 0..) |cv, i| hasher.pushCv(cv, i);
+
+        hasher.chunk.chunk_counter = num_full_chunks;
+        const remaining_bytes = b.len % chunk_length;
+        if (remaining_bytes > 0) {
+            hasher.chunk.update(b[num_full_chunks * chunk_length ..]);
+            hasher.mergeCvStack(hasher.chunk.chunk_counter);
+        }
+
+        hasher.final(out);
+    }
+
     fn init_internal(key: [8]u32, flags: Flags) Blake3 {
         return Blake3{
             .key = key,
@@ -1182,3 +1362,48 @@ test "BLAKE3 reference test cases" {
         try testBlake3(derive_key, t.input_len, t.derive_key.*);
     }
 }
+
+test "BLAKE3 parallel vs sequential" {
+    const allocator = std.testing.allocator;
+    const io = std.testing.io;
+
+    // Test various sizes including those above the parallelization threshold
+    const test_sizes = [_]usize{
+        0, // Empty
+        64, // One block
+        1024, // One chunk
+        1024 * 10, // Multiple chunks
+        1024 * 100, // 100KB
+        1024 * 1000, // 1MB
+        1024 * 5000, // 5MB (above threshold)
+        1024 * 10000, // 10MB (above threshold)
+    };
+
+    for (test_sizes) |size| {
+        // Allocate and fill test data with a pattern
+        const input = try allocator.alloc(u8, size);
+        defer allocator.free(input);
+        for (input, 0..) |*byte, i| {
+            byte.* = @truncate(i);
+        }
+
+        // Test regular hash
+        var expected: [32]u8 = undefined;
+        Blake3.hash(input, &expected, .{});
+
+        var actual: [32]u8 = undefined;
+        try Blake3.hashParallel(input, &actual, .{}, allocator, io);
+
+        try std.testing.expectEqualSlices(u8, &expected, &actual);
+
+        // Test keyed hash
+        const key: [32]u8 = @splat(0x42);
+        var expected_keyed: [32]u8 = undefined;
+        Blake3.hash(input, &expected_keyed, .{ .key = key });
+
+        var actual_keyed: [32]u8 = undefined;
+        try Blake3.hashParallel(input, &actual_keyed, .{ .key = key }, allocator, io);
+
+        try std.testing.expectEqualSlices(u8, &expected_keyed, &actual_keyed);
+    }
+}