Commit f2a24b48e1

kcbanner <kcbanner@gmail.com>
2023-09-23 19:03:03
sema: rework the comptime representation of comptime unions
When the tag is not known, it's set to `.none`. In this case, the value is either an array of bytes (for extern unions) or an integer (for packed unions).
1 parent 2fddd76
src/arch/wasm/CodeGen.zig
@@ -3259,10 +3259,7 @@ fn lowerConstant(func: *CodeGen, arg_val: Value, ty: Type) InnerError!WValue {
         .un => |un| {
             // in this case we have a packed union which will not be passed by reference.
             const union_obj = mod.typeToUnion(ty).?;
-            const field_index = mod.unionTagFieldIndex(union_obj, un.tag.toValue()) orelse f: {
-                assert(union_obj.getLayout(ip) == .Extern);
-                break :f mod.unionLargestField(union_obj).index;
-            };
+            const field_index = mod.unionTagFieldIndex(union_obj, un.tag.toValue()).?;
             const field_ty = union_obj.field_types.get(ip)[field_index].toType();
             return func.lowerConstant(un.val.toValue(), field_ty);
         },
src/codegen/c.zig
@@ -1439,10 +1439,7 @@ pub const DeclGen = struct {
                 }
 
                 const union_obj = mod.typeToUnion(ty).?;
-                const field_i = mod.unionTagFieldIndex(union_obj, un.tag.toValue()) orelse f: {
-                    assert(union_obj.getLayout(ip) == .Extern);
-                    break :f mod.unionLargestField(union_obj).index;
-                };
+                const field_i = mod.unionTagFieldIndex(union_obj, un.tag.toValue()).?;
                 const field_ty = union_obj.field_types.get(ip)[field_i].toType();
                 const field_name = union_obj.field_names.get(ip)[field_i];
                 if (union_obj.getLayout(ip) == .Packed) {
src/codegen/llvm.zig
@@ -4108,28 +4108,28 @@ pub const Object = struct {
                 if (layout.payload_size == 0) return o.lowerValue(un.tag);
 
                 const union_obj = mod.typeToUnion(ty).?;
-                const field_index = mod.unionTagFieldIndex(union_obj, un.tag.toValue()) orelse f: {
-                    assert(union_obj.getLayout(ip) == .Extern);
-                    break :f mod.unionLargestField(union_obj).index;
-                };
+                const container_layout = union_obj.getLayout(ip);
+
+                var need_unnamed = false;
+                const payload = if (un.tag != .none) p: {
+                    const field_index = mod.unionTagFieldIndex(union_obj, un.tag.toValue()).?;
+                    const field_ty = union_obj.field_types.get(ip)[field_index].toType();
+                    if (container_layout == .Packed) {
+                        if (!field_ty.hasRuntimeBits(mod)) return o.builder.intConst(union_ty, 0);
+                        const small_int_val = try o.builder.castConst(
+                            if (field_ty.isPtrAtRuntime(mod)) .ptrtoint else .bitcast,
+                            try o.lowerValue(un.val),
+                            try o.builder.intType(@intCast(field_ty.bitSize(mod))),
+                        );
+                        return o.builder.convConst(.unsigned, small_int_val, union_ty);
+                    }
 
-                const field_ty = union_obj.field_types.get(ip)[field_index].toType();
-                if (union_obj.getLayout(ip) == .Packed) {
-                    if (!field_ty.hasRuntimeBits(mod)) return o.builder.intConst(union_ty, 0);
-                    const small_int_val = try o.builder.castConst(
-                        if (field_ty.isPtrAtRuntime(mod)) .ptrtoint else .bitcast,
-                        try o.lowerValue(un.val),
-                        try o.builder.intType(@intCast(field_ty.bitSize(mod))),
-                    );
-                    return o.builder.convConst(.unsigned, small_int_val, union_ty);
-                }
+                    // Sometimes we must make an unnamed struct because LLVM does
+                    // not support bitcasting our payload struct to the true union payload type.
+                    // Instead we use an unnamed struct and every reference to the global
+                    // must pointer cast to the expected type before accessing the union.
+                    need_unnamed = layout.most_aligned_field != field_index;
 
-                // Sometimes we must make an unnamed struct because LLVM does
-                // not support bitcasting our payload struct to the true union payload type.
-                // Instead we use an unnamed struct and every reference to the global
-                // must pointer cast to the expected type before accessing the union.
-                var need_unnamed = layout.most_aligned_field != field_index;
-                const payload = p: {
                     if (!field_ty.hasRuntimeBitsIgnoreComptime(mod)) {
                         const padding_len = layout.payload_size;
                         break :p try o.builder.undefConst(try o.builder.arrayType(padding_len, .i8));
@@ -4147,9 +4147,23 @@ pub const Object = struct {
                         try o.builder.structType(.@"packed", &.{ payload_ty, padding_ty }),
                         &.{ payload, try o.builder.undefConst(padding_ty) },
                     );
+                } else p: {
+                    assert(layout.tag_size == 0);
+                    const union_val = try o.lowerValue(un.val);
+                    if (container_layout == .Packed) {
+                        const bitcast_val = try o.builder.castConst(
+                            .bitcast,
+                            union_val,
+                            try o.builder.intType(@intCast(ty.bitSize(mod))),
+                        );
+                        return o.builder.convConst(.unsigned, bitcast_val, union_ty);
+                    }
+
+                    need_unnamed = true;
+                    break :p union_val;
                 };
-                const payload_ty = payload.typeOf(&o.builder);
 
+                const payload_ty = payload.typeOf(&o.builder);
                 if (layout.tag_size == 0) return o.builder.structConst(if (need_unnamed)
                     try o.builder.structType(union_ty.structKind(&o.builder), &.{payload_ty})
                 else
src/codegen/spirv.zig
@@ -838,10 +838,7 @@ pub const DeclGen = struct {
                         return dg.todo("packed union constants", .{});
                     }
 
-                    const active_field = ty.unionTagFieldIndex(un.tag.toValue(), dg.module) orelse f: {
-                        assert(union_obj.getLayout(ip) == .Extern);
-                        break :f mod.unionLargestField(union_obj).index;
-                    };
+                    const active_field = ty.unionTagFieldIndex(un.tag.toValue(), dg.module).?;
                     const active_field_ty = union_obj.field_types.get(ip)[active_field].toType();
 
                     const has_tag = layout.tag_size != 0;
src/codegen.zig
@@ -583,10 +583,7 @@ pub fn generateSymbol(
             }
 
             const union_obj = mod.typeToUnion(typed_value.ty).?;
-            const field_index = typed_value.ty.unionTagFieldIndex(un.tag.toValue(), mod) orelse f: {
-                assert(union_obj.getLayout(ip) == .Extern);
-                break :f mod.unionLargestField(union_obj).index;
-            };
+            const field_index = typed_value.ty.unionTagFieldIndex(un.tag.toValue(), mod).?;
 
             const field_ty = union_obj.field_types.get(ip)[field_index].toType();
             if (!field_ty.hasRuntimeBits(mod)) {
src/InternPool.zig
@@ -1105,7 +1105,10 @@ pub const Key = union(enum) {
     pub const Union = extern struct {
         /// This is the union type; not the field type.
         ty: Index,
-        /// Indicates the active field.
+        /// Indicates the active field. This could be `none`, which indicates the tag is not known. `none` is only a valid value for extern and packed unions.
+        /// In those cases, the type of `val` is:
+        ///   extern: a u8 array of the same byte length as the union
+        ///   packed: an unsigned integer with the same bit size as the union
         tag: Index,
         /// The value of the active field.
         val: Index,
@@ -5130,7 +5133,6 @@ pub fn get(ip: *InternPool, gpa: Allocator, key: Key) Allocator.Error!Index {
 
         .un => |un| {
             assert(un.ty != .none);
-            assert(un.tag != .none);
             assert(un.val != .none);
             ip.items.appendAssumeCapacity(.{
                 .tag = .union_value,
src/Module.zig
@@ -5823,7 +5823,7 @@ pub fn markReferencedDeclsAlive(mod: *Module, val: Value) Allocator.Error!void {
         .aggregate => |aggregate| for (aggregate.storage.values()) |elem|
             try mod.markReferencedDeclsAlive(elem.toValue()),
         .un => |un| {
-            try mod.markReferencedDeclsAlive(un.tag.toValue());
+            if (un.tag != .none) try mod.markReferencedDeclsAlive(un.tag.toValue());
             try mod.markReferencedDeclsAlive(un.val.toValue());
         },
         else => {},
@@ -6607,7 +6607,7 @@ pub fn unionFieldNormalAlignment(mod: *Module, u: InternPool.UnionType, field_in
 
 pub fn unionTagFieldIndex(mod: *Module, u: InternPool.UnionType, enum_tag: Value) ?u32 {
     const ip = &mod.intern_pool;
-    if (enum_tag.toIntern() == .undef) return null;
+    if (enum_tag.toIntern() == .none) return null;
     assert(ip.typeOf(enum_tag.toIntern()) == u.enum_tag_ty);
     const enum_type = ip.indexToKey(u.enum_tag_ty).enum_type;
     return enum_type.tagValueIndex(ip, enum_tag.toIntern());
@@ -6673,30 +6673,3 @@ pub fn structPackedFieldBitOffset(
     }
     unreachable; // index out of bounds
 }
-
-pub fn unionLargestField(mod: *Module, u: InternPool.UnionType) struct {
-    ty: Type,
-    index: u32,
-    size: u64,
-} {
-    const fields = u.field_types.get(&mod.intern_pool);
-    assert(fields.len != 0);
-    var largest_field_ty: Type = undefined;
-    var largest_field_size: u64 = 0;
-    var largest_field_index: u32 = 0;
-    for (fields, 0..) |union_field, i| {
-        const field_ty = union_field.toType();
-        const size: u32 = @intCast(field_ty.abiSize(mod));
-        if (size > largest_field_size) {
-            largest_field_ty = field_ty;
-            largest_field_size = size;
-            largest_field_index = @intCast(i);
-        }
-    }
-
-    return .{
-        .ty = largest_field_ty,
-        .index = largest_field_index,
-        .size = largest_field_size,
-    };
-}
src/Sema.zig
@@ -12124,7 +12124,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
 
         const analyze_body = if (union_originally) blk: {
             const item_val = sema.resolveConstLazyValue(block, .unneeded, item, undefined) catch unreachable;
-            const field_ty = maybe_union_ty.unionFieldType(item_val, mod);
+            const field_ty = maybe_union_ty.unionFieldType(item_val, mod).?;
             break :blk field_ty.zigTypeTag(mod) != .NoReturn;
         } else true;
 
@@ -12250,7 +12250,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
 
                 const analyze_body = if (union_originally) blk: {
                     const item_val = sema.resolveConstValue(block, .unneeded, item, undefined) catch unreachable;
-                    const field_ty = maybe_union_ty.unionFieldType(item_val, mod);
+                    const field_ty = maybe_union_ty.unionFieldType(item_val, mod).?;
                     break :blk field_ty.zigTypeTag(mod) != .NoReturn;
                 } else true;
 
@@ -12304,7 +12304,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
             const analyze_body = if (union_originally)
                 for (items) |item| {
                     const item_val = sema.resolveConstValue(block, .unneeded, item, undefined) catch unreachable;
-                    const field_ty = maybe_union_ty.unionFieldType(item_val, mod);
+                    const field_ty = maybe_union_ty.unionFieldType(item_val, mod).?;
                     if (field_ty.zigTypeTag(mod) != .NoReturn) break true;
                 } else false
             else
@@ -12456,7 +12456,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
                     case_block.wip_capture_scope = child_block.wip_capture_scope;
 
                     const analyze_body = if (union_originally) blk: {
-                        const field_ty = maybe_union_ty.unionFieldType(item_val, mod);
+                        const field_ty = maybe_union_ty.unionFieldType(item_val, mod).?;
                         break :blk field_ty.zigTypeTag(mod) != .NoReturn;
                     } else true;
 
@@ -16496,7 +16496,7 @@ fn analyzeCmpUnionTag(
 
     if (try sema.resolveMaybeUndefVal(coerced_tag)) |enum_val| {
         if (enum_val.isUndef(mod)) return mod.undefRef(Type.bool);
-        const field_ty = union_ty.unionFieldType(enum_val, mod);
+        const field_ty = union_ty.unionFieldType(enum_val, mod).?;
         if (field_ty.zigTypeTag(mod) == .NoReturn) {
             return .bool_false;
         }
@@ -27208,7 +27208,11 @@ fn unionFieldVal(
                 if (tag_matches) {
                     return Air.internedToRef(un.val);
                 } else {
-                    const old_ty = union_ty.unionFieldType(un.tag.toValue(), mod);
+                    const old_ty = if (un.tag == .none)
+                        ip.typeOf(un.val).toType()
+                    else
+                        union_ty.unionFieldType(un.tag.toValue(), mod).?;
+
                     if (try sema.bitCastVal(block, src, un.val.toValue(), old_ty, field_ty, 0)) |new_val| {
                         return Air.internedToRef(new_val.toIntern());
                     }
src/type.zig
@@ -1658,6 +1658,7 @@ pub const Type = struct {
                     const field_ty = union_obj.field_types.get(ip)[field_index];
                     size = @max(size, try bitSizeAdvanced(field_ty.toType(), mod, opt_sema));
                 }
+
                 return size;
             },
             .opaque_type => unreachable,
@@ -1926,15 +1927,12 @@ pub const Type = struct {
         return union_obj.enum_tag_ty.toType();
     }
 
-    pub fn unionFieldType(ty: Type, enum_tag: Value, mod: *Module) Type {
+    pub fn unionFieldType(ty: Type, enum_tag: Value, mod: *Module) ?Type {
         const ip = &mod.intern_pool;
         const union_obj = mod.typeToUnion(ty).?;
         const union_fields = union_obj.field_types.get(ip);
-        if (mod.unionTagFieldIndex(union_obj, enum_tag)) |index| {
-            return union_fields[index].toType();
-        } else {
-            return mod.unionLargestField(union_obj).ty;
-        }
+        const index = mod.unionTagFieldIndex(union_obj, enum_tag) orelse return null;
+        return union_fields[index].toType();
     }
 
     pub fn unionTagFieldIndex(ty: Type, enum_tag: Value, mod: *Module) ?u32 {
src/TypedValue.zig
@@ -92,10 +92,14 @@ pub fn print(
                     .val = union_val.tag,
                 }, writer, level - 1, mod);
                 try writer.writeAll(" = ");
-                try print(.{
-                    .ty = ty.unionFieldType(union_val.tag, mod),
-                    .val = union_val.val,
-                }, writer, level - 1, mod);
+                if (ty.unionFieldType(union_val.tag, mod)) |field_ty| {
+                    try print(.{
+                        .ty = field_ty,
+                        .val = union_val.val,
+                    }, writer, level - 1, mod);
+                } else {
+                    return writer.writeAll("(no tag)");
+                }
 
                 return writer.writeAll(" }");
             },
@@ -409,10 +413,14 @@ pub fn print(
                         .val = un.tag.toValue(),
                     }, writer, level - 1, mod);
                     try writer.writeAll(" = ");
-                    try print(.{
-                        .ty = ty.unionFieldType(un.tag.toValue(), mod),
-                        .val = un.val.toValue(),
-                    }, writer, level - 1, mod);
+                    if (ty.unionFieldType(un.tag.toValue(), mod)) |field_ty| {
+                        try print(.{
+                            .ty = field_ty,
+                            .val = un.val.toValue(),
+                            }, writer, level - 1, mod);
+                    } else {
+                        try writer.writeAll("(no tag)");
+                    }
                 } else try writer.writeAll("...");
                 return writer.writeAll(" }");
             },
src/value.zig
@@ -330,7 +330,7 @@ pub const Value = struct {
                 return mod.intern(.{ .un = .{
                     .ty = ty.toIntern(),
                     .tag = try pl.tag.intern(ty.unionTagTypeHypothetical(mod), mod),
-                    .val = try pl.val.intern(ty.unionFieldType(pl.tag, mod), mod),
+                    .val = try pl.val.intern(ty.unionFieldType(pl.tag, mod).?, mod),
                 } });
             },
         }
@@ -703,22 +703,20 @@ pub const Value = struct {
                 std.mem.writeInt(Int, buffer[0..@sizeOf(Int)], @as(Int, @intCast(int)), endian);
             },
             .Union => switch (ty.containerLayout(mod)) {
-                .Auto => return error.IllDefinedMemoryLayout,
+                .Auto => return error.IllDefinedMemoryLayout, // Sema is supposed to have emitted a compile error already
                 .Extern => {
                     const union_obj = mod.typeToUnion(ty).?;
                     const union_tag = val.unionTag(mod);
-
-                    const field_type, const field_index = if (mod.unionTagFieldIndex(union_obj, union_tag)) |field_index| .{
-                        union_obj.field_types.get(&mod.intern_pool)[field_index].toType(),
-                        field_index,
-                    } else f: {
-                        const largest_field = mod.unionLargestField(union_obj);
-                        break :f .{ largest_field.ty, largest_field.index };
-                    };
-
-                    const field_val = try val.fieldValue(mod, field_index);
-                    const byte_count = @as(usize, @intCast(field_type.abiSize(mod)));
-                    return writeToMemory(field_val, field_type, mod, buffer[0..byte_count]);
+                    if (mod.unionTagFieldIndex(union_obj, union_tag)) |field_index| {
+                        const field_type = union_obj.field_types.get(&mod.intern_pool)[field_index].toType();
+                        const field_val = try val.fieldValue(mod, field_index);
+                        const byte_count = @as(usize, @intCast(field_type.abiSize(mod)));
+                        return writeToMemory(field_val, field_type, mod, buffer[0..byte_count]);
+                    } else {
+                        const union_size = ty.abiSize(mod);
+                        const array_type = try mod.arrayType(.{ .len = union_size, .child = .u8_type });
+                        return writeToMemory(val.unionValue(mod), array_type, mod, buffer[0..union_size]);
+                    }
                 },
                 .Packed => {
                     const byte_count = (@as(usize, @intCast(ty.bitSize(mod))) + 7) / 8;
@@ -832,13 +830,11 @@ pub const Value = struct {
             .Union => {
                 const union_obj = mod.typeToUnion(ty).?;
                 switch (union_obj.getLayout(ip)) {
-                    .Auto => unreachable, // Sema is supposed to have emitted a compile error already
-                    .Extern => unreachable, // Handled in non-packed writeToMemory
+                    .Auto, .Extern => unreachable, // Handled in non-packed writeToMemory
                     .Packed => {
                         const field_index = mod.unionTagFieldIndex(union_obj, val.unionTag(mod)).?;
                         const field_type = union_obj.field_types.get(ip)[field_index].toType();
                         const field_val = try val.fieldValue(mod, field_index);
-
                         return field_val.writeToPackedMemory(field_type, mod, buffer, bit_offset);
                     },
                 }
@@ -988,17 +984,14 @@ pub const Value = struct {
             .Union => switch (ty.containerLayout(mod)) {
                 .Auto => return error.IllDefinedMemoryLayout,
                 .Extern => {
-                    const union_obj = mod.typeToUnion(ty).?;
-                    const largest_field = mod.unionLargestField(union_obj);
-                    const field_size: usize = @intCast(largest_field.size);
-                    const val = try (try readFromMemory(largest_field.ty, mod, buffer[0..field_size], arena)).intern(largest_field.ty, mod);
-                    return (try mod.intern(.{
-                        .un = .{
-                            .ty = ty.toIntern(),
-                            .tag = .undef,
-                            .val = val,
-                        },
-                    })).toValue();
+                    const union_size = ty.abiSize(mod);
+                    const array_ty = try mod.arrayType(.{ .len = union_size, .child = .u8_type });
+                    const val = try (try readFromMemory(array_ty, mod, buffer, arena)).intern(array_ty, mod);
+                    return (try mod.intern(.{ .un = .{
+                        .ty = ty.toIntern(),
+                        .tag = .none,
+                        .val = val,
+                    } })).toValue();
                 },
                 .Packed => {
                     const byte_count = (@as(usize, @intCast(ty.bitSize(mod))) + 7) / 8;
@@ -1141,17 +1134,17 @@ pub const Value = struct {
                 } })).toValue();
             },
             .Union => switch (ty.containerLayout(mod)) {
-                .Auto => return error.IllDefinedMemoryLayout,
-                .Extern => unreachable, // Handled by non-packed readFromMemory
+                .Auto, .Extern => unreachable, // Handled by non-packed readFromMemory
                 .Packed => {
-                    const union_obj = mod.typeToUnion(ty).?;
-                    const largest_field = mod.unionLargestField(union_obj);
-                    const un_tag_val = try mod.enumValueFieldIndex(union_obj.enum_tag_ty.toType(), largest_field.index);
-                    const un_val = try (try readFromPackedMemory(largest_field.ty, mod, buffer, bit_offset, arena)).intern(largest_field.ty, mod);
+                    const union_bits: u16 = @intCast(ty.bitSize(mod));
+                    // TODO: Remove after tests pass
+                    assert(union_bits != 0);
+                    const int_ty = try mod.intType(.unsigned, union_bits);
+                    const val = (try readFromPackedMemory(int_ty, mod, buffer, bit_offset, arena)).toIntern();
                     return (try mod.intern(.{ .un = .{
                         .ty = ty.toIntern(),
-                        .tag = un_tag_val.ip_index,
-                        .val = un_val,
+                        .tag = .none,
+                        .val = val,
                     } })).toValue();
                 },
             },
test/behavior/comptime_memory.zig
@@ -455,23 +455,52 @@ test "type pun null pointer-like optional" {
 }
 
 test "reinterpret extern union" {
-    const U = extern union {
-        a: u32,
-        b: u64,
-    };
+    {
+        const U = extern union {
+            a: u32,
+            b: u8 align(8),
+        };
 
-    comptime var u: U = undefined;
-    comptime @memset(std.mem.asBytes(&u), 42);
-    try testing.expectEqual(@as(u64, 0x2a2a2a2a_2a2a2a2a), u.b);
+        comptime var u: U = undefined;
+        comptime @memset(std.mem.asBytes(&u), 42);
+        try comptime testing.expect(0x2a2a2a2a == u.a);
+        try comptime testing.expect(42 == u.b);
+        try testing.expectEqual(@as(u32, 0x2a2a2a2a), u.a);
+        try testing.expectEqual(42, u.b);
+    }
 }
 
 test "reinterpret packed union" {
-    const U = packed union {
-        a: u32,
-        b: u64,
-    };
+    {
+        const U = packed union {
+            a: u32,
+            b: u8 align(8),
+        };
 
-    comptime var u: U = undefined;
-    comptime @memset(std.mem.asBytes(&u), 42);
-    try testing.expectEqual(@as(u64, 0x2a2a2a2a_2a2a2a2a), u.b);
+        comptime var u: U = undefined;
+        comptime @memset(std.mem.asBytes(&u), 42);
+        try comptime testing.expect(0x2a2a2a2a == u.a);
+        try comptime testing.expect(0x2a == u.b);
+        try testing.expectEqual(@as(u32, 0x2a2a2a2a), u.a);
+        try testing.expectEqual(0x2a, u.b);
+    }
+
+    {
+        const U = packed union {
+            a: u7,
+            b: u1,
+        };
+
+        const S = packed struct {
+            lsb: U,
+            msb: U,
+        };
+
+        comptime var s: S = undefined;
+        comptime @memset(std.mem.asBytes(&s), 0xaa);
+        try comptime testing.expectEqual(@as(u7, 0x2a), s.lsb.a);
+        try comptime testing.expectEqual(@as(u1, 0), s.lsb.b);
+        try comptime testing.expectEqual(@as(u7, 0x55), s.msb.a);
+        try comptime testing.expectEqual(@as(u1, 1), s.msb.b);
+    }
 }