Commit dbcd53def0

John Schmidt <john.schmidt.h@gmail.com>
2024-02-04 23:04:43
Preserve field alignment in union pointer captures
1 parent 919a3ba
Changed files (3)
src
test
src/Sema.zig
@@ -10992,51 +10992,67 @@ const SwitchProngAnalysis = struct {
 
                 // By-reference captures have some further restrictions which make them easier to emit
                 if (capture_byref) {
+                    const first_field_alignment = union_obj.fieldAlign(ip, first_field_index);
+                    const same_alignment = for (field_indices[1..]) |field_idx| {
+                        const field_alignment = union_obj.fieldAlign(ip, field_idx);
+                        if (field_alignment != first_field_alignment) break false;
+                    } else true;
                     const operand_ptr_info = operand_ptr_ty.ptrInfo(mod);
-                    const capture_ptr_ty = try sema.ptrType(.{
-                        .child = capture_ty.toIntern(),
-                        .flags = .{
-                            // TODO: alignment!
-                            .is_const = operand_ptr_info.flags.is_const,
-                            .is_volatile = operand_ptr_info.flags.is_volatile,
-                            .address_space = operand_ptr_info.flags.address_space,
-                        },
-                    });
-
-                    // By-ref captures of hetereogeneous types are only allowed if each field
-                    // pointer type is in-memory coercible to the capture pointer type.
-                    if (!same_types) {
-                        for (field_indices, 0..) |field_idx, i| {
+                    const capture_ptr_ty = if (same_types and same_alignment) same: {
+                        break :same try sema.ptrType(.{
+                            .child = capture_ty.toIntern(),
+                            .flags = .{
+                                .is_const = operand_ptr_info.flags.is_const,
+                                .is_volatile = operand_ptr_info.flags.is_volatile,
+                                .address_space = operand_ptr_info.flags.address_space,
+                                .alignment = first_field_alignment,
+                            },
+                        });
+                    } else resolve: {
+                        // By-ref captures of hetereogeneous types are only allowed if all field
+                        // pointer types are peer resolvable to each other.
+                        // We need values to run PTR on, so make a bunch of undef constants.
+                        const dummy_captures = try sema.arena.alloc(Air.Inst.Ref, case_vals.len);
+                        for (field_indices, dummy_captures) |field_idx, *dummy| {
                             const field_ty = Type.fromInterned(union_obj.field_types.get(ip)[field_idx]);
                             const field_ptr_ty = try sema.ptrType(.{
                                 .child = field_ty.toIntern(),
                                 .flags = .{
-                                    // TODO: alignment!
                                     .is_const = operand_ptr_info.flags.is_const,
                                     .is_volatile = operand_ptr_info.flags.is_volatile,
                                     .address_space = operand_ptr_info.flags.address_space,
+                                    .alignment = union_obj.fieldAlign(ip, field_idx),
                                 },
                             });
-                            if (.ok != try sema.coerceInMemoryAllowed(block, capture_ptr_ty, field_ptr_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) {
+                            dummy.* = try mod.undefRef(field_ptr_ty);
+                        }
+                        const case_srcs = try sema.arena.alloc(?LazySrcLoc, case_vals.len);
+                        @memset(case_srcs, .unneeded);
+
+                        break :resolve sema.resolvePeerTypes(block, .unneeded, dummy_captures, .{ .override = case_srcs }) catch |err| switch (err) {
+                            error.NeededSourceLocation => {
+                                // This must be a multi-prong so this must be a `multi_capture` src
                                 const multi_idx = raw_capture_src.multi_capture;
                                 const src_decl_ptr = sema.mod.declPtr(block.src_decl);
+                                for (case_srcs, 0..) |*case_src, i| {
+                                    const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(i) } };
+                                    case_src.* = raw_case_src.resolve(mod, src_decl_ptr, switch_node_offset, .none);
+                                }
                                 const capture_src = raw_capture_src.resolve(mod, src_decl_ptr, switch_node_offset, .none);
-                                const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(i) } };
-                                const case_src = raw_case_src.resolve(mod, src_decl_ptr, switch_node_offset, .none);
-                                const msg = msg: {
-                                    const msg = try sema.errMsg(block, capture_src, "capture group with incompatible types", .{});
-                                    errdefer msg.destroy(sema.gpa);
-                                    try sema.errNote(block, case_src, msg, "pointer type child '{}' cannot cast into resolved pointer type child '{}'", .{
-                                        field_ty.fmt(sema.mod),
-                                        capture_ty.fmt(sema.mod),
-                                    });
-                                    try sema.errNote(block, capture_src, msg, "this coercion is only possible when capturing by value", .{});
-                                    break :msg msg;
+                                _ = sema.resolvePeerTypes(block, capture_src, dummy_captures, .{ .override = case_srcs }) catch |err1| switch (err1) {
+                                    error.AnalysisFail => {
+                                        const msg = sema.err orelse return error.AnalysisFail;
+                                        try sema.errNote(block, capture_src, msg, "this coercion is only possible when capturing by value", .{});
+                                        try sema.reparentOwnedErrorMsg(block, capture_src, msg, "capture group with incompatible types", .{});
+                                        return error.AnalysisFail;
+                                    },
+                                    else => |e| return e,
                                 };
-                                return sema.failWithOwnedErrorMsg(block, msg);
-                            }
-                        }
-                    }
+                                unreachable;
+                            },
+                            else => |e| return e,
+                        };
+                    };
 
                     if (try sema.resolveDefinedValue(block, operand_src, spa.operand_ptr)) |op_ptr_val| {
                         if (op_ptr_val.isUndef(mod)) return mod.undefRef(capture_ptr_ty);
test/behavior/switch.zig
@@ -574,6 +574,69 @@ test "switch prongs with cases with identical payload types" {
     try comptime S.doTheTest();
 }
 
+test "switch prong pointer capture alignment" {
+    const U = union(enum) {
+        a: u8 align(8),
+        b: u8 align(4),
+        c: u8,
+    };
+
+    const S = struct {
+        fn doTheTest() !void {
+            const u = U{ .a = 1 };
+            switch (u) {
+                .a => |*a| try expectEqual(*align(8) const u8, @TypeOf(a)),
+                .b, .c => |*p| {
+                    _ = p;
+                    @panic("fail");
+                },
+            }
+
+            switch (u) {
+                .a, .b => |*p| try expectEqual(*align(4) const u8, @TypeOf(p)),
+                .c => |*p| {
+                    _ = p;
+                    @panic("fail");
+                },
+            }
+
+            switch (u) {
+                .a, .c => |*p| try expectEqual(*const u8, @TypeOf(p)),
+                .b => |*p| {
+                    _ = p;
+                    @panic("fail");
+                },
+            }
+        }
+
+        fn doTheTest2() !void {
+            const un1 = U{ .b = 1 };
+            switch (un1) {
+                .b => |*a| try expectEqual(*align(4) const u8, @TypeOf(a)),
+                .a, .c => |*p| {
+                    _ = p;
+                    @panic("fail");
+                },
+            }
+
+            const un2 = U{ .c = 1 };
+            switch (un2) {
+                .c => |*a| try expectEqual(*const u8, @TypeOf(a)),
+                .a, .b => |*p| {
+                    _ = p;
+                    @panic("fail");
+                },
+            }
+        }
+    };
+
+    try S.doTheTest();
+    try comptime S.doTheTest();
+
+    try S.doTheTest2();
+    try comptime S.doTheTest2();
+}
+
 test "switch on pointer type" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
test/cases/compile_errors/switch_capture_incompatible_types.zig
@@ -23,5 +23,7 @@ export fn g() void {
 // :5:10: note: type 'u32' here
 // :5:14: note: type '*u8' here
 // :13:20: error: capture group with incompatible types
-// :13:14: note: pointer type child 'u32' cannot cast into resolved pointer type child 'u64'
+// :13:20: note: incompatible types: '*u64' and '*u32'
+// :13:10: note: type '*u64' here
+// :13:14: note: type '*u32' here
 // :13:20: note: this coercion is only possible when capturing by value