Commit e7caf3a54c
Changed files (3)
lib
std
lib/std/Io/EventLoop.zig
@@ -7,12 +7,13 @@ const EventLoop = @This();
const Alignment = std.mem.Alignment;
const IoUring = std.os.linux.IoUring;
+/// Must be a thread-safe allocator.
gpa: Allocator,
mutex: std.Thread.Mutex,
-queue: std.DoublyLinkedList(void),
+queue: std.DoublyLinkedList,
/// Atomic copy of queue.len
queue_len: u32,
-free: std.DoublyLinkedList(void),
+free: std.DoublyLinkedList,
main_fiber: Fiber,
idle_count: usize,
threads: std.ArrayListUnmanaged(Thread),
@@ -39,7 +40,7 @@ const Thread = struct {
const Fiber = struct {
context: Context,
awaiter: ?*Fiber,
- queue_node: std.DoublyLinkedList(void).Node,
+ queue_node: std.DoublyLinkedList.Node,
result_align: Alignment,
const finished: ?*Fiber = @ptrFromInt(std.mem.alignBackward(usize, std.math.maxInt(usize), @alignOf(Fiber)));
@@ -447,6 +448,15 @@ pub fn @"await"(userdata: ?*anyopaque, any_future: *std.Io.AnyFuture, result: []
event_loop.recycle(future_fiber);
}
+pub fn cancel(userdata: ?*anyopaque, any_future: *std.Io.AnyFuture, result: []u8) void {
+ const event_loop: *EventLoop = @alignCast(@ptrCast(userdata));
+ const future_fiber: *Fiber = @alignCast(@ptrCast(any_future));
+ // TODO set a flag that makes all IO operations for this fiber return error.Canceled
+ if (@atomicLoad(?*Fiber, &future_fiber.awaiter, .acquire) != Fiber.finished) event_loop.yield(null, .{ .register_awaiter = &future_fiber.awaiter });
+ @memcpy(result, future_fiber.resultPointer());
+ event_loop.recycle(future_fiber);
+}
+
pub fn createFile(userdata: ?*anyopaque, dir: std.fs.Dir, sub_path: []const u8, flags: std.fs.File.CreateFlags) std.fs.File.OpenError!std.fs.File {
const el: *EventLoop = @ptrCast(@alignCast(userdata));
lib/std/Thread/Pool.zig
@@ -1,22 +1,27 @@
const builtin = @import("builtin");
const std = @import("std");
+const Allocator = std.mem.Allocator;
const assert = std.debug.assert;
const WaitGroup = @import("WaitGroup.zig");
+const Io = std.Io;
const Pool = @This();
+/// Must be a thread-safe allocator.
+allocator: std.mem.Allocator,
mutex: std.Thread.Mutex = .{},
cond: std.Thread.Condition = .{},
run_queue: std.SinglyLinkedList = .{},
is_running: bool = true,
-/// Must be a thread-safe allocator.
-allocator: std.mem.Allocator,
-threads: if (builtin.single_threaded) [0]std.Thread else []std.Thread,
+threads: std.ArrayListUnmanaged(std.Thread),
ids: if (builtin.single_threaded) struct {
inline fn deinit(_: @This(), _: std.mem.Allocator) void {}
fn getIndex(_: @This(), _: std.Thread.Id) usize {
return 0;
}
} else std.AutoArrayHashMapUnmanaged(std.Thread.Id, void),
+stack_size: usize,
+
+threadlocal var current_closure: ?*AsyncClosure = null;
pub const Runnable = struct {
runFn: RunProto,
@@ -33,48 +38,36 @@ pub const Options = struct {
};
pub fn init(pool: *Pool, options: Options) !void {
- const allocator = options.allocator;
+ const gpa = options.allocator;
+ const thread_count = options.n_jobs orelse @max(1, std.Thread.getCpuCount() catch 1);
+ const threads = try gpa.alloc(std.Thread, thread_count);
+ errdefer gpa.free(threads);
pool.* = .{
- .allocator = allocator,
- .threads = if (builtin.single_threaded) .{} else &.{},
+ .allocator = gpa,
+ .threads = .initBuffer(threads),
.ids = .{},
+ .stack_size = options.stack_size,
};
- if (builtin.single_threaded) {
- return;
- }
+ if (builtin.single_threaded) return;
- const thread_count = options.n_jobs orelse @max(1, std.Thread.getCpuCount() catch 1);
if (options.track_ids) {
- try pool.ids.ensureTotalCapacity(allocator, 1 + thread_count);
+ try pool.ids.ensureTotalCapacity(gpa, 1 + thread_count);
pool.ids.putAssumeCapacityNoClobber(std.Thread.getCurrentId(), {});
}
-
- // kill and join any threads we spawned and free memory on error.
- pool.threads = try allocator.alloc(std.Thread, thread_count);
- var spawned: usize = 0;
- errdefer pool.join(spawned);
-
- for (pool.threads) |*thread| {
- thread.* = try std.Thread.spawn(.{
- .stack_size = options.stack_size,
- .allocator = allocator,
- }, worker, .{pool});
- spawned += 1;
- }
}
pub fn deinit(pool: *Pool) void {
- pool.join(pool.threads.len); // kill and join all threads.
- pool.ids.deinit(pool.allocator);
+ const gpa = pool.allocator;
+ pool.join();
+ pool.threads.deinit(gpa);
+ pool.ids.deinit(gpa);
pool.* = undefined;
}
-fn join(pool: *Pool, spawned: usize) void {
- if (builtin.single_threaded) {
- return;
- }
+fn join(pool: *Pool) void {
+ if (builtin.single_threaded) return;
{
pool.mutex.lock();
@@ -87,11 +80,7 @@ fn join(pool: *Pool, spawned: usize) void {
// wake up any sleeping threads (this can be done outside the mutex)
// then wait for all the threads we know are spawned to complete.
pool.cond.broadcast();
- for (pool.threads[0..spawned]) |thread| {
- thread.join();
- }
-
- pool.allocator.free(pool.threads);
+ for (pool.threads.items) |thread| thread.join();
}
/// Runs `func` in the thread pool, calling `WaitGroup.start` beforehand, and
@@ -123,26 +112,34 @@ pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args
}
};
- {
- pool.mutex.lock();
-
- const closure = pool.allocator.create(Closure) catch {
- pool.mutex.unlock();
- @call(.auto, func, args);
- wait_group.finish();
- return;
- };
- closure.* = .{
- .arguments = args,
- .pool = pool,
- .wait_group = wait_group,
- };
+ pool.mutex.lock();
- pool.run_queue.prepend(&closure.runnable.node);
+ const gpa = pool.allocator;
+ const closure = gpa.create(Closure) catch {
pool.mutex.unlock();
+ @call(.auto, func, args);
+ wait_group.finish();
+ return;
+ };
+ closure.* = .{
+ .arguments = args,
+ .pool = pool,
+ .wait_group = wait_group,
+ };
+
+ pool.run_queue.prepend(&closure.runnable.node);
+
+ if (pool.threads.items.len < pool.threads.capacity) {
+ pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{
+ .stack_size = pool.stack_size,
+ .allocator = gpa,
+ }, worker, .{pool}) catch t: {
+ pool.threads.items.len -= 1;
+ break :t undefined;
+ };
}
- // Notify waiting threads outside the lock to try and keep the critical section small.
+ pool.mutex.unlock();
pool.cond.signal();
}
@@ -179,31 +176,39 @@ pub fn spawnWgId(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, ar
}
};
- {
- pool.mutex.lock();
-
- const closure = pool.allocator.create(Closure) catch {
- const id: ?usize = pool.ids.getIndex(std.Thread.getCurrentId());
- pool.mutex.unlock();
- @call(.auto, func, .{id.?} ++ args);
- wait_group.finish();
- return;
- };
- closure.* = .{
- .arguments = args,
- .pool = pool,
- .wait_group = wait_group,
- };
+ pool.mutex.lock();
- pool.run_queue.prepend(&closure.runnable.node);
+ const gpa = pool.allocator;
+ const closure = gpa.create(Closure) catch {
+ const id: ?usize = pool.ids.getIndex(std.Thread.getCurrentId());
pool.mutex.unlock();
+ @call(.auto, func, .{id.?} ++ args);
+ wait_group.finish();
+ return;
+ };
+ closure.* = .{
+ .arguments = args,
+ .pool = pool,
+ .wait_group = wait_group,
+ };
+
+ pool.run_queue.prepend(&closure.runnable.node);
+
+ if (pool.threads.items.len < pool.threads.capacity) {
+ pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{
+ .stack_size = pool.stack_size,
+ .allocator = gpa,
+ }, worker, .{pool}) catch t: {
+ pool.threads.items.len -= 1;
+ break :t undefined;
+ };
}
- // Notify waiting threads outside the lock to try and keep the critical section small.
+ pool.mutex.unlock();
pool.cond.signal();
}
-pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void {
+pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) void {
if (builtin.single_threaded) {
@call(.auto, func, args);
return;
@@ -222,20 +227,32 @@ pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void {
}
};
- {
- pool.mutex.lock();
- defer pool.mutex.unlock();
+ pool.mutex.lock();
- const closure = try pool.allocator.create(Closure);
- closure.* = .{
- .arguments = args,
- .pool = pool,
- };
+ const gpa = pool.allocator;
+ const closure = gpa.create(Closure) catch {
+ pool.mutex.unlock();
+ @call(.auto, func, args);
+ return;
+ };
+ closure.* = .{
+ .arguments = args,
+ .pool = pool,
+ };
+
+ pool.run_queue.prepend(&closure.runnable.node);
- pool.run_queue.prepend(&closure.runnable.node);
+ if (pool.threads.items.len < pool.threads.capacity) {
+ pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{
+ .stack_size = pool.stack_size,
+ .allocator = gpa,
+ }, worker, .{pool}) catch t: {
+ pool.threads.items.len -= 1;
+ break :t undefined;
+ };
}
- // Notify waiting threads outside the lock to try and keep the critical section small.
+ pool.mutex.unlock();
pool.cond.signal();
}
@@ -254,7 +271,7 @@ test spawn {
.allocator = std.testing.allocator,
});
defer pool.deinit();
- try pool.spawn(TestFn.checkRun, .{&completed});
+ pool.spawn(TestFn.checkRun, .{&completed});
}
try std.testing.expectEqual(true, completed);
@@ -306,15 +323,17 @@ pub fn waitAndWork(pool: *Pool, wait_group: *WaitGroup) void {
}
pub fn getIdCount(pool: *Pool) usize {
- return @intCast(1 + pool.threads.len);
+ return @intCast(1 + pool.threads.items.len);
}
-pub fn io(pool: *Pool) std.Io {
+pub fn io(pool: *Pool) Io {
return .{
.userdata = pool,
.vtable = &.{
.@"async" = @"async",
.@"await" = @"await",
+ .cancel = cancel,
+ .cancelRequested = cancelRequested,
.createFile = createFile,
.openFile = openFile,
.closeFile = closeFile,
@@ -326,15 +345,17 @@ pub fn io(pool: *Pool) std.Io {
const AsyncClosure = struct {
func: *const fn (context: *anyopaque, result: *anyopaque) void,
- run_node: std.Thread.Pool.RunQueue.Node = .{ .data = .{ .runFn = runFn } },
+ runnable: Runnable = .{ .runFn = runFn },
reset_event: std.Thread.ResetEvent,
+ cancel_flag: bool,
context_offset: usize,
result_offset: usize,
fn runFn(runnable: *std.Thread.Pool.Runnable, _: ?usize) void {
- const run_node: *std.Thread.Pool.RunQueue.Node = @fieldParentPtr("data", runnable);
- const closure: *AsyncClosure = @alignCast(@fieldParentPtr("run_node", run_node));
+ const closure: *AsyncClosure = @alignCast(@fieldParentPtr("runnable", runnable));
+ current_closure = closure;
closure.func(closure.contextPointer(), closure.resultPointer());
+ current_closure = null;
closure.reset_event.set();
}
@@ -359,16 +380,23 @@ const AsyncClosure = struct {
const base: [*]u8 = @ptrCast(closure);
return base + closure.context_offset;
}
+
+ fn waitAndFree(closure: *AsyncClosure, gpa: Allocator, result: []u8) void {
+ closure.reset_event.wait();
+ const base: [*]align(@alignOf(AsyncClosure)) u8 = @ptrCast(closure);
+ @memcpy(result, closure.resultPointer()[0..result.len]);
+ gpa.free(base[0 .. closure.result_offset + result.len]);
+ }
};
-pub fn @"async"(
+fn @"async"(
userdata: ?*anyopaque,
result: []u8,
result_alignment: std.mem.Alignment,
context: []const u8,
context_alignment: std.mem.Alignment,
start: *const fn (context: *const anyopaque, result: *anyopaque) void,
-) ?*std.Io.AnyFuture {
+) ?*Io.AnyFuture {
const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
pool.mutex.lock();
@@ -386,46 +414,87 @@ pub fn @"async"(
.context_offset = context_offset,
.result_offset = result_offset,
.reset_event = .{},
+ .cancel_flag = false,
};
@memcpy(closure.contextPointer()[0..context.len], context);
- pool.run_queue.prepend(&closure.run_node);
- pool.mutex.unlock();
+ pool.run_queue.prepend(&closure.runnable.node);
+
+ if (pool.threads.items.len < pool.threads.capacity) {
+ pool.threads.addOneAssumeCapacity().* = std.Thread.spawn(.{
+ .stack_size = pool.stack_size,
+ .allocator = gpa,
+ }, worker, .{pool}) catch t: {
+ pool.threads.items.len -= 1;
+ break :t undefined;
+ };
+ }
+ pool.mutex.unlock();
pool.cond.signal();
return @ptrCast(closure);
}
-pub fn @"await"(userdata: ?*anyopaque, any_future: *std.Io.AnyFuture, result: []u8) void {
- const thread_pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
+fn @"await"(userdata: ?*anyopaque, any_future: *Io.AnyFuture, result: []u8) void {
+ const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
const closure: *AsyncClosure = @ptrCast(@alignCast(any_future));
- closure.reset_event.wait();
- const base: [*]align(@alignOf(AsyncClosure)) u8 = @ptrCast(closure);
- @memcpy(result, closure.resultPointer()[0..result.len]);
- thread_pool.allocator.free(base[0 .. closure.result_offset + result.len]);
+ closure.waitAndFree(pool.allocator, result);
+}
+
+fn cancel(userdata: ?*anyopaque, any_future: *Io.AnyFuture, result: []u8) void {
+ const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
+ const closure: *AsyncClosure = @ptrCast(@alignCast(any_future));
+ @atomicStore(bool, &closure.cancel_flag, true, .seq_cst);
+ closure.waitAndFree(pool.allocator, result);
+}
+
+fn cancelRequested(userdata: ?*anyopaque) bool {
+ const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
+ _ = pool;
+ const closure = current_closure orelse return false;
+ return @atomicLoad(bool, &closure.cancel_flag, .unordered);
+}
+
+fn checkCancel(pool: *Pool) error{AsyncCancel}!void {
+ if (cancelRequested(pool)) return error.AsyncCancel;
}
-pub fn createFile(userdata: ?*anyopaque, dir: std.fs.Dir, sub_path: []const u8, flags: std.fs.File.CreateFlags) std.fs.File.OpenError!std.fs.File {
- _ = userdata;
+pub fn createFile(
+ userdata: ?*anyopaque,
+ dir: std.fs.Dir,
+ sub_path: []const u8,
+ flags: std.fs.File.CreateFlags,
+) Io.FileOpenError!std.fs.File {
+ const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
+ try pool.checkCancel();
return dir.createFile(sub_path, flags);
}
-pub fn openFile(userdata: ?*anyopaque, dir: std.fs.Dir, sub_path: []const u8, flags: std.fs.File.OpenFlags) std.fs.File.OpenError!std.fs.File {
- _ = userdata;
+pub fn openFile(
+ userdata: ?*anyopaque,
+ dir: std.fs.Dir,
+ sub_path: []const u8,
+ flags: std.fs.File.OpenFlags,
+) Io.FileOpenError!std.fs.File {
+ const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
+ try pool.checkCancel();
return dir.openFile(sub_path, flags);
}
pub fn closeFile(userdata: ?*anyopaque, file: std.fs.File) void {
- _ = userdata;
+ const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
+ _ = pool;
return file.close();
}
-pub fn read(userdata: ?*anyopaque, file: std.fs.File, buffer: []u8) std.fs.File.ReadError!usize {
- _ = userdata;
+pub fn read(userdata: ?*anyopaque, file: std.fs.File, buffer: []u8) Io.FileReadError!usize {
+ const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
+ try pool.checkCancel();
return file.read(buffer);
}
-pub fn write(userdata: ?*anyopaque, file: std.fs.File, buffer: []const u8) std.fs.File.WriteError!usize {
- _ = userdata;
+pub fn write(userdata: ?*anyopaque, file: std.fs.File, buffer: []const u8) Io.FileWriteError!usize {
+ const pool: *std.Thread.Pool = @alignCast(@ptrCast(userdata));
+ try pool.checkCancel();
return file.write(buffer);
}
lib/std/Io.zig
@@ -564,6 +564,8 @@ vtable: *const VTable,
pub const VTable = struct {
/// If it returns `null` it means `result` has been already populated and
/// `await` will be a no-op.
+ ///
+ /// Thread-safe.
async: *const fn (
/// Corresponds to `Io.userdata`.
userdata: ?*anyopaque,
@@ -579,6 +581,8 @@ pub const VTable = struct {
) ?*AnyFuture,
/// This function is only called when `async` returns a non-null value.
+ ///
+ /// Thread-safe.
await: *const fn (
/// Corresponds to `Io.userdata`.
userdata: ?*anyopaque,
@@ -589,13 +593,41 @@ pub const VTable = struct {
result: []u8,
) void,
- createFile: *const fn (?*anyopaque, dir: fs.Dir, sub_path: []const u8, flags: fs.File.CreateFlags) fs.File.OpenError!fs.File,
- openFile: *const fn (?*anyopaque, dir: fs.Dir, sub_path: []const u8, flags: fs.File.OpenFlags) fs.File.OpenError!fs.File,
+ /// Equivalent to `await` but initiates cancel request.
+ ///
+ /// This function is only called when `async` returns a non-null value.
+ ///
+ /// Thread-safe.
+ cancel: *const fn (
+ /// Corresponds to `Io.userdata`.
+ userdata: ?*anyopaque,
+ /// The same value that was returned from `async`.
+ any_future: *AnyFuture,
+ /// Points to a buffer where the result is written.
+ /// The length is equal to size in bytes of result type.
+ result: []u8,
+ ) void,
+
+ /// Returns whether the current thread of execution is known to have
+ /// been requested to cancel.
+ ///
+ /// Thread-safe.
+ cancelRequested: *const fn (?*anyopaque) bool,
+
+ createFile: *const fn (?*anyopaque, dir: fs.Dir, sub_path: []const u8, flags: fs.File.CreateFlags) FileOpenError!fs.File,
+ openFile: *const fn (?*anyopaque, dir: fs.Dir, sub_path: []const u8, flags: fs.File.OpenFlags) FileOpenError!fs.File,
closeFile: *const fn (?*anyopaque, fs.File) void,
- read: *const fn (?*anyopaque, file: fs.File, buffer: []u8) fs.File.ReadError!usize,
- write: *const fn (?*anyopaque, file: fs.File, buffer: []const u8) fs.File.WriteError!usize,
+ read: *const fn (?*anyopaque, file: fs.File, buffer: []u8) FileReadError!usize,
+ write: *const fn (?*anyopaque, file: fs.File, buffer: []const u8) FileWriteError!usize,
};
+pub const OpenFlags = fs.File.OpenFlags;
+pub const CreateFlags = fs.File.CreateFlags;
+
+pub const FileOpenError = fs.File.OpenError || error{AsyncCancel};
+pub const FileReadError = fs.File.ReadError || error{AsyncCancel};
+pub const FileWriteError = fs.File.WriteError || error{AsyncCancel};
+
pub const AnyFuture = opaque {};
pub fn Future(Result: type) type {
@@ -603,6 +635,17 @@ pub fn Future(Result: type) type {
any_future: ?*AnyFuture,
result: Result,
+ /// Equivalent to `await` but sets a flag observable to application
+ /// code that cancellation has been requested.
+ ///
+ /// Idempotent.
+ pub fn cancel(f: *@This(), io: Io) Result {
+ const any_future = f.any_future orelse return f.result;
+ io.vtable.cancel(io.userdata, any_future, @ptrCast((&f.result)[0..1]));
+ f.any_future = null;
+ return f.result;
+ }
+
pub fn await(f: *@This(), io: Io) Result {
const any_future = f.any_future orelse return f.result;
io.vtable.await(io.userdata, any_future, @ptrCast((&f.result)[0..1]));
@@ -636,11 +679,11 @@ pub fn async(io: Io, function: anytype, args: anytype) Future(@typeInfo(@TypeOf(
return future;
}
-pub fn openFile(io: Io, dir: fs.Dir, sub_path: []const u8, flags: fs.File.OpenFlags) fs.File.OpenError!fs.File {
+pub fn openFile(io: Io, dir: fs.Dir, sub_path: []const u8, flags: fs.File.OpenFlags) FileOpenError!fs.File {
return io.vtable.openFile(io.userdata, dir, sub_path, flags);
}
-pub fn createFile(io: Io, dir: fs.Dir, sub_path: []const u8, flags: fs.File.CreateFlags) fs.File.OpenError!fs.File {
+pub fn createFile(io: Io, dir: fs.Dir, sub_path: []const u8, flags: fs.File.CreateFlags) FileOpenError!fs.File {
return io.vtable.createFile(io.userdata, dir, sub_path, flags);
}
@@ -648,22 +691,22 @@ pub fn closeFile(io: Io, file: fs.File) void {
return io.vtable.closeFile(io.userdata, file);
}
-pub fn read(io: Io, file: fs.File, buffer: []u8) fs.File.ReadError!usize {
+pub fn read(io: Io, file: fs.File, buffer: []u8) FileReadError!usize {
return io.vtable.read(io.userdata, file, buffer);
}
-pub fn write(io: Io, file: fs.File, buffer: []const u8) fs.File.WriteError!usize {
+pub fn write(io: Io, file: fs.File, buffer: []const u8) FileWriteError!usize {
return io.vtable.write(io.userdata, file, buffer);
}
-pub fn writeAll(io: Io, file: fs.File, bytes: []const u8) fs.File.WriteError!void {
+pub fn writeAll(io: Io, file: fs.File, bytes: []const u8) FileWriteError!void {
var index: usize = 0;
while (index < bytes.len) {
index += try io.write(file, bytes[index..]);
}
}
-pub fn readAll(io: Io, file: fs.File, buffer: []u8) fs.File.ReadError!usize {
+pub fn readAll(io: Io, file: fs.File, buffer: []u8) FileReadError!usize {
var index: usize = 0;
while (index != buffer.len) {
const amt = try io.read(file, buffer[index..]);