Commit a72292513e

Andrew Kelley <andrew@ziglang.org>
2024-05-04 02:57:48
add std.Thread.Pool.spawnWg
This function accepts a WaitGroup parameter and manages the reference counting therein. It also is infallible. The existing `spawn` function is still handy when the job wants to further schedule more tasks.
1 parent a96b78c
Changed files (6)
lib
compiler
std
Thread
src
lib/compiler/build_runner.zig
@@ -466,10 +466,9 @@ fn runStepNames(
             const step = steps_slice[steps_slice.len - i - 1];
             if (step.state == .skipped_oom) continue;
 
-            wait_group.start();
-            thread_pool.spawn(workerMakeOneStep, .{
+            thread_pool.spawnWg(&wait_group, workerMakeOneStep, .{
                 &wait_group, &thread_pool, b, step, &step_prog, run,
-            }) catch @panic("OOM");
+            });
         }
     }
     assert(run.memory_blocked_steps.items.len == 0);
@@ -895,8 +894,6 @@ fn workerMakeOneStep(
     prog_node: *std.Progress.Node,
     run: *Run,
 ) void {
-    defer wg.finish();
-
     // First, check the conditions for running this step. If they are not met,
     // then we return without doing the step, relying on another worker to
     // queue this step up again when dependencies are met.
@@ -976,10 +973,9 @@ fn workerMakeOneStep(
 
         // Successful completion of a step, so we queue up its dependants as well.
         for (s.dependants.items) |dep| {
-            wg.start();
-            thread_pool.spawn(workerMakeOneStep, .{
+            thread_pool.spawnWg(wg, workerMakeOneStep, .{
                 wg, thread_pool, b, dep, prog_node, run,
-            }) catch @panic("OOM");
+            });
         }
     }
 
@@ -1002,10 +998,9 @@ fn workerMakeOneStep(
             if (dep.max_rss <= remaining) {
                 remaining -= dep.max_rss;
 
-                wg.start();
-                thread_pool.spawn(workerMakeOneStep, .{
+                thread_pool.spawnWg(wg, workerMakeOneStep, .{
                     wg, thread_pool, b, dep, prog_node, run,
-                }) catch @panic("OOM");
+                });
             } else {
                 run.memory_blocked_steps.items[i] = dep;
                 i += 1;
lib/std/Thread/Pool.zig
@@ -75,6 +75,65 @@ fn join(pool: *Pool, spawned: usize) void {
     pool.allocator.free(pool.threads);
 }
 
+/// Runs `func` in the thread pool, calling `WaitGroup.start` beforehand, and
+/// `WaitGroup.finish` after it returns.
+///
+/// In the case that queuing the function call fails to allocate memory, or the
+/// target is single-threaded, the function is called directly.
+pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args: anytype) void {
+    wait_group.start();
+
+    if (builtin.single_threaded) {
+        @call(.auto, func, args);
+        wait_group.finish();
+        return;
+    }
+
+    const Args = @TypeOf(args);
+    const Closure = struct {
+        arguments: Args,
+        pool: *Pool,
+        run_node: RunQueue.Node = .{ .data = .{ .runFn = runFn } },
+        wait_group: *WaitGroup,
+
+        fn runFn(runnable: *Runnable) void {
+            const run_node: *RunQueue.Node = @fieldParentPtr("data", runnable);
+            const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node));
+            @call(.auto, func, closure.arguments);
+            closure.wait_group.finish();
+
+            // The thread pool's allocator is protected by the mutex.
+            const mutex = &closure.pool.mutex;
+            mutex.lock();
+            defer mutex.unlock();
+
+            closure.pool.allocator.destroy(closure);
+        }
+    };
+
+    {
+        pool.mutex.lock();
+
+        const closure = pool.allocator.create(Closure) catch {
+            pool.mutex.unlock();
+            @call(.auto, func, args);
+            wait_group.finish();
+            return;
+        };
+        closure.* = .{
+            .arguments = args,
+            .pool = pool,
+            .wait_group = wait_group,
+        };
+
+        pool.run_queue.prepend(&closure.run_node);
+        pool.mutex.unlock();
+    }
+
+    // Notify waiting threads outside the lock to try and keep the critical section small.
+    pool.cond.signal();
+}
+
 pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void {
     if (builtin.single_threaded) {
         @call(.auto, func, args);
src/link/MachO/hasher.zig
@@ -36,14 +36,12 @@ pub fn ParallelHasher(comptime Hasher: type) type {
                         file_size - fstart
                     else
                         chunk_size;
-                    wg.start();
-                    try self.thread_pool.spawn(worker, .{
+                    self.thread_pool.spawnWg(&wg, worker, .{
                         file,
                         fstart,
                         buffer[fstart..][0..fsize],
                         &(out_buf.*),
                         &(result.*),
-                        &wg,
                     });
                 }
             }
@@ -56,9 +54,7 @@ pub fn ParallelHasher(comptime Hasher: type) type {
             buffer: []u8,
             out: *[hash_size]u8,
             err: *fs.File.PReadError!usize,
-            wg: *WaitGroup,
         ) void {
-            defer wg.finish();
             err.* = file.preadAll(buffer, fstart);
             Hasher.hash(buffer, out, .{});
         }
src/Package/Fetch.zig
@@ -722,14 +722,7 @@ fn queueJobsForDeps(f: *Fetch) RunError!void {
     const thread_pool = f.job_queue.thread_pool;
 
     for (new_fetches, prog_names) |*new_fetch, prog_name| {
-        f.job_queue.wait_group.start();
-        thread_pool.spawn(workerRun, .{ new_fetch, prog_name }) catch |err| switch (err) {
-            error.OutOfMemory => {
-                new_fetch.oom_flag = true;
-                f.job_queue.wait_group.finish();
-                continue;
-            },
-        };
+        thread_pool.spawnWg(&f.job_queue.wait_group, workerRun, .{ new_fetch, prog_name });
     }
 }
 
@@ -750,8 +743,6 @@ pub fn relativePathDigest(
 }
 
 pub fn workerRun(f: *Fetch, prog_name: []const u8) void {
-    defer f.job_queue.wait_group.finish();
-
     var prog_node = f.prog_node.start(prog_name, 0);
     defer prog_node.end();
     prog_node.activate();
@@ -1477,10 +1468,7 @@ fn computeHash(
                     .fs_path = fs_path,
                     .failure = undefined, // to be populated by the worker
                 };
-                wait_group.start();
-                try thread_pool.spawn(workerDeleteFile, .{
-                    root_dir, deleted_file, &wait_group,
-                });
+                thread_pool.spawnWg(&wait_group, workerDeleteFile, .{ root_dir, deleted_file });
                 try deleted_files.append(deleted_file);
                 continue;
             }
@@ -1507,10 +1495,7 @@ fn computeHash(
                 .hash = undefined, // to be populated by the worker
                 .failure = undefined, // to be populated by the worker
             };
-            wait_group.start();
-            try thread_pool.spawn(workerHashFile, .{
-                root_dir, hashed_file, &wait_group,
-            });
+            thread_pool.spawnWg(&wait_group, workerHashFile, .{ root_dir, hashed_file });
             try all_files.append(hashed_file);
         }
     }
@@ -1602,13 +1587,11 @@ fn dumpHashInfo(all_files: []const *const HashedFile) !void {
     try bw.flush();
 }
 
-fn workerHashFile(dir: fs.Dir, hashed_file: *HashedFile, wg: *WaitGroup) void {
-    defer wg.finish();
+fn workerHashFile(dir: fs.Dir, hashed_file: *HashedFile) void {
     hashed_file.failure = hashFileFallible(dir, hashed_file);
 }
 
-fn workerDeleteFile(dir: fs.Dir, deleted_file: *DeletedFile, wg: *WaitGroup) void {
-    defer wg.finish();
+fn workerDeleteFile(dir: fs.Dir, deleted_file: *DeletedFile) void {
     deleted_file.failure = deleteFileFallible(dir, deleted_file);
 }
 
src/Compilation.zig
@@ -3273,7 +3273,7 @@ pub fn performAllTheWork(
 
     if (!build_options.only_c and !build_options.only_core_functionality) {
         if (comp.docs_emit != null) {
-            try taskDocsCopy(comp, &comp.work_queue_wait_group);
+            comp.thread_pool.spawnWg(&comp.work_queue_wait_group, workerDocsCopy, .{comp});
             comp.work_queue_wait_group.spawnManager(workerDocsWasm, .{ comp, &wasm_prog_node });
         }
     }
@@ -3305,39 +3305,34 @@ pub fn performAllTheWork(
 
                 const file = mod.builtin_file orelse continue;
 
-                comp.astgen_wait_group.start();
-                try comp.thread_pool.spawn(workerUpdateBuiltinZigFile, .{
-                    comp, mod, file, &comp.astgen_wait_group,
+                comp.thread_pool.spawnWg(&comp.astgen_wait_group, workerUpdateBuiltinZigFile, .{
+                    comp, mod, file,
                 });
             }
         }
 
         while (comp.astgen_work_queue.readItem()) |file| {
-            comp.astgen_wait_group.start();
-            try comp.thread_pool.spawn(workerAstGenFile, .{
+            comp.thread_pool.spawnWg(&comp.astgen_wait_group, workerAstGenFile, .{
                 comp, file, &zir_prog_node, &comp.astgen_wait_group, .root,
             });
         }
 
         while (comp.embed_file_work_queue.readItem()) |embed_file| {
-            comp.astgen_wait_group.start();
-            try comp.thread_pool.spawn(workerCheckEmbedFile, .{
-                comp, embed_file, &comp.astgen_wait_group,
+            comp.thread_pool.spawnWg(&comp.astgen_wait_group, workerCheckEmbedFile, .{
+                comp, embed_file,
             });
         }
 
         while (comp.c_object_work_queue.readItem()) |c_object| {
-            comp.work_queue_wait_group.start();
-            try comp.thread_pool.spawn(workerUpdateCObject, .{
-                comp, c_object, &c_obj_prog_node, &comp.work_queue_wait_group,
+            comp.thread_pool.spawnWg(&comp.work_queue_wait_group, workerUpdateCObject, .{
+                comp, c_object, &c_obj_prog_node,
             });
         }
 
         if (!build_options.only_core_functionality) {
             while (comp.win32_resource_work_queue.readItem()) |win32_resource| {
-                comp.work_queue_wait_group.start();
-                try comp.thread_pool.spawn(workerUpdateWin32Resource, .{
-                    comp, win32_resource, &win32_resource_prog_node, &comp.work_queue_wait_group,
+                comp.thread_pool.spawnWg(&comp.work_queue_wait_group, workerUpdateWin32Resource, .{
+                    comp, win32_resource, &win32_resource_prog_node,
                 });
             }
         }
@@ -3680,14 +3675,7 @@ fn processOneJob(comp: *Compilation, job: Job, prog_node: *std.Progress.Node) !v
     }
 }
 
-fn taskDocsCopy(comp: *Compilation, wg: *WaitGroup) !void {
-    wg.start();
-    errdefer wg.finish();
-    try comp.thread_pool.spawn(workerDocsCopy, .{ comp, wg });
-}
-
-fn workerDocsCopy(comp: *Compilation, wg: *WaitGroup) void {
-    defer wg.finish();
+fn workerDocsCopy(comp: *Compilation) void {
     docsCopyFallible(comp) catch |err| {
         return comp.lockAndSetMiscFailure(
             .docs_copy,
@@ -3965,8 +3953,6 @@ fn workerAstGenFile(
     wg: *WaitGroup,
     src: AstGenSrc,
 ) void {
-    defer wg.finish();
-
     var child_prog_node = prog_node.start(file.sub_file_path, 0);
     child_prog_node.activate();
     defer child_prog_node.end();
@@ -4025,13 +4011,9 @@ fn workerAstGenFile(
                     .importing_file = file,
                     .import_tok = item.data.token,
                 } };
-                wg.start();
-                comp.thread_pool.spawn(workerAstGenFile, .{
+                comp.thread_pool.spawnWg(wg, workerAstGenFile, .{
                     comp, import_result.file, prog_node, wg, sub_src,
-                }) catch {
-                    wg.finish();
-                    continue;
-                };
+                });
             }
         }
     }
@@ -4041,9 +4023,7 @@ fn workerUpdateBuiltinZigFile(
     comp: *Compilation,
     mod: *Package.Module,
     file: *Module.File,
-    wg: *WaitGroup,
 ) void {
-    defer wg.finish();
     Builtin.populateFile(comp, mod, file) catch |err| {
         comp.mutex.lock();
         defer comp.mutex.unlock();
@@ -4054,13 +4034,7 @@ fn workerUpdateBuiltinZigFile(
     };
 }
 
-fn workerCheckEmbedFile(
-    comp: *Compilation,
-    embed_file: *Module.EmbedFile,
-    wg: *WaitGroup,
-) void {
-    defer wg.finish();
-
+fn workerCheckEmbedFile(comp: *Compilation, embed_file: *Module.EmbedFile) void {
     comp.detectEmbedFileUpdate(embed_file) catch |err| {
         comp.reportRetryableEmbedFileError(embed_file, err) catch |oom| switch (oom) {
             // Swallowing this error is OK because it's implied to be OOM when
@@ -4289,10 +4263,7 @@ fn workerUpdateCObject(
     comp: *Compilation,
     c_object: *CObject,
     progress_node: *std.Progress.Node,
-    wg: *WaitGroup,
 ) void {
-    defer wg.finish();
-
     comp.updateCObject(c_object, progress_node) catch |err| switch (err) {
         error.AnalysisFail => return,
         else => {
@@ -4309,10 +4280,7 @@ fn workerUpdateWin32Resource(
     comp: *Compilation,
     win32_resource: *Win32Resource,
     progress_node: *std.Progress.Node,
-    wg: *WaitGroup,
 ) void {
-    defer wg.finish();
-
     comp.updateWin32Resource(win32_resource, progress_node) catch |err| switch (err) {
         error.AnalysisFail => return,
         else => {
src/main.zig
@@ -5109,8 +5109,9 @@ fn cmdBuild(gpa: Allocator, arena: Allocator, args: []const []const u8) !void {
                     &fetch,
                 );
 
-                job_queue.wait_group.start();
-                try job_queue.thread_pool.spawn(Package.Fetch.workerRun, .{ &fetch, "root" });
+                job_queue.thread_pool.spawnWg(&job_queue.wait_group, Package.Fetch.workerRun, .{
+                    &fetch, "root",
+                });
                 job_queue.wait_group.wait();
 
                 try job_queue.consolidateErrors();