Commit 4d044ee7e0

kcbanner <kcbanner@gmail.com>
2023-10-21 18:03:36
sema: Add union alignment resolution
- Add resolveUnionAlignment, to resolve a union's alignment only, without triggering layout resolution. - Update resolveUnionLayout to cache size, alignment, and padding. abiSizeAdvanced and abiAlignmentAdvanced now use this information instead of computing it each time.
1 parent ffaeb45
src/InternPool.zig
@@ -491,7 +491,7 @@ pub const Key = union(enum) {
 
         /// The returned pointer expires with any addition to the `InternPool`.
         /// Asserts the struct is not packed.
-        pub fn flagsPtr(self: @This(), ip: *InternPool) *Tag.TypeStruct.Flags {
+        pub fn flagsPtr(self: @This(), ip: *const InternPool) *Tag.TypeStruct.Flags {
             assert(self.layout != .Packed);
             const flags_field_index = std.meta.fieldIndex(Tag.TypeStruct, "flags").?;
             return @ptrCast(&ip.extra.items[self.extra_index + flags_field_index]);
@@ -687,6 +687,18 @@ pub const Key = union(enum) {
             return @ptrCast(&ip.extra.items[self.extra_index + flags_field_index]);
         }
 
+        /// The returned pointer expires with any addition to the `InternPool`.
+        pub fn size(self: @This(), ip: *InternPool) *u32 {
+            const size_field_index = std.meta.fieldIndex(Tag.TypeUnion, "size").?;
+            return @ptrCast(&ip.extra.items[self.extra_index + size_field_index]);
+        }
+
+        /// The returned pointer expires with any addition to the `InternPool`.
+        pub fn padding(self: @This(), ip: *InternPool) *u32 {
+            const padding_field_index = std.meta.fieldIndex(Tag.TypeUnion, "padding").?;
+            return @ptrCast(&ip.extra.items[self.extra_index + padding_field_index]);
+        }
+
         pub fn haveFieldTypes(self: @This(), ip: *const InternPool) bool {
             return self.flagsPtr(ip).status.haveFieldTypes();
         }
@@ -1744,6 +1756,10 @@ pub const UnionType = struct {
     enum_tag_ty: Index,
     /// The integer tag type of the enum.
     int_tag_ty: Index,
+    /// ABI size of the union, including padding
+    size: u64,
+    /// Trailing padding bytes
+    padding: u32,
     /// List of field names in declaration order.
     field_names: NullTerminatedString.Slice,
     /// List of field types in declaration order.
@@ -1830,6 +1846,10 @@ pub const UnionType = struct {
         return self.flagsPtr(ip).runtime_tag.hasTag();
     }
 
+    pub fn haveFieldTypes(self: UnionType, ip: *const InternPool) bool {
+        return self.flagsPtr(ip).status.haveFieldTypes();
+    }
+
     pub fn haveLayout(self: UnionType, ip: *const InternPool) bool {
         return self.flagsPtr(ip).status.haveLayout();
     }
@@ -1867,6 +1887,8 @@ pub fn loadUnionType(ip: *InternPool, key: Key.UnionType) UnionType {
         .namespace = type_union.data.namespace,
         .enum_tag_ty = enum_ty,
         .int_tag_ty = enum_info.tag_ty,
+        .size = type_union.data.padding,
+        .padding = type_union.data.padding,
         .field_names = enum_info.names,
         .names_map = enum_info.names_map,
         .field_types = .{
@@ -2943,6 +2965,10 @@ pub const Tag = enum(u8) {
     /// 1. field align: Alignment for each field; declaration order
     pub const TypeUnion = struct {
         flags: Flags,
+        // Only valid after .have_layout
+        size: u32,
+        // Only valid after .have_layout
+        padding: u32,
         decl: Module.Decl.Index,
         namespace: Module.Namespace.Index,
         /// The enum that provides the list of field names and values.
@@ -2957,7 +2983,8 @@ pub const Tag = enum(u8) {
             status: UnionType.Status,
             requires_comptime: RequiresComptime,
             assumed_runtime_bits: bool,
-            _: u21 = 0,
+            alignment: Alignment,
+            _: u15 = 0,
         };
     };
 
@@ -3021,7 +3048,7 @@ pub const Tag = enum(u8) {
             any_comptime_fields: bool,
             any_default_inits: bool,
             any_aligned_fields: bool,
-            /// `undefined` until the layout_resolved
+            /// `.none` until layout_resolved
             alignment: Alignment,
             /// Dependency loop detection when resolving struct alignment.
             alignment_wip: bool,
@@ -5262,6 +5289,8 @@ pub fn getUnionType(ip: *InternPool, gpa: Allocator, ini: UnionTypeInit) Allocat
 
     const union_type_extra_index = ip.addExtraAssumeCapacity(Tag.TypeUnion{
         .flags = ini.flags,
+        .size = std.math.maxInt(u32),
+        .padding = std.math.maxInt(u32),
         .decl = ini.decl,
         .namespace = ini.namespace,
         .tag_ty = ini.enum_tag_ty,
src/Module.zig
@@ -6538,31 +6538,11 @@ pub fn getUnionLayout(mod: *Module, u: InternPool.UnionType) UnionLayout {
             .padding = 0,
         };
     }
-    // Put the tag before or after the payload depending on which one's
-    // alignment is greater.
+
     const tag_size = u.enum_tag_ty.toType().abiSize(mod);
     const tag_align = u.enum_tag_ty.toType().abiAlignment(mod).max(.@"1");
-    var size: u64 = 0;
-    var padding: u32 = undefined;
-    if (tag_align.compare(.gte, payload_align)) {
-        // {Tag, Payload}
-        size += tag_size;
-        size = payload_align.forward(size);
-        size += payload_size;
-        const prev_size = size;
-        size = tag_align.forward(size);
-        padding = @intCast(size - prev_size);
-    } else {
-        // {Payload, Tag}
-        size += payload_size;
-        size = tag_align.forward(size);
-        size += tag_size;
-        const prev_size = size;
-        size = payload_align.forward(size);
-        padding = @intCast(size - prev_size);
-    }
     return .{
-        .abi_size = size,
+        .abi_size = u.size,
         .abi_align = tag_align.max(payload_align),
         .most_aligned_field = most_aligned_field,
         .most_aligned_field_size = most_aligned_field_size,
@@ -6571,7 +6551,7 @@ pub fn getUnionLayout(mod: *Module, u: InternPool.UnionType) UnionLayout {
         .payload_align = payload_align,
         .tag_align = tag_align,
         .tag_size = tag_size,
-        .padding = padding,
+        .padding = u.padding,
     };
 }
 
src/Sema.zig
@@ -3200,6 +3200,7 @@ fn zirUnionDecl(
                 .any_aligned_fields = small.any_aligned_fields,
                 .requires_comptime = .unknown,
                 .assumed_runtime_bits = false,
+                .alignment = .none,
             },
             .decl = new_decl_index,
             .namespace = new_namespace_index,
@@ -20988,6 +20989,7 @@ fn zirReify(
                     .any_aligned_fields = any_aligned_fields,
                     .requires_comptime = .unknown,
                     .assumed_runtime_bits = false,
+                    .alignment = .none,
                 },
                 .field_types = union_fields.items(.type),
                 .field_aligns = if (any_aligned_fields) union_fields.items(.alignment) else &.{},
@@ -34921,11 +34923,56 @@ fn checkMemOperand(sema: *Sema, block: *Block, src: LazySrcLoc, ty: Type) !void
     return sema.failWithOwnedErrorMsg(block, msg);
 }
 
+/// Resolve a unions's alignment only without triggering resolution of its layout.
+/// Asserts that the alignment is not yet resolved.
+pub fn resolveUnionAlignment(
+    sema: *Sema,
+    ty: Type,
+    union_type: InternPool.Key.UnionType,
+) CompileError!Alignment {
+    const mod = sema.mod;
+    const ip = &mod.intern_pool;
+    const target = mod.getTarget();
+
+    assert(!union_type.haveLayout(ip));
+
+    if (union_type.flagsPtr(ip).status == .field_types_wip) {
+        // We'll guess "pointer-aligned", if the union has an
+        // underaligned pointer field then some allocations
+        // might require explicit alignment.
+        return Alignment.fromByteUnits(@divExact(target.ptrBitWidth(), 8));
+    }
+
+    try sema.resolveTypeFieldsUnion(ty, union_type);
+
+    const union_obj = ip.loadUnionType(union_type);
+    var max_align: Alignment = .@"1";
+    for (0..union_obj.field_names.len) |field_index| {
+        const field_ty = union_obj.field_types.get(ip)[field_index].toType();
+        if (!(try sema.typeHasRuntimeBits(field_ty))) continue;
+
+        const explicit_align = union_obj.fieldAlign(ip, @intCast(field_index));
+        const field_align = if (explicit_align != .none)
+            explicit_align
+        else
+            try sema.typeAbiAlignment(field_ty);
+
+        max_align = max_align.max(field_align);
+    }
+
+    union_type.flagsPtr(ip).alignment = max_align;
+    return max_align;
+}
+
+/// This logic must be kept in sync with `Module.getUnionLayout`.
 fn resolveUnionLayout(sema: *Sema, ty: Type) CompileError!void {
     const mod = sema.mod;
     const ip = &mod.intern_pool;
-    try sema.resolveTypeFields(ty);
-    const union_obj = mod.typeToUnion(ty).?;
+
+    const union_type = ip.indexToKey(ty.ip_index).union_type;
+    try sema.resolveTypeFieldsUnion(ty, union_type);
+
+    const union_obj = ip.loadUnionType(union_type);
     switch (union_obj.flagsPtr(ip).status) {
         .none, .have_field_types => {},
         .field_types_wip, .layout_wip => {
@@ -34939,25 +34986,74 @@ fn resolveUnionLayout(sema: *Sema, ty: Type) CompileError!void {
         },
         .have_layout, .fully_resolved_wip, .fully_resolved => return,
     }
+
     const prev_status = union_obj.flagsPtr(ip).status;
     errdefer if (union_obj.flagsPtr(ip).status == .layout_wip) {
         union_obj.flagsPtr(ip).status = prev_status;
     };
 
     union_obj.flagsPtr(ip).status = .layout_wip;
-    for (0..union_obj.field_types.len) |field_index| {
+
+    var max_size: u64 = 0;
+    var max_align: Alignment = .@"1";
+    for (0..union_obj.field_names.len) |field_index| {
         const field_ty = union_obj.field_types.get(ip)[field_index].toType();
-        sema.resolveTypeLayout(field_ty) catch |err| switch (err) {
+        if (!(try sema.typeHasRuntimeBits(field_ty))) continue;
+
+        max_size = @max(max_size, sema.typeAbiSize(field_ty) catch |err| switch (err) {
             error.AnalysisFail => {
                 const msg = sema.err orelse return err;
                 try sema.addFieldErrNote(ty, field_index, msg, "while checking this field", .{});
                 return err;
             },
             else => return err,
-        };
-    }
-    union_obj.flagsPtr(ip).status = .have_layout;
-    _ = try sema.typeRequiresComptime(ty);
+        });
+
+        const explicit_align = union_obj.fieldAlign(ip, @intCast(field_index));
+        const field_align = if (explicit_align != .none)
+            explicit_align
+        else
+            try sema.typeAbiAlignment(field_ty);
+
+        max_align = max_align.max(field_align);
+    }
+
+    const flags = union_obj.flagsPtr(ip);
+    const has_runtime_tag = flags.runtime_tag.hasTag() and try sema.typeHasRuntimeBits(union_obj.enum_tag_ty.toType());
+    const size, const alignment, const padding = if (has_runtime_tag) layout: {
+        const enum_tag_type = union_obj.enum_tag_ty.toType();
+        const tag_align = try sema.typeAbiAlignment(enum_tag_type);
+        const tag_size = try sema.typeAbiSize(enum_tag_type);
+
+        // Put the tag before or after the payload depending on which one's
+        // alignment is greater.
+        var size: u64 = 0;
+        var padding: u32 = 0;
+        if (tag_align.compare(.gte, max_align)) {
+            // {Tag, Payload}
+            size += tag_size;
+            size = max_align.forward(size);
+            size += max_size;
+            const prev_size = size;
+            size = tag_align.forward(size);
+            padding = @intCast(size - prev_size);
+        } else {
+            // {Payload, Tag}
+            size += max_size;
+            size = tag_align.forward(size);
+            size += tag_size;
+            const prev_size = size;
+            size = max_align.forward(size);
+            padding = @intCast(size - prev_size);
+        }
+
+        break :layout .{ size, max_align.max(tag_align), padding };
+    } else .{ max_align.forward(max_size), max_align, 0 };
+
+    union_type.size(ip).* = @intCast(size);
+    union_type.padding(ip).* = padding;
+    flags.alignment = alignment;
+    flags.status = .have_layout;
 
     if (union_obj.flagsPtr(ip).assumed_runtime_bits and !(try sema.typeHasRuntimeBits(ty))) {
         const msg = try Module.ErrorMsg.create(
@@ -35034,7 +35130,6 @@ fn resolveStructFully(sema: *Sema, ty: Type) CompileError!void {
 
 fn resolveUnionFully(sema: *Sema, ty: Type) CompileError!void {
     try sema.resolveUnionLayout(ty);
-    try sema.resolveTypeFields(ty);
 
     const mod = sema.mod;
     const ip = &mod.intern_pool;
src/type.zig
@@ -1034,66 +1034,20 @@ pub const Type = struct {
                     }
                     return .{ .scalar = big_align };
                 },
-
                 .union_type => |union_type| {
-                    if (opt_sema) |sema| {
-                        if (union_type.flagsPtr(ip).status == .field_types_wip) {
-                            // We'll guess "pointer-aligned", if the union has an
-                            // underaligned pointer field then some allocations
-                            // might require explicit alignment.
-                            return .{ .scalar = Alignment.fromByteUnits(@divExact(target.ptrBitWidth(), 8)) };
-                        }
-                        _ = try sema.resolveTypeFields(ty);
-                    }
-                    if (!union_type.haveFieldTypes(ip)) switch (strat) {
+                    const flags = union_type.flagsPtr(ip).*;
+                    if (flags.alignment != .none) return .{ .scalar = flags.alignment };
+
+                    if (!union_type.haveLayout(ip)) switch (strat) {
                         .eager => unreachable, // union layout not resolved
-                        .sema => unreachable, // handled above
+                        .sema => |sema| return .{ .scalar = try sema.resolveUnionAlignment(ty, union_type) },
                         .lazy => return .{ .val = (try mod.intern(.{ .int = .{
                             .ty = .comptime_int_type,
                             .storage = .{ .lazy_align = ty.toIntern() },
                         } })).toValue() },
                     };
-                    const union_obj = ip.loadUnionType(union_type);
-                    if (union_obj.field_names.len == 0) {
-                        if (union_obj.hasTag(ip)) {
-                            return abiAlignmentAdvanced(union_obj.enum_tag_ty.toType(), mod, strat);
-                        } else {
-                            return .{ .scalar = .@"1" };
-                        }
-                    }
 
-                    var max_align: Alignment = .@"1";
-                    if (union_obj.hasTag(ip)) max_align = union_obj.enum_tag_ty.toType().abiAlignment(mod);
-                    for (0..union_obj.field_names.len) |field_index| {
-                        const field_ty = union_obj.field_types.get(ip)[field_index].toType();
-                        const field_align = if (union_obj.field_aligns.len == 0)
-                            .none
-                        else
-                            union_obj.field_aligns.get(ip)[field_index];
-                        if (!(field_ty.hasRuntimeBitsAdvanced(mod, false, strat) catch |err| switch (err) {
-                            error.NeedLazy => return .{ .val = (try mod.intern(.{ .int = .{
-                                .ty = .comptime_int_type,
-                                .storage = .{ .lazy_align = ty.toIntern() },
-                            } })).toValue() },
-                            else => |e| return e,
-                        })) continue;
-
-                        const field_align_bytes: Alignment = if (field_align != .none)
-                            field_align
-                        else switch (try field_ty.abiAlignmentAdvanced(mod, strat)) {
-                            .scalar => |a| a,
-                            .val => switch (strat) {
-                                .eager => unreachable, // struct layout not resolved
-                                .sema => unreachable, // handled above
-                                .lazy => return .{ .val = (try mod.intern(.{ .int = .{
-                                    .ty = .comptime_int_type,
-                                    .storage = .{ .lazy_align = ty.toIntern() },
-                                } })).toValue() },
-                            },
-                        };
-                        max_align = max_align.max(field_align_bytes);
-                    }
-                    return .{ .scalar = max_align };
+                    return .{ .scalar = union_type.flagsPtr(ip).alignment };
                 },
                 .opaque_type => return .{ .scalar = .@"1" },
                 .enum_type => |enum_type| return .{
@@ -1451,8 +1405,8 @@ pub const Type = struct {
                         },
                         .eager => {},
                     }
-                    const union_obj = ip.loadUnionType(union_type);
-                    return AbiSizeAdvanced{ .scalar = mod.unionAbiSize(union_obj) };
+
+                    return .{ .scalar = union_type.size(ip).* };
                 },
                 .opaque_type => unreachable, // no size available
                 .enum_type => |enum_type| return AbiSizeAdvanced{ .scalar = enum_type.tag_ty.toType().abiSize(mod) },
@@ -2680,11 +2634,11 @@ pub const Type = struct {
                             if (struct_type.flagsPtr(ip).field_types_wip)
                                 return false;
 
-                            try sema.resolveTypeFieldsStruct(ty.toIntern(), struct_type);
-
                             struct_type.flagsPtr(ip).requires_comptime = .wip;
                             errdefer struct_type.flagsPtr(ip).requires_comptime = .unknown;
 
+                            try sema.resolveTypeFieldsStruct(ty.toIntern(), struct_type);
+
                             for (0..struct_type.field_types.len) |i_usize| {
                                 const i: u32 = @intCast(i_usize);
                                 if (struct_type.fieldIsComptime(ip, i)) continue;
@@ -2723,12 +2677,12 @@ pub const Type = struct {
                         if (union_type.flagsPtr(ip).status == .field_types_wip)
                             return false;
 
-                        try sema.resolveTypeFieldsUnion(ty, union_type);
-                        const union_obj = ip.loadUnionType(union_type);
+                        union_type.flagsPtr(ip).requires_comptime = .wip;
+                        errdefer union_type.flagsPtr(ip).requires_comptime = .unknown;
 
-                        union_obj.flagsPtr(ip).requires_comptime = .wip;
-                        errdefer union_obj.flagsPtr(ip).requires_comptime = .unknown;
+                        try sema.resolveTypeFieldsUnion(ty, union_type);
 
+                        const union_obj = ip.loadUnionType(union_type);
                         for (0..union_obj.field_types.len) |field_idx| {
                             const field_ty = union_obj.field_types.get(ip)[field_idx];
                             if (try field_ty.toType().comptimeOnlyAdvanced(mod, opt_sema)) {