Commit cca2d09950

Carl Åstholm <carl@astholm.se>
2025-11-05 01:00:44
io: Correctly align async closure contexts
This fixes package fetching on Windows. Previously, `Async/GroupClosure` allocations were only aligned for the closure struct type, which resulted in panics when `context_alignment` (or `result_alignment` for that matter) had a greater alignment.
1 parent 5f13922
Changed files (2)
lib
std
lib/std/Io/Threaded/test.zig
@@ -56,3 +56,62 @@ test "concurrent vs concurrent prevents deadlock via oversubscription" {
     getter.await(io);
     putter.await(io);
 }
+
+const ByteArray256 = struct { x: [32]u8 align(32) };
+const ByteArray512 = struct { x: [64]u8 align(64) };
+
+fn concatByteArrays(a: ByteArray256, b: ByteArray256) ByteArray512 {
+    return .{ .x = a.x ++ b.x };
+}
+
+test "async/concurrent context and result alignment" {
+    var buffer: [2048]u8 align(@alignOf(ByteArray512)) = undefined;
+    var fba: std.heap.FixedBufferAllocator = .init(&buffer);
+
+    var threaded: std.Io.Threaded = .init(fba.allocator());
+    defer threaded.deinit();
+    const io = threaded.io();
+
+    const a: ByteArray256 = .{ .x = @splat(2) };
+    const b: ByteArray256 = .{ .x = @splat(3) };
+    const expected: ByteArray512 = .{ .x = @as([32]u8, @splat(2)) ++ @as([32]u8, @splat(3)) };
+
+    {
+        var future = io.async(concatByteArrays, .{ a, b });
+        const result = future.await(io);
+        try std.testing.expectEqualSlices(u8, &expected.x, &result.x);
+    }
+    {
+        var future = io.concurrent(concatByteArrays, .{ a, b }) catch |err| switch (err) {
+            error.ConcurrencyUnavailable => {
+                try testing.expect(builtin.single_threaded);
+                return;
+            },
+        };
+        const result = future.await(io);
+        try std.testing.expectEqualSlices(u8, &expected.x, &result.x);
+    }
+}
+
+fn concatByteArraysResultPtr(a: ByteArray256, b: ByteArray256, result: *ByteArray512) void {
+    result.* = .{ .x = a.x ++ b.x };
+}
+
+test "Group.async context alignment" {
+    var buffer: [2048]u8 align(@alignOf(ByteArray512)) = undefined;
+    var fba: std.heap.FixedBufferAllocator = .init(&buffer);
+
+    var threaded: std.Io.Threaded = .init(fba.allocator());
+    defer threaded.deinit();
+    const io = threaded.io();
+
+    const a: ByteArray256 = .{ .x = @splat(2) };
+    const b: ByteArray256 = .{ .x = @splat(3) };
+    const expected: ByteArray512 = .{ .x = @as([32]u8, @splat(2)) ++ @as([32]u8, @splat(3)) };
+
+    var group: std.Io.Group = .init;
+    var result: ByteArray512 = undefined;
+    group.async(io, concatByteArraysResultPtr, .{ a, b, &result });
+    group.wait(io);
+    try std.testing.expectEqualSlices(u8, &expected.x, &result.x);
+}
lib/std/Io/Threaded.zig
@@ -389,6 +389,7 @@ const AsyncClosure = struct {
     select_condition: ?*ResetEvent,
     context_alignment: std.mem.Alignment,
     result_offset: usize,
+    alloc_len: usize,
 
     const done_reset_event: *ResetEvent = @ptrFromInt(@alignOf(ResetEvent));
 
@@ -425,18 +426,59 @@ const AsyncClosure = struct {
 
     fn contextPointer(ac: *AsyncClosure) [*]u8 {
         const base: [*]u8 = @ptrCast(ac);
-        return base + ac.context_alignment.forward(@sizeOf(AsyncClosure));
+        const context_offset = ac.context_alignment.forward(@intFromPtr(ac) + @sizeOf(AsyncClosure)) - @intFromPtr(ac);
+        return base + context_offset;
+    }
+
+    fn init(
+        gpa: Allocator,
+        mode: enum { async, concurrent },
+        result_len: usize,
+        result_alignment: std.mem.Alignment,
+        context: []const u8,
+        context_alignment: std.mem.Alignment,
+        func: *const fn (context: *const anyopaque, result: *anyopaque) void,
+    ) Allocator.Error!*AsyncClosure {
+        const max_context_misalignment = context_alignment.toByteUnits() -| @alignOf(AsyncClosure);
+        const worst_case_context_offset = context_alignment.forward(@sizeOf(AsyncClosure) + max_context_misalignment);
+        const worst_case_result_offset = result_alignment.forward(worst_case_context_offset + context.len);
+        const alloc_len = worst_case_result_offset + result_len;
+
+        const ac: *AsyncClosure = @ptrCast(@alignCast(try gpa.alignedAlloc(u8, .of(AsyncClosure), alloc_len)));
+        errdefer comptime unreachable;
+
+        const actual_context_addr = context_alignment.forward(@intFromPtr(ac) + @sizeOf(AsyncClosure));
+        const actual_result_addr = result_alignment.forward(actual_context_addr + context.len);
+        const actual_result_offset = actual_result_addr - @intFromPtr(ac);
+        ac.* = .{
+            .closure = .{
+                .cancel_tid = .none,
+                .start = start,
+                .is_concurrent = switch (mode) {
+                    .async => false,
+                    .concurrent => true,
+                },
+            },
+            .func = func,
+            .context_alignment = context_alignment,
+            .result_offset = actual_result_offset,
+            .alloc_len = alloc_len,
+            .reset_event = .unset,
+            .select_condition = null,
+        };
+        @memcpy(ac.contextPointer()[0..context.len], context);
+        return ac;
     }
 
-    fn waitAndFree(ac: *AsyncClosure, gpa: Allocator, result: []u8) void {
+    fn waitAndDeinit(ac: *AsyncClosure, gpa: Allocator, result: []u8) void {
         ac.reset_event.waitUncancelable();
         @memcpy(result, ac.resultPointer()[0..result.len]);
-        free(ac, gpa, result.len);
+        ac.deinit(gpa);
     }
 
-    fn free(ac: *AsyncClosure, gpa: Allocator, result_len: usize) void {
+    fn deinit(ac: *AsyncClosure, gpa: Allocator) void {
         const base: [*]align(@alignOf(AsyncClosure)) u8 = @ptrCast(ac);
-        gpa.free(base[0 .. ac.result_offset + result_len]);
+        gpa.free(base[0..ac.alloc_len]);
     }
 };
 
@@ -452,6 +494,7 @@ fn async(
         start(context.ptr, result.ptr);
         return null;
     }
+
     const t: *Threaded = @ptrCast(@alignCast(userdata));
     const cpu_count = t.cpu_count catch {
         return concurrent(userdata, result.len, result_alignment, context, context_alignment, start) catch {
@@ -459,37 +502,20 @@ fn async(
             return null;
         };
     };
+
     const gpa = t.allocator;
-    const context_offset = context_alignment.forward(@sizeOf(AsyncClosure));
-    const result_offset = result_alignment.forward(context_offset + context.len);
-    const n = result_offset + result.len;
-    const ac: *AsyncClosure = @ptrCast(@alignCast(gpa.alignedAlloc(u8, .of(AsyncClosure), n) catch {
+    const ac = AsyncClosure.init(gpa, .async, result.len, result_alignment, context, context_alignment, start) catch {
         start(context.ptr, result.ptr);
         return null;
-    }));
-
-    ac.* = .{
-        .closure = .{
-            .cancel_tid = .none,
-            .start = AsyncClosure.start,
-            .is_concurrent = false,
-        },
-        .func = start,
-        .context_alignment = context_alignment,
-        .result_offset = result_offset,
-        .reset_event = .unset,
-        .select_condition = null,
     };
 
-    @memcpy(ac.contextPointer()[0..context.len], context);
-
     t.mutex.lock();
 
     const thread_capacity = cpu_count - 1 + t.concurrent_count;
 
     t.threads.ensureTotalCapacityPrecise(gpa, thread_capacity) catch {
         t.mutex.unlock();
-        ac.free(gpa, result.len);
+        ac.deinit(gpa);
         start(context.ptr, result.ptr);
         return null;
     };
@@ -501,7 +527,7 @@ fn async(
             if (t.threads.items.len == 0) {
                 assert(t.run_queue.popFirst() == &ac.closure.node);
                 t.mutex.unlock();
-                ac.free(gpa, result.len);
+                ac.deinit(gpa);
                 start(context.ptr, result.ptr);
                 return null;
             }
@@ -530,27 +556,11 @@ fn concurrent(
 
     const t: *Threaded = @ptrCast(@alignCast(userdata));
     const cpu_count = t.cpu_count catch 1;
+
     const gpa = t.allocator;
-    const context_offset = context_alignment.forward(@sizeOf(AsyncClosure));
-    const result_offset = result_alignment.forward(context_offset + context.len);
-    const n = result_offset + result_len;
-    const ac_bytes = gpa.alignedAlloc(u8, .of(AsyncClosure), n) catch
+    const ac = AsyncClosure.init(gpa, .concurrent, result_len, result_alignment, context, context_alignment, start) catch {
         return error.ConcurrencyUnavailable;
-    const ac: *AsyncClosure = @ptrCast(@alignCast(ac_bytes));
-
-    ac.* = .{
-        .closure = .{
-            .cancel_tid = .none,
-            .start = AsyncClosure.start,
-            .is_concurrent = true,
-        },
-        .func = start,
-        .context_alignment = context_alignment,
-        .result_offset = result_offset,
-        .reset_event = .unset,
-        .select_condition = null,
     };
-    @memcpy(ac.contextPointer()[0..context.len], context);
 
     t.mutex.lock();
 
@@ -559,7 +569,7 @@ fn concurrent(
 
     t.threads.ensureTotalCapacity(gpa, thread_capacity) catch {
         t.mutex.unlock();
-        ac.free(gpa, result_len);
+        ac.deinit(gpa);
         return error.ConcurrencyUnavailable;
     };
 
@@ -569,7 +579,7 @@ fn concurrent(
         const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch {
             assert(t.run_queue.popFirst() == &ac.closure.node);
             t.mutex.unlock();
-            ac.free(gpa, result_len);
+            ac.deinit(gpa);
             return error.ConcurrencyUnavailable;
         };
         t.threads.appendAssumeCapacity(thread);
@@ -588,7 +598,7 @@ const GroupClosure = struct {
     node: std.SinglyLinkedList.Node,
     func: *const fn (*Io.Group, context: *anyopaque) void,
     context_alignment: std.mem.Alignment,
-    context_len: usize,
+    alloc_len: usize,
 
     fn start(closure: *Closure) void {
         const gc: *GroupClosure = @alignCast(@fieldParentPtr("closure", closure));
@@ -616,22 +626,48 @@ const GroupClosure = struct {
         if (prev_state == (sync_one_pending | sync_is_waiting)) reset_event.set();
     }
 
-    fn free(gc: *GroupClosure, gpa: Allocator) void {
-        const base: [*]align(@alignOf(GroupClosure)) u8 = @ptrCast(gc);
-        gpa.free(base[0..contextEnd(gc.context_alignment, gc.context_len)]);
-    }
-
-    fn contextOffset(context_alignment: std.mem.Alignment) usize {
-        return context_alignment.forward(@sizeOf(GroupClosure));
-    }
-
-    fn contextEnd(context_alignment: std.mem.Alignment, context_len: usize) usize {
-        return contextOffset(context_alignment) + context_len;
-    }
-
     fn contextPointer(gc: *GroupClosure) [*]u8 {
         const base: [*]u8 = @ptrCast(gc);
-        return base + contextOffset(gc.context_alignment);
+        const context_offset = gc.context_alignment.forward(@intFromPtr(gc) + @sizeOf(GroupClosure)) - @intFromPtr(gc);
+        return base + context_offset;
+    }
+
+    /// Does not initialize the `node` field.
+    fn init(
+        gpa: Allocator,
+        t: *Threaded,
+        group: *Io.Group,
+        context: []const u8,
+        context_alignment: std.mem.Alignment,
+        func: *const fn (*Io.Group, context: *const anyopaque) void,
+    ) Allocator.Error!*GroupClosure {
+        const max_context_misalignment = context_alignment.toByteUnits() -| @alignOf(GroupClosure);
+        const worst_case_context_offset = context_alignment.forward(@sizeOf(GroupClosure) + max_context_misalignment);
+        const alloc_len = worst_case_context_offset + context.len;
+
+        const gc: *GroupClosure = @ptrCast(@alignCast(try gpa.alignedAlloc(u8, .of(GroupClosure), alloc_len)));
+        errdefer comptime unreachable;
+
+        gc.* = .{
+            .closure = .{
+                .cancel_tid = .none,
+                .start = start,
+                .is_concurrent = false,
+            },
+            .t = t,
+            .group = group,
+            .node = undefined,
+            .func = func,
+            .context_alignment = context_alignment,
+            .alloc_len = alloc_len,
+        };
+        @memcpy(gc.contextPointer()[0..context.len], context);
+        return gc;
+    }
+
+    fn deinit(gc: *GroupClosure, gpa: Allocator) void {
+        const base: [*]align(@alignOf(GroupClosure)) u8 = @ptrCast(gc);
+        gpa.free(base[0..gc.alloc_len]);
     }
 
     const sync_is_waiting: usize = 1 << 0;
@@ -646,27 +682,14 @@ fn groupAsync(
     start: *const fn (*Io.Group, context: *const anyopaque) void,
 ) void {
     if (builtin.single_threaded) return start(group, context.ptr);
+
     const t: *Threaded = @ptrCast(@alignCast(userdata));
     const cpu_count = t.cpu_count catch 1;
+
     const gpa = t.allocator;
-    const n = GroupClosure.contextEnd(context_alignment, context.len);
-    const gc: *GroupClosure = @ptrCast(@alignCast(gpa.alignedAlloc(u8, .of(GroupClosure), n) catch {
+    const gc = GroupClosure.init(gpa, t, group, context, context_alignment, start) catch {
         return start(group, context.ptr);
-    }));
-    gc.* = .{
-        .closure = .{
-            .cancel_tid = .none,
-            .start = GroupClosure.start,
-            .is_concurrent = false,
-        },
-        .t = t,
-        .group = group,
-        .node = undefined,
-        .func = start,
-        .context_alignment = context_alignment,
-        .context_len = context.len,
     };
-    @memcpy(gc.contextPointer()[0..context.len], context);
 
     t.mutex.lock();
 
@@ -678,7 +701,7 @@ fn groupAsync(
 
     t.threads.ensureTotalCapacityPrecise(gpa, thread_capacity) catch {
         t.mutex.unlock();
-        gc.free(gpa);
+        gc.deinit(gpa);
         return start(group, context.ptr);
     };
 
@@ -688,7 +711,7 @@ fn groupAsync(
         const thread = std.Thread.spawn(.{ .stack_size = t.stack_size }, worker, .{t}) catch {
             assert(t.run_queue.popFirst() == &gc.closure.node);
             t.mutex.unlock();
-            gc.free(gpa);
+            gc.deinit(gpa);
             return start(group, context.ptr);
         };
         t.threads.appendAssumeCapacity(thread);
@@ -730,7 +753,7 @@ fn groupWait(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) void {
     while (true) {
         const gc: *GroupClosure = @fieldParentPtr("node", node);
         const node_next = node.next;
-        gc.free(gpa);
+        gc.deinit(gpa);
         node = node_next orelse break;
     }
 }
@@ -761,7 +784,7 @@ fn groupCancel(userdata: ?*anyopaque, group: *Io.Group, token: *anyopaque) void
         while (true) {
             const gc: *GroupClosure = @fieldParentPtr("node", node);
             const node_next = node.next;
-            gc.free(gpa);
+            gc.deinit(gpa);
             node = node_next orelse break;
         }
     }
@@ -776,7 +799,7 @@ fn await(
     _ = result_alignment;
     const t: *Threaded = @ptrCast(@alignCast(userdata));
     const closure: *AsyncClosure = @ptrCast(@alignCast(any_future));
-    closure.waitAndFree(t.allocator, result);
+    closure.waitAndDeinit(t.allocator, result);
 }
 
 fn cancel(
@@ -789,7 +812,7 @@ fn cancel(
     const t: *Threaded = @ptrCast(@alignCast(userdata));
     const ac: *AsyncClosure = @ptrCast(@alignCast(any_future));
     ac.closure.requestCancel();
-    ac.waitAndFree(t.allocator, result);
+    ac.waitAndDeinit(t.allocator, result);
 }
 
 fn cancelRequested(userdata: ?*anyopaque) bool {