Commit 2fddd767ba

kcbanner <kcbanner@gmail.com>
2023-09-21 05:53:06
sema: add support for unions in readFromMemory and writeToMemory
1 parent ce919cc
src/arch/wasm/CodeGen.zig
@@ -3259,7 +3259,10 @@ fn lowerConstant(func: *CodeGen, arg_val: Value, ty: Type) InnerError!WValue {
         .un => |un| {
             // in this case we have a packed union which will not be passed by reference.
             const union_obj = mod.typeToUnion(ty).?;
-            const field_index = mod.unionTagFieldIndex(union_obj, un.tag.toValue()).?;
+            const field_index = mod.unionTagFieldIndex(union_obj, un.tag.toValue()) orelse f: {
+                assert(union_obj.getLayout(ip) == .Extern);
+                break :f mod.unionLargestField(union_obj).index;
+            };
             const field_ty = union_obj.field_types.get(ip)[field_index].toType();
             return func.lowerConstant(un.val.toValue(), field_ty);
         },
src/codegen/c.zig
@@ -1439,7 +1439,10 @@ pub const DeclGen = struct {
                 }
 
                 const union_obj = mod.typeToUnion(ty).?;
-                const field_i = mod.unionTagFieldIndex(union_obj, un.tag.toValue()).?;
+                const field_i = mod.unionTagFieldIndex(union_obj, un.tag.toValue()) orelse f: {
+                    assert(union_obj.getLayout(ip) == .Extern);
+                    break :f mod.unionLargestField(union_obj).index;
+                };
                 const field_ty = union_obj.field_types.get(ip)[field_i].toType();
                 const field_name = union_obj.field_names.get(ip)[field_i];
                 if (union_obj.getLayout(ip) == .Packed) {
src/codegen/llvm.zig
@@ -4108,7 +4108,10 @@ pub const Object = struct {
                 if (layout.payload_size == 0) return o.lowerValue(un.tag);
 
                 const union_obj = mod.typeToUnion(ty).?;
-                const field_index = mod.unionTagFieldIndex(union_obj, un.tag.toValue()).?;
+                const field_index = mod.unionTagFieldIndex(union_obj, un.tag.toValue()) orelse f: {
+                    assert(union_obj.getLayout(ip) == .Extern);
+                    break :f mod.unionLargestField(union_obj).index;
+                };
 
                 const field_ty = union_obj.field_types.get(ip)[field_index].toType();
                 if (union_obj.getLayout(ip) == .Packed) {
src/codegen/spirv.zig
@@ -838,7 +838,10 @@ pub const DeclGen = struct {
                         return dg.todo("packed union constants", .{});
                     }
 
-                    const active_field = ty.unionTagFieldIndex(un.tag.toValue(), dg.module).?;
+                    const active_field = ty.unionTagFieldIndex(un.tag.toValue(), dg.module) orelse f: {
+                        assert(union_obj.getLayout(ip) == .Extern);
+                        break :f mod.unionLargestField(union_obj).index;
+                    };
                     const active_field_ty = union_obj.field_types.get(ip)[active_field].toType();
 
                     const has_tag = layout.tag_size != 0;
src/codegen.zig
@@ -583,7 +583,11 @@ pub fn generateSymbol(
             }
 
             const union_obj = mod.typeToUnion(typed_value.ty).?;
-            const field_index = typed_value.ty.unionTagFieldIndex(un.tag.toValue(), mod).?;
+            const field_index = typed_value.ty.unionTagFieldIndex(un.tag.toValue(), mod) orelse f: {
+                assert(union_obj.getLayout(ip) == .Extern);
+                break :f mod.unionLargestField(union_obj).index;
+            };
+
             const field_ty = union_obj.field_types.get(ip)[field_index].toType();
             if (!field_ty.hasRuntimeBits(mod)) {
                 try code.appendNTimes(0xaa, math.cast(usize, layout.payload_size) orelse return error.Overflow);
src/Module.zig
@@ -6607,6 +6607,7 @@ pub fn unionFieldNormalAlignment(mod: *Module, u: InternPool.UnionType, field_in
 
 pub fn unionTagFieldIndex(mod: *Module, u: InternPool.UnionType, enum_tag: Value) ?u32 {
     const ip = &mod.intern_pool;
+    if (enum_tag.toIntern() == .undef) return null;
     assert(ip.typeOf(enum_tag.toIntern()) == u.enum_tag_ty);
     const enum_type = ip.indexToKey(u.enum_tag_ty).enum_type;
     return enum_type.tagValueIndex(ip, enum_tag.toIntern());
@@ -6672,3 +6673,30 @@ pub fn structPackedFieldBitOffset(
     }
     unreachable; // index out of bounds
 }
+
+pub fn unionLargestField(mod: *Module, u: InternPool.UnionType) struct {
+    ty: Type,
+    index: u32,
+    size: u64,
+} {
+    const fields = u.field_types.get(&mod.intern_pool);
+    assert(fields.len != 0);
+    var largest_field_ty: Type = undefined;
+    var largest_field_size: u64 = 0;
+    var largest_field_index: u32 = 0;
+    for (fields, 0..) |union_field, i| {
+        const field_ty = union_field.toType();
+        const size: u32 = @intCast(field_ty.abiSize(mod));
+        if (size > largest_field_size) {
+            largest_field_ty = field_ty;
+            largest_field_size = size;
+            largest_field_index = @intCast(i);
+        }
+    }
+
+    return .{
+        .ty = largest_field_ty,
+        .index = largest_field_index,
+        .size = largest_field_size,
+    };
+}
src/Sema.zig
@@ -29740,10 +29740,15 @@ fn storePtrVal(
                 error.OutOfMemory => return error.OutOfMemory,
                 error.ReinterpretDeclRef => unreachable,
                 error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already
-                error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{mut_kit.ty.fmt(mod)}),
+                error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{operand_ty.fmt(mod)}),
             };
 
-            reinterpret.val_ptr.* = (try (try Value.readFromMemory(mut_kit.ty, mod, buffer, sema.arena)).intern(mut_kit.ty, mod)).toValue();
+            const val = Value.readFromMemory(mut_kit.ty, mod, buffer, sema.arena) catch |err| switch (err) {
+                error.OutOfMemory => return error.OutOfMemory,
+                error.IllDefinedMemoryLayout => unreachable,
+                error.Unimplemented => return sema.fail(block, src, "TODO: implement readFromMemory for type '{}'", .{mut_kit.ty.fmt(mod)}),
+            };
+            reinterpret.val_ptr.* = (try val.intern(mut_kit.ty, mod)).toValue();
         },
         .bad_decl_ty, .bad_ptr_ty => {
             // TODO show the decl declaration site in a note and explain whether the decl
@@ -30655,7 +30660,12 @@ fn bitCastVal(
         error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already
         error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{old_ty.fmt(mod)}),
     };
-    return try Value.readFromMemory(new_ty, mod, buffer[buffer_offset..], sema.arena);
+
+    return Value.readFromMemory(new_ty, mod, buffer[buffer_offset..], sema.arena) catch |err| switch (err) {
+        error.OutOfMemory => return error.OutOfMemory,
+        error.IllDefinedMemoryLayout => unreachable,
+        error.Unimplemented => return sema.fail(block, src, "TODO: implement readFromMemory for type '{}'", .{new_ty.fmt(mod)}),
+    };
 }
 
 fn coerceArrayPtrToSlice(
src/type.zig
@@ -1929,8 +1929,12 @@ pub const Type = struct {
     pub fn unionFieldType(ty: Type, enum_tag: Value, mod: *Module) Type {
         const ip = &mod.intern_pool;
         const union_obj = mod.typeToUnion(ty).?;
-        const index = mod.unionTagFieldIndex(union_obj, enum_tag).?;
-        return union_obj.field_types.get(ip)[index].toType();
+        const union_fields = union_obj.field_types.get(ip);
+        if (mod.unionTagFieldIndex(union_obj, enum_tag)) |index| {
+            return union_fields[index].toType();
+        } else {
+            return mod.unionLargestField(union_obj).ty;
+        }
     }
 
     pub fn unionTagFieldIndex(ty: Type, enum_tag: Value, mod: *Module) ?u32 {
src/value.zig
@@ -704,7 +704,22 @@ pub const Value = struct {
             },
             .Union => switch (ty.containerLayout(mod)) {
                 .Auto => return error.IllDefinedMemoryLayout,
-                .Extern => return error.Unimplemented,
+                .Extern => {
+                    const union_obj = mod.typeToUnion(ty).?;
+                    const union_tag = val.unionTag(mod);
+
+                    const field_type, const field_index = if (mod.unionTagFieldIndex(union_obj, union_tag)) |field_index| .{
+                        union_obj.field_types.get(&mod.intern_pool)[field_index].toType(),
+                        field_index,
+                    } else f: {
+                        const largest_field = mod.unionLargestField(union_obj);
+                        break :f .{ largest_field.ty, largest_field.index };
+                    };
+
+                    const field_val = try val.fieldValue(mod, field_index);
+                    const byte_count = @as(usize, @intCast(field_type.abiSize(mod)));
+                    return writeToMemory(field_val, field_type, mod, buffer[0..byte_count]);
+                },
                 .Packed => {
                     const byte_count = (@as(usize, @intCast(ty.bitSize(mod))) + 7) / 8;
                     return writeToPackedMemory(val, ty, mod, buffer[0..byte_count], 0);
@@ -856,7 +871,11 @@ pub const Value = struct {
         mod: *Module,
         buffer: []const u8,
         arena: Allocator,
-    ) Allocator.Error!Value {
+    ) error{
+        IllDefinedMemoryLayout,
+        Unimplemented,
+        OutOfMemory,
+    }!Value {
         const ip = &mod.intern_pool;
         const target = mod.getTarget();
         const endian = target.cpu.arch.endian();
@@ -966,6 +985,26 @@ pub const Value = struct {
                     .name = name,
                 } })).toValue();
             },
+            .Union => switch (ty.containerLayout(mod)) {
+                .Auto => return error.IllDefinedMemoryLayout,
+                .Extern => {
+                    const union_obj = mod.typeToUnion(ty).?;
+                    const largest_field = mod.unionLargestField(union_obj);
+                    const field_size: usize = @intCast(largest_field.size);
+                    const val = try (try readFromMemory(largest_field.ty, mod, buffer[0..field_size], arena)).intern(largest_field.ty, mod);
+                    return (try mod.intern(.{
+                        .un = .{
+                            .ty = ty.toIntern(),
+                            .tag = .undef,
+                            .val = val,
+                        },
+                    })).toValue();
+                },
+                .Packed => {
+                    const byte_count = (@as(usize, @intCast(ty.bitSize(mod))) + 7) / 8;
+                    return readFromPackedMemory(ty, mod, buffer[0..byte_count], 0, arena);
+                },
+            },
             .Pointer => {
                 assert(!ty.isSlice(mod)); // No well defined layout.
                 const int_val = try readFromMemory(Type.usize, mod, buffer, arena);
@@ -987,7 +1026,7 @@ pub const Value = struct {
                     },
                 } })).toValue();
             },
-            else => @panic("TODO implement readFromMemory for more types"),
+            else => return error.Unimplemented,
         }
     }
 
@@ -1001,7 +1040,10 @@ pub const Value = struct {
         buffer: []const u8,
         bit_offset: usize,
         arena: Allocator,
-    ) Allocator.Error!Value {
+    ) error{
+        IllDefinedMemoryLayout,
+        OutOfMemory,
+    }!Value {
         const ip = &mod.intern_pool;
         const target = mod.getTarget();
         const endian = target.cpu.arch.endian();
@@ -1098,6 +1140,21 @@ pub const Value = struct {
                     .storage = .{ .elems = field_vals },
                 } })).toValue();
             },
+            .Union => switch (ty.containerLayout(mod)) {
+                .Auto => return error.IllDefinedMemoryLayout,
+                .Extern => unreachable, // Handled by non-packed readFromMemory
+                .Packed => {
+                    const union_obj = mod.typeToUnion(ty).?;
+                    const largest_field = mod.unionLargestField(union_obj);
+                    const un_tag_val = try mod.enumValueFieldIndex(union_obj.enum_tag_ty.toType(), largest_field.index);
+                    const un_val = try (try readFromPackedMemory(largest_field.ty, mod, buffer, bit_offset, arena)).intern(largest_field.ty, mod);
+                    return (try mod.intern(.{ .un = .{
+                        .ty = ty.toIntern(),
+                        .tag = un_tag_val.ip_index,
+                        .val = un_val,
+                    } })).toValue();
+                },
+            },
             .Pointer => {
                 assert(!ty.isSlice(mod)); // No well defined layout.
                 return readFromPackedMemory(Type.usize, mod, buffer, bit_offset, arena);
@@ -1713,6 +1770,14 @@ pub const Value = struct {
         };
     }
 
+    pub fn unionValue(val: Value, mod: *Module) Value {
+        if (val.ip_index == .none) return val.castTag(.@"union").?.data.val;
+        return switch (mod.intern_pool.indexToKey(val.toIntern())) {
+            .un => |un| un.val.toValue(),
+            else => unreachable,
+        };
+    }
+
     /// Returns a pointer to the element value at the index.
     pub fn elemPtr(
         val: Value,
test/behavior/comptime_memory.zig
@@ -1,3 +1,4 @@
+const std = @import("std");
 const builtin = @import("builtin");
 const endian = builtin.cpu.arch.endian();
 const testing = @import("std").testing;
@@ -452,3 +453,25 @@ test "type pun null pointer-like optional" {
     // note that expectEqual hides the bug
     try testing.expect(@as(*const ?*i8, @ptrCast(&p)).* == null);
 }
+
+test "reinterpret extern union" {
+    const U = extern union {
+        a: u32,
+        b: u64,
+    };
+
+    comptime var u: U = undefined;
+    comptime @memset(std.mem.asBytes(&u), 42);
+    try testing.expectEqual(@as(u64, 0x2a2a2a2a_2a2a2a2a), u.b);
+}
+
+test "reinterpret packed union" {
+    const U = packed union {
+        a: u32,
+        b: u64,
+    };
+
+    comptime var u: U = undefined;
+    comptime @memset(std.mem.asBytes(&u), 42);
+    try testing.expectEqual(@as(u64, 0x2a2a2a2a_2a2a2a2a), u.b);
+}