Commit 2611d97fb0

mlugg <mlugg@mlugg.co.uk>
2023-06-10 02:23:17
Sema: copy pointer alignment to union field pointers
This implements the semantics as discussed in today's compiler meeting, where the alignment of pointers to fields of default-layout unions cannot exceed the field's alignment. Resolves: #15878
1 parent 3f04231
Changed files (2)
src
test
behavior
src/Sema.zig
@@ -26060,6 +26060,7 @@ fn unionFieldPtr(
     assert(unresolved_union_ty.zigTypeTag(mod) == .Union);
 
     const union_ptr_ty = sema.typeOf(union_ptr);
+    const union_ptr_info = union_ptr_ty.ptrInfo(mod);
     const union_ty = try sema.resolveTypeFields(unresolved_union_ty);
     const union_obj = mod.typeToUnion(union_ty).?;
     const field_index = try sema.unionFieldIndex(block, union_ty, field_name, field_name_src);
@@ -26067,10 +26068,16 @@ fn unionFieldPtr(
     const ptr_field_ty = try mod.ptrType(.{
         .child = field.ty.toIntern(),
         .flags = .{
-            .is_const = !union_ptr_ty.ptrIsMutable(mod),
-            .is_volatile = union_ptr_ty.isVolatilePtr(mod),
-            .address_space = union_ptr_ty.ptrAddressSpace(mod),
-        },
+            .is_const = union_ptr_info.flags.is_const,
+            .is_volatile = union_ptr_info.flags.is_volatile,
+            .address_space = union_ptr_info.flags.address_space,
+            .alignment = if (union_obj.layout == .Auto) blk: {
+                const union_align = union_ptr_info.flags.alignment.toByteUnitsOptional() orelse try sema.typeAbiAlignment(union_ty);
+                const field_align = try sema.unionFieldAlignment(field);
+                break :blk InternPool.Alignment.fromByteUnits(@min(union_align, field_align));
+            } else union_ptr_info.flags.alignment,
+        },
+        .packed_offset = union_ptr_info.packed_offset,
     });
     const enum_field_index = @as(u32, @intCast(union_obj.tag_ty.enumFieldIndex(field_name, mod).?));
 
test/behavior/union.zig
@@ -1583,3 +1583,112 @@ test "coerce enum literal to union in result loc" {
     try U.doTest(true);
     try comptime U.doTest(true);
 }
+
+test "defined-layout union field pointer has correct alignment" {
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO
+
+    const S = struct {
+        fn doTheTest(comptime U: type) !void {
+            var a: U = .{ .x = 123 };
+            var b: U align(1) = .{ .x = 456 };
+            var c: U align(64) = .{ .x = 789 };
+
+            const ap = &a.x;
+            const bp = &b.x;
+            const cp = &c.x;
+
+            comptime assert(@TypeOf(ap) == *u32);
+            comptime assert(@TypeOf(bp) == *align(1) u32);
+            comptime assert(@TypeOf(cp) == *align(64) u32);
+
+            try expectEqual(@as(u32, 123), ap.*);
+            try expectEqual(@as(u32, 456), bp.*);
+            try expectEqual(@as(u32, 789), cp.*);
+        }
+    };
+
+    const U1 = extern union { x: u32 };
+    const U2 = packed union { x: u32 };
+
+    try S.doTheTest(U1);
+    try S.doTheTest(U2);
+    try comptime S.doTheTest(U1);
+    try comptime S.doTheTest(U2);
+}
+
+test "undefined-layout union field pointer has correct alignment" {
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO
+
+    const S = struct {
+        fn doTheTest(comptime U: type) !void {
+            var a: U = .{ .x = 123 };
+            var b: U align(1) = .{ .x = 456 };
+            var c: U align(64) = .{ .x = 789 };
+
+            const ap = &a.x;
+            const bp = &b.x;
+            const cp = &c.x;
+
+            comptime assert(@TypeOf(ap) == *u32);
+            comptime assert(@TypeOf(bp) == *align(1) u32);
+            comptime assert(@TypeOf(cp) == *u32); // undefined layout so does not inherit larger aligns
+
+            try expectEqual(@as(u32, 123), ap.*);
+            try expectEqual(@as(u32, 456), bp.*);
+            try expectEqual(@as(u32, 789), cp.*);
+        }
+    };
+
+    const U1 = union { x: u32 };
+    const U2 = union(enum) { x: u32 };
+
+    try S.doTheTest(U1);
+    try S.doTheTest(U2);
+    try comptime S.doTheTest(U1);
+    try comptime S.doTheTest(U2);
+}
+
+test "packed union field pointer has correct alignment" {
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO
+
+    const U = packed union { x: u20 };
+    const S = packed struct(u24) { a: u2, u: U, b: u2 };
+
+    var a: S = undefined;
+    var b: S align(1) = undefined;
+    var c: S align(64) = undefined;
+
+    const ap = &a.u.x;
+    const bp = &b.u.x;
+    const cp = &c.u.x;
+
+    comptime assert(@TypeOf(ap) == *align(4:2:3) u20);
+    comptime assert(@TypeOf(bp) == *align(1:2:3) u20);
+    comptime assert(@TypeOf(cp) == *align(64:2:3) u20);
+
+    a.u = .{ .x = 123 };
+    b.u = .{ .x = 456 };
+    c.u = .{ .x = 789 };
+
+    try expectEqual(@as(u20, 123), ap.*);
+    try expectEqual(@as(u20, 456), bp.*);
+    try expectEqual(@as(u20, 789), cp.*);
+}