Commit 240f9d740d

Robin Voetter <robin@voetter.nl>
2023-09-16 02:53:14
spirv: lower union initialization at runtime
1 parent d06862b
Changed files (1)
src
codegen
src/codegen/spirv.zig
@@ -1301,6 +1301,72 @@ pub const DeclGen = struct {
                 .vector_type, .anon_struct_type => unreachable, // TODO
                 else => unreachable,
             },
+            .un => |un| {
+                // To initialize a union, generate a temporary variable with the
+                // type that has the right field active, then pointer-cast and store
+                // the active field, and finally load and return the entire union.
+
+                const layout = ty.unionGetLayout(mod);
+                const union_ty = mod.typeToUnion(ty).?;
+
+                if (union_ty.getLayout(ip) == .Packed) {
+                    return self.todo("packed union types", .{});
+                } else if (layout.payload_size == 0) {
+                    // No payload, so represent this as just the tag type.
+                    return try self.constant(ty.unionTagTypeSafety(mod).?, un.tag.toValue(), .indirect);
+                }
+
+                const has_tag = layout.tag_size != 0;
+                const tag_first = layout.tag_align >= layout.payload_align;
+
+                const un_ptr_ty_ref = try self.spv.ptrType(result_ty_ref, .Function);
+
+                const var_id = self.spv.allocId();
+                try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
+                    .id_result_type = self.typeId(un_ptr_ty_ref),
+                    .id_result = var_id,
+                    .storage_class = .Function,
+                });
+
+                const index_ty_ref = try self.intType(.unsigned, 32);
+
+                if (has_tag) {
+                    const tag_index: u32 = if (tag_first) 0 else 1;
+                    const index_id = try self.spv.constInt(index_ty_ref, tag_index);
+                    const tag_ty = ty.unionTagTypeSafety(mod).?;
+                    const tag_ty_ref = try self.resolveType(tag_ty, .indirect);
+                    const tag_ptr_ty_ref = try self.spv.ptrType(tag_ty_ref, .Function);
+                    const ptr_id = try self.accessChain(tag_ptr_ty_ref, var_id, &.{index_id});
+                    const tag_id = try self.constant(tag_ty, un.tag.toValue(), .indirect);
+                    try self.func.body.emit(self.spv.gpa, .OpStore, .{
+                        .pointer = ptr_id,
+                        .object = tag_id,
+                    });
+                }
+
+                const pl_index: u32 = if (tag_first) 1 else 0;
+                const index_id = try self.spv.constInt(index_ty_ref, pl_index);
+                const active_field = ty.unionTagFieldIndex(un.tag.toValue(), mod).?;
+                const active_field_ty = union_ty.field_types.get(ip)[active_field].toType();
+                const active_field_ty_ref = try self.resolveType(active_field_ty, .indirect);
+                const active_field_ptr_ty_ref = try self.spv.ptrType(active_field_ty_ref, .Function);
+                const ptr_id = try self.accessChain(active_field_ptr_ty_ref, var_id, &.{index_id});
+                const value_id = try self.constant(active_field_ty, un.val.toValue(), .indirect);
+                try self.func.body.emit(self.spv.gpa, .OpStore, .{
+                    .pointer = ptr_id,
+                    .object = value_id,
+                });
+
+                // Just leave the padding fields uninitialized...
+
+                const result_id = self.spv.allocId();
+                try self.func.body.emit(self.spv.gpa, .OpLoad, .{
+                    .id_result_type = self.typeId(result_ty_ref),
+                    .id_result = result_id,
+                    .pointer = var_id,
+                });
+                return result_id;
+            },
             else => {
                 // The value cannot be generated directly, so generate it as an indirect constant,
                 // and then perform an OpLoad.
@@ -1417,7 +1483,7 @@ pub const DeclGen = struct {
         if (has_tag and tag_first) {
             const tag_ty_ref = try self.resolveType(union_obj.enum_tag_ty.toType(), .indirect);
             member_types.appendAssumeCapacity(tag_ty_ref);
-            member_names.appendAssumeCapacity(try self.spv.resolveString("tag"));
+            member_names.appendAssumeCapacity(try self.spv.resolveString("(tag)"));
         }
 
         const active_field = maybe_active_field orelse layout.most_aligned_field;
@@ -1426,7 +1492,7 @@ pub const DeclGen = struct {
         const active_field_size = if (active_field_ty.hasRuntimeBitsIgnoreComptime(mod)) blk: {
             const active_payload_ty_ref = try self.resolveType(active_field_ty, .indirect);
             member_types.appendAssumeCapacity(active_payload_ty_ref);
-            member_names.appendAssumeCapacity(try self.spv.resolveString("payload"));
+            member_names.appendAssumeCapacity(try self.spv.resolveString("(payload)"));
             break :blk active_field_ty.abiSize(mod);
         } else 0;
 
@@ -1434,19 +1500,19 @@ pub const DeclGen = struct {
         if (payload_padding_len != 0) {
             const payload_padding_ty_ref = try self.spv.arrayType(@as(u32, @intCast(payload_padding_len)), u8_ty_ref);
             member_types.appendAssumeCapacity(payload_padding_ty_ref);
-            member_names.appendAssumeCapacity(try self.spv.resolveString("payload_padding"));
+            member_names.appendAssumeCapacity(try self.spv.resolveString("(payload padding)"));
         }
 
         if (has_tag and !tag_first) {
             const tag_ty_ref = try self.resolveType(union_obj.enum_tag_ty.toType(), .indirect);
             member_types.appendAssumeCapacity(tag_ty_ref);
-            member_names.appendAssumeCapacity(try self.spv.resolveString("tag"));
+            member_names.appendAssumeCapacity(try self.spv.resolveString("(tag)"));
         }
 
         if (layout.padding != 0) {
             const padding_ty_ref = try self.spv.arrayType(layout.padding, u8_ty_ref);
             member_types.appendAssumeCapacity(padding_ty_ref);
-            member_names.appendAssumeCapacity(try self.spv.resolveString("padding"));
+            member_names.appendAssumeCapacity(try self.spv.resolveString("(padding)"));
         }
 
         const ty_ref = try self.spv.resolve(.{ .struct_type = .{