Commit 749307dbb2

Robin Voetter <robin@voetter.nl>
2023-09-17 20:25:32
spirv: air union_init
1 parent 98046b4
Changed files (2)
src
codegen
test
behavior
src/codegen/spirv.zig
@@ -801,77 +801,14 @@ pub const DeclGen = struct {
                 else => unreachable,
             },
             .un => |un| {
-                // To initialize a union, generate a temporary variable with the
-                // type that has the right field active, then pointer-cast and store
-                // the active field, and finally load and return the entire union.
-
-                const union_ty = mod.typeToUnion(ty).?;
-
-                if (union_ty.getLayout(ip) == .Packed) {
-                    return self.todo("packed union types", .{});
-                }
-
                 const active_field = ty.unionTagFieldIndex(un.tag.toValue(), mod).?;
                 const layout = self.unionLayout(ty, active_field);
+                const payload = if (layout.active_field_size != 0)
+                    try self.constant(layout.active_field_ty, un.val.toValue(), .indirect)
+                else
+                    null;
 
-                if (layout.payload_size == 0) {
-                    // No payload, so represent this as just the tag type.
-                    return try self.constant(ty.unionTagTypeSafety(mod).?, un.tag.toValue(), .indirect);
-                }
-
-                const un_active_ty_ref = try self.resolveUnionType(ty, active_field);
-                const un_active_ptr_ty_ref = try self.spv.ptrType(un_active_ty_ref, .Function);
-                const un_general_ptr_ty_ref = try self.spv.ptrType(result_ty_ref, .Function);
-
-                const var_id = self.spv.allocId();
-                try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
-                    .id_result_type = self.typeId(un_active_ptr_ty_ref),
-                    .id_result = var_id,
-                    .storage_class = .Function,
-                });
-
-                if (layout.tag_size != 0) {
-                    const tag_ty = ty.unionTagTypeSafety(mod).?;
-                    const tag_ty_ref = try self.resolveType(tag_ty, .indirect);
-                    const tag_ptr_ty_ref = try self.spv.ptrType(tag_ty_ref, .Function);
-                    const ptr_id = try self.accessChain(tag_ptr_ty_ref, var_id, &.{@as(u32, @intCast(layout.tag_index))});
-                    const tag_id = try self.constant(tag_ty, un.tag.toValue(), .indirect);
-                    try self.func.body.emit(self.spv.gpa, .OpStore, .{
-                        .pointer = ptr_id,
-                        .object = tag_id,
-                    });
-                }
-
-                if (layout.active_field_size != 0) {
-                    const active_field_ty_ref = try self.resolveType(layout.active_field_ty, .indirect);
-                    const active_field_ptr_ty_ref = try self.spv.ptrType(active_field_ty_ref, .Function);
-                    const ptr_id = try self.accessChain(active_field_ptr_ty_ref, var_id, &.{@as(u32, @intCast(layout.active_field_index))});
-                    const value_id = try self.constant(layout.active_field_ty, un.val.toValue(), .indirect);
-                    try self.func.body.emit(self.spv.gpa, .OpStore, .{
-                        .pointer = ptr_id,
-                        .object = value_id,
-                    });
-                }
-
-                // Just leave the padding fields uninitialized...
-                // TODO: Or should we initialize them with undef explicitly?
-
-                // Now cast the pointer and load it as the 'generic' union type.
-
-                const casted_var_id = self.spv.allocId();
-                try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
-                    .id_result_type = self.typeId(un_general_ptr_ty_ref),
-                    .id_result = casted_var_id,
-                    .operand = var_id,
-                });
-
-                const result_id = self.spv.allocId();
-                try self.func.body.emit(self.spv.gpa, .OpLoad, .{
-                    .id_result_type = self.typeId(result_ty_ref),
-                    .id_result = result_id,
-                    .pointer = casted_var_id,
-                });
-                return result_id;
+                return try self.unionInit(ty, active_field, payload);
             },
             .memoized_call => unreachable,
         }
@@ -1752,6 +1689,8 @@ pub const DeclGen = struct {
 
             .set_union_tag => return try self.airSetUnionTag(inst),
             .get_union_tag => try self.airGetUnionTag(inst),
+            .union_init => try self.airUnionInit(inst),
+
             .struct_field_val => try self.airStructFieldVal(inst),
 
             .struct_field_ptr_index_0 => try self.airStructFieldPtrIndex(inst, 0),
@@ -2573,8 +2512,12 @@ pub const DeclGen = struct {
         const union_ptr_id = try self.resolve(bin_op.lhs);
         const new_tag_id = try self.resolve(bin_op.rhs);
 
-        const ptr_id = try self.accessChain(tag_ptr_ty_ref, union_ptr_id, &.{layout.tag_index});
-        try self.store(tag_ty, ptr_id, new_tag_id, un_ptr_ty.isVolatilePtr(mod));
+        if (layout.payload_size == 0) {
+            try self.store(tag_ty, union_ptr_id, new_tag_id, un_ptr_ty.isVolatilePtr(mod));
+        } else {
+            const ptr_id = try self.accessChain(tag_ptr_ty_ref, union_ptr_id, &.{layout.tag_index});
+            try self.store(tag_ty, ptr_id, new_tag_id, un_ptr_ty.isVolatilePtr(mod));
+        }
     }
 
     fn airGetUnionTag(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -2594,6 +2537,113 @@ pub const DeclGen = struct {
         return try self.extractField(tag_ty, union_handle, layout.tag_index);
     }
 
+    fn unionInit(
+        self: *DeclGen,
+        ty: Type,
+        active_field: u32,
+        payload: ?IdRef,
+    ) !IdRef {
+        // To initialize a union, generate a temporary variable with the
+        // type that has the right field active, then pointer-cast and store
+        // the active field, and finally load and return the entire union.
+
+        const mod = self.module;
+        const ip = &mod.intern_pool;
+        const union_ty = mod.typeToUnion(ty).?;
+
+        if (union_ty.getLayout(ip) == .Packed) {
+            unreachable; // TODO
+        }
+
+        const maybe_tag_ty = ty.unionTagTypeSafety(mod);
+        const layout = self.unionLayout(ty, active_field);
+
+        const tag_int = if (layout.tag_size != 0) blk: {
+            const tag_ty = maybe_tag_ty.?;
+            const union_field_name = union_ty.field_names.get(ip)[active_field];
+            const enum_field_index = tag_ty.enumFieldIndex(union_field_name, mod).?;
+            const tag_val = try mod.enumValueFieldIndex(tag_ty, enum_field_index);
+            const tag_int_val = try tag_val.intFromEnum(tag_ty, mod);
+            break :blk tag_int_val.toUnsignedInt(mod);
+        } else 0;
+
+        if (layout.payload_size == 0) {
+            const tag_ty_ref = try self.resolveType(maybe_tag_ty.?, .direct);
+            return try self.constInt(tag_ty_ref, tag_int);
+        }
+
+        const un_active_ty_ref = try self.resolveUnionType(ty, active_field);
+        const un_active_ptr_ty_ref = try self.spv.ptrType(un_active_ty_ref, .Function);
+        const un_general_ty_ref = try self.resolveType(ty, .direct);
+        const un_general_ptr_ty_ref = try self.spv.ptrType(un_general_ty_ref, .Function);
+
+        const tmp_id = self.spv.allocId();
+        try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
+            .id_result_type = self.typeId(un_active_ptr_ty_ref),
+            .id_result = tmp_id,
+            .storage_class = .Function,
+        });
+
+        if (layout.tag_size != 0) {
+            const tag_ty_ref = try self.resolveType(maybe_tag_ty.?, .direct);
+            const tag_ptr_ty_ref = try self.spv.ptrType(tag_ty_ref, .Function);
+            const ptr_id = try self.accessChain(tag_ptr_ty_ref, tmp_id, &.{@as(u32, @intCast(layout.tag_index))});
+            const tag_id = try self.constInt(tag_ty_ref, tag_int);
+            try self.func.body.emit(self.spv.gpa, .OpStore, .{
+                .pointer = ptr_id,
+                .object = tag_id,
+            });
+        }
+
+        if (layout.active_field_size != 0) {
+            const active_field_ty_ref = try self.resolveType(layout.active_field_ty, .indirect);
+            const active_field_ptr_ty_ref = try self.spv.ptrType(active_field_ty_ref, .Function);
+            const ptr_id = try self.accessChain(active_field_ptr_ty_ref, tmp_id, &.{@as(u32, @intCast(layout.active_field_index))});
+            try self.func.body.emit(self.spv.gpa, .OpStore, .{
+                .pointer = ptr_id,
+                .object = payload.?,
+            });
+        } else {
+            assert(payload == null);
+        }
+
+        // Just leave the padding fields uninitialized...
+        // TODO: Or should we initialize them with undef explicitly?
+
+        // Now cast the pointer and load it as the 'generic' union type.
+
+        const casted_var_id = self.spv.allocId();
+        try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
+            .id_result_type = self.typeId(un_general_ptr_ty_ref),
+            .id_result = casted_var_id,
+            .operand = tmp_id,
+        });
+
+        const result_id = self.spv.allocId();
+        try self.func.body.emit(self.spv.gpa, .OpLoad, .{
+            .id_result_type = self.typeId(un_general_ty_ref),
+            .id_result = result_id,
+            .pointer = casted_var_id,
+        });
+
+        return result_id;
+    }
+
+    fn airUnionInit(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+        if (self.liveness.isUnused(inst)) return null;
+
+        const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
+        const extra = self.air.extraData(Air.UnionInit, ty_pl.payload).data;
+        const ty = self.typeOfIndex(inst);
+        const layout = self.unionLayout(ty, extra.field_index);
+
+        const payload = if (layout.active_field_size != 0)
+            try self.resolve(extra.init)
+        else
+            null;
+        return try self.unionInit(ty, extra.field_index, payload);
+    }
+
     fn airStructFieldVal(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
         if (self.liveness.isUnused(inst)) return null;
 
test/behavior/union.zig
@@ -997,7 +997,6 @@ test "cast from pointer to anonymous struct to pointer to union" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         const U = union(enum) {