Commit 3c3bc5af29

Andrew Kelley <andrew@ziglang.org>
2022-06-11 00:04:39
Sema: introduce bitSizeAdvanced to recursively resolve types
Same pattern as abiSizeAdvanced. Fixes compiler crash for nested packed structs.
1 parent 58bc562
Changed files (3)
src/Sema.zig
@@ -11747,11 +11747,11 @@ fn zirSizeOf(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.
 
 fn zirBitSizeOf(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
     const inst_data = sema.code.instructions.items(.data)[inst].un_node;
+    const src = inst_data.src();
     const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node };
-    const unresolved_operand_ty = try sema.resolveType(block, operand_src, inst_data.operand);
-    const operand_ty = try sema.resolveTypeFields(block, operand_src, unresolved_operand_ty);
+    const operand_ty = try sema.resolveType(block, operand_src, inst_data.operand);
     const target = sema.mod.getTarget();
-    const bit_size = operand_ty.bitSize(target);
+    const bit_size = try operand_ty.bitSizeAdvanced(target, sema.kit(block, src));
     return sema.addIntUnsigned(Type.comptime_int, bit_size);
 }
 
src/type.zig
@@ -3542,9 +3542,19 @@ pub const Type = extern union {
         );
     }
 
-    /// Asserts the type has the bit size already resolved.
     pub fn bitSize(ty: Type, target: Target) u64 {
-        return switch (ty.tag()) {
+        return bitSizeAdvanced(ty, target, null) catch unreachable;
+    }
+
+    /// If you pass `sema_kit`, any recursive type resolutions will happen if
+    /// necessary, possibly returning a CompileError. Passing `null` instead asserts
+    /// the type is fully resolved, and there will be no error, guaranteed.
+    pub fn bitSizeAdvanced(
+        ty: Type,
+        target: Target,
+        sema_kit: ?Module.WipAnalysis,
+    ) Module.CompileError!u64 {
+        switch (ty.tag()) {
             .fn_noreturn_no_args => unreachable, // represents machine code; not a pointer
             .fn_void_no_args => unreachable, // represents machine code; not a pointer
             .fn_naked_noreturn_no_args => unreachable, // represents machine code; not a pointer
@@ -3568,40 +3578,30 @@ pub const Type = extern union {
             .generic_poison => unreachable,
             .bound_fn => unreachable,
 
-            .void => 0,
-            .bool, .u1 => 1,
-            .u8, .i8 => 8,
-            .i16, .u16, .f16 => 16,
-            .u29 => 29,
-            .i32, .u32, .f32 => 32,
-            .i64, .u64, .f64 => 64,
-            .f80 => 80,
-            .u128, .i128, .f128 => 128,
+            .void => return 0,
+            .bool, .u1 => return 1,
+            .u8, .i8 => return 8,
+            .i16, .u16, .f16 => return 16,
+            .u29 => return 29,
+            .i32, .u32, .f32 => return 32,
+            .i64, .u64, .f64 => return 64,
+            .f80 => return 80,
+            .u128, .i128, .f128 => return 128,
 
             .@"struct" => {
-                const field_count = ty.structFieldCount();
-                if (field_count == 0) return 0;
-
-                const struct_obj = ty.castTag(.@"struct").?.data;
-                assert(struct_obj.haveFieldTypes());
-
-                switch (struct_obj.layout) {
-                    .Auto, .Extern => {
-                        var total: u64 = 0;
-                        for (struct_obj.fields.values()) |field| {
-                            total += field.ty.bitSize(target);
-                        }
-                        return total;
-                    },
-                    .Packed => return struct_obj.packedIntegerBits(target),
+                if (sema_kit) |sk| _ = try sk.sema.resolveTypeFields(sk.block, sk.src, ty);
+                var total: u64 = 0;
+                for (ty.structFields().values()) |field| {
+                    total += try bitSizeAdvanced(field.ty, target, sema_kit);
                 }
+                return total;
             },
 
             .tuple, .anon_struct => {
-                const tuple = ty.tupleFields();
+                if (sema_kit) |sk| _ = try sk.sema.resolveTypeFields(sk.block, sk.src, ty);
                 var total: u64 = 0;
-                for (tuple.types) |field_ty| {
-                    total += field_ty.bitSize(target);
+                for (ty.tupleFields().types) |field_ty| {
+                    total += try bitSizeAdvanced(field_ty, target, sema_kit);
                 }
                 return total;
             },
@@ -3609,37 +3609,35 @@ pub const Type = extern union {
             .enum_simple, .enum_full, .enum_nonexhaustive, .enum_numbered => {
                 var buffer: Payload.Bits = undefined;
                 const int_tag_ty = ty.intTagType(&buffer);
-                return int_tag_ty.bitSize(target);
+                return try bitSizeAdvanced(int_tag_ty, target, sema_kit);
             },
 
             .@"union", .union_tagged => {
+                if (sema_kit) |sk| _ = try sk.sema.resolveTypeFields(sk.block, sk.src, ty);
                 const union_obj = ty.cast(Payload.Union).?.data;
-
-                const fields = union_obj.fields;
-                if (fields.count() == 0) return 0;
-
                 assert(union_obj.haveFieldTypes());
 
                 var size: u64 = 0;
-                for (fields.values()) |field| {
-                    size = @maximum(size, field.ty.bitSize(target));
+                for (union_obj.fields.values()) |field| {
+                    size = @maximum(size, try bitSizeAdvanced(field.ty, target, sema_kit));
                 }
                 return size;
             },
 
             .vector => {
                 const payload = ty.castTag(.vector).?.data;
-                const elem_bit_size = payload.elem_type.bitSize(target);
+                const elem_bit_size = try bitSizeAdvanced(payload.elem_type, target, sema_kit);
                 return elem_bit_size * payload.len;
             },
-            .array_u8 => 8 * ty.castTag(.array_u8).?.data,
-            .array_u8_sentinel_0 => 8 * (ty.castTag(.array_u8_sentinel_0).?.data + 1),
+            .array_u8 => return 8 * ty.castTag(.array_u8).?.data,
+            .array_u8_sentinel_0 => return 8 * (ty.castTag(.array_u8_sentinel_0).?.data + 1),
             .array => {
                 const payload = ty.castTag(.array).?.data;
                 const elem_size = std.math.max(payload.elem_type.abiAlignment(target), payload.elem_type.abiSize(target));
                 if (elem_size == 0 or payload.len == 0)
-                    return 0;
-                return (payload.len - 1) * 8 * elem_size + payload.elem_type.bitSize(target);
+                    return @as(u64, 0);
+                const elem_bit_size = try bitSizeAdvanced(payload.elem_type, target, sema_kit);
+                return (payload.len - 1) * 8 * elem_size + elem_bit_size;
             },
             .array_sentinel => {
                 const payload = ty.castTag(.array_sentinel).?.data;
@@ -3647,14 +3645,15 @@ pub const Type = extern union {
                     payload.elem_type.abiAlignment(target),
                     payload.elem_type.abiSize(target),
                 );
-                return payload.len * 8 * elem_size + payload.elem_type.bitSize(target);
+                const elem_bit_size = try bitSizeAdvanced(payload.elem_type, target, sema_kit);
+                return payload.len * 8 * elem_size + elem_bit_size;
             },
 
             .isize,
             .usize,
             .@"anyframe",
             .anyframe_T,
-            => target.cpu.arch.ptrBitWidth(),
+            => return target.cpu.arch.ptrBitWidth(),
 
             .const_slice,
             .mut_slice,
@@ -3662,7 +3661,7 @@ pub const Type = extern union {
 
             .const_slice_u8,
             .const_slice_u8_sentinel_0,
-            => target.cpu.arch.ptrBitWidth() * 2,
+            => return target.cpu.arch.ptrBitWidth() * 2,
 
             .optional_single_const_pointer,
             .optional_single_mut_pointer,
@@ -3681,8 +3680,8 @@ pub const Type = extern union {
             },
 
             .pointer => switch (ty.castTag(.pointer).?.data.size) {
-                .Slice => target.cpu.arch.ptrBitWidth() * 2,
-                else => target.cpu.arch.ptrBitWidth(),
+                .Slice => return target.cpu.arch.ptrBitWidth() * 2,
+                else => return target.cpu.arch.ptrBitWidth(),
             },
 
             .manyptr_u8,
@@ -3708,7 +3707,7 @@ pub const Type = extern union {
             .error_set_merged,
             => return 16, // TODO revisit this when we have the concept of the error tag type
 
-            .int_signed, .int_unsigned => ty.cast(Payload.Bits).?.data,
+            .int_signed, .int_unsigned => return ty.cast(Payload.Bits).?.data,
 
             .optional => {
                 var buf: Payload.ElemType = undefined;
@@ -3722,7 +3721,8 @@ pub const Type = extern union {
                 // field and a boolean as the second. Since the child type's abi alignment is
                 // guaranteed to be >= that of bool's (1 byte) the added size is exactly equal
                 // to the child type's ABI alignment.
-                return child_type.bitSize(target) + 1;
+                const child_bit_size = try bitSizeAdvanced(child_type, target, sema_kit);
+                return child_bit_size + 1;
             },
 
             .error_union => {
@@ -3730,9 +3730,9 @@ pub const Type = extern union {
                 if (!payload.error_set.hasRuntimeBits() and !payload.payload.hasRuntimeBits()) {
                     return 0;
                 } else if (!payload.error_set.hasRuntimeBits()) {
-                    return payload.payload.bitSize(target);
+                    return payload.payload.bitSizeAdvanced(target, sema_kit);
                 } else if (!payload.payload.hasRuntimeBits()) {
-                    return payload.error_set.bitSize(target);
+                    return payload.error_set.bitSizeAdvanced(target, sema_kit);
                 }
                 @panic("TODO bitSize error union");
             },
@@ -3749,7 +3749,7 @@ pub const Type = extern union {
             .extern_options,
             .type_info,
             => @panic("TODO at some point we gotta resolve builtin types"),
-        };
+        }
     }
 
     pub fn isSinglePointer(self: Type) bool {
@@ -5506,6 +5506,7 @@ pub const Type = extern union {
         switch (ty.tag()) {
             .@"struct" => {
                 const struct_obj = ty.castTag(.@"struct").?.data;
+                assert(struct_obj.haveFieldTypes());
                 return struct_obj.fields.count();
             },
             .empty_struct, .empty_struct_literal => return 0,
test/behavior/packed-struct.zig
@@ -243,7 +243,12 @@ test "correct sizeOf and offsets in packed structs" {
 }
 
 test "nested packed structs" {
-    if (builtin.zig_backend != .stage1) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
 
     const S1 = packed struct { a: u8, b: u8, c: u8 };
 
@@ -253,7 +258,7 @@ test "nested packed structs" {
     const S3Padded = packed struct { s3: S3, pad: u16 };
 
     try expectEqual(48, @bitSizeOf(S3));
-    try expectEqual(6, @sizeOf(S3));
+    try expectEqual(@sizeOf(u48), @sizeOf(S3));
 
     try expectEqual(3, @offsetOf(S3, "y"));
     try expectEqual(24, @bitOffsetOf(S3, "y"));
@@ -273,7 +278,7 @@ test "nested packed structs" {
     const S6 = packed struct { a: i32, b: S4, c: i8 };
 
     const expectedBitSize = 80;
-    const expectedByteSize = expectedBitSize / 8;
+    const expectedByteSize = @sizeOf(u80);
     try expectEqual(expectedBitSize, @bitSizeOf(S5));
     try expectEqual(expectedByteSize, @sizeOf(S5));
     try expectEqual(expectedBitSize, @bitSizeOf(S6));