Commit 98046b4c3c

Robin Voetter <robin@voetter.nl>
2023-09-17 18:36:45
spirv: air set_union_tag + improve load()/store()
1 parent 6f55a68
Changed files (2)
src
codegen
test
behavior
src/codegen/spirv.zig
@@ -1662,13 +1662,11 @@ pub const DeclGen = struct {
         return try self.convertToDirect(result_ty, result_id);
     }
 
-    fn load(self: *DeclGen, ptr_ty: Type, ptr_id: IdRef) !IdRef {
-        const mod = self.module;
-        const value_ty = ptr_ty.childType(mod);
+    fn load(self: *DeclGen, value_ty: Type, ptr_id: IdRef, is_volatile: bool) !IdRef {
         const indirect_value_ty_ref = try self.resolveType(value_ty, .indirect);
         const result_id = self.spv.allocId();
         const access = spec.MemoryAccess.Extended{
-            .Volatile = ptr_ty.isVolatilePtr(mod),
+            .Volatile = is_volatile,
         };
         try self.func.body.emit(self.spv.gpa, .OpLoad, .{
             .id_result_type = self.typeId(indirect_value_ty_ref),
@@ -1679,12 +1677,10 @@ pub const DeclGen = struct {
         return try self.convertToDirect(value_ty, result_id);
     }
 
-    fn store(self: *DeclGen, ptr_ty: Type, ptr_id: IdRef, value_id: IdRef) !void {
-        const mod = self.module;
-        const value_ty = ptr_ty.childType(mod);
+    fn store(self: *DeclGen, value_ty: Type, ptr_id: IdRef, value_id: IdRef, is_volatile: bool) !void {
         const indirect_value_id = try self.convertToIndirect(value_ty, value_id);
         const access = spec.MemoryAccess.Extended{
-            .Volatile = ptr_ty.isVolatilePtr(mod),
+            .Volatile = is_volatile,
         };
         try self.func.body.emit(self.spv.gpa, .OpStore, .{
             .pointer = ptr_id,
@@ -1754,6 +1750,7 @@ pub const DeclGen = struct {
             .ptr_elem_ptr   => try self.airPtrElemPtr(inst),
             .ptr_elem_val   => try self.airPtrElemVal(inst),
 
+            .set_union_tag => return try self.airSetUnionTag(inst),
             .get_union_tag => try self.airGetUnionTag(inst),
             .struct_field_val => try self.airStructFieldVal(inst),
 
@@ -2512,7 +2509,7 @@ pub const DeclGen = struct {
 
         const slice_ptr = try self.extractField(ptr_ty, slice_id, 0);
         const elem_ptr = try self.ptrAccessChain(ptr_ty_ref, slice_ptr, index_id, &.{});
-        return try self.load(slice_ty, elem_ptr);
+        return try self.load(slice_ty.childType(mod), elem_ptr, slice_ty.isVolatilePtr(mod));
     }
 
     fn ptrElemPtr(self: *DeclGen, ptr_ty: Type, ptr_id: IdRef, index_id: IdRef) !IdRef {
@@ -2548,25 +2545,41 @@ pub const DeclGen = struct {
     }
 
     fn airPtrElemVal(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+        if (self.liveness.isUnused(inst)) return null;
+
         const mod = self.module;
         const bin_op = self.air.instructions.items(.data)[inst].bin_op;
         const ptr_ty = self.typeOf(bin_op.lhs);
+        const elem_ty = self.typeOfIndex(inst);
         const ptr_id = try self.resolve(bin_op.lhs);
         const index_id = try self.resolve(bin_op.rhs);
-
         const elem_ptr_id = try self.ptrElemPtr(ptr_ty, ptr_id, index_id);
+        return try self.load(elem_ty, elem_ptr_id, ptr_ty.isVolatilePtr(mod));
+    }
+
+    fn airSetUnionTag(self: *DeclGen, inst: Air.Inst.Index) !void {
+        const mod = self.module;
+        const bin_op = self.air.instructions.items(.data)[inst].bin_op;
+        const un_ptr_ty = self.typeOf(bin_op.lhs);
+        const un_ty = un_ptr_ty.childType(mod);
+        const layout = self.unionLayout(un_ty, null);
+
+        if (layout.tag_size == 0) return;
+
+        const tag_ty = un_ty.unionTagTypeSafety(mod).?;
+        const tag_ty_ref = try self.resolveType(tag_ty, .indirect);
+        const tag_ptr_ty_ref = try self.spv.ptrType(tag_ty_ref, spvStorageClass(un_ptr_ty.ptrAddressSpace(mod)));
 
-        // If we have a pointer-to-array, construct an element pointer to use with load()
-        // If we pass ptr_ty directly, it will attempt to load the entire array rather than
-        // just an element.
-        var elem_ptr_info = ptr_ty.ptrInfo(mod);
-        elem_ptr_info.flags.size = .One;
-        const elem_ptr_ty = try mod.intern_pool.get(mod.gpa, .{ .ptr_type = elem_ptr_info });
+        const union_ptr_id = try self.resolve(bin_op.lhs);
+        const new_tag_id = try self.resolve(bin_op.rhs);
 
-        return try self.load(elem_ptr_ty.toType(), elem_ptr_id);
+        const ptr_id = try self.accessChain(tag_ptr_ty_ref, union_ptr_id, &.{layout.tag_index});
+        try self.store(tag_ty, ptr_id, new_tag_id, un_ptr_ty.isVolatilePtr(mod));
     }
 
     fn airGetUnionTag(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+        if (self.liveness.isUnused(inst)) return null;
+
         const ty_op = self.air.instructions.items(.data)[inst].ty_op;
         const un_ty = self.typeOf(ty_op.operand);
 
@@ -2588,25 +2601,25 @@ pub const DeclGen = struct {
         const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
         const struct_field = self.air.extraData(Air.StructField, ty_pl.payload).data;
 
-        const container_ty = self.typeOf(struct_field.struct_operand);
+        const object_ty = self.typeOf(struct_field.struct_operand);
         const object_id = try self.resolve(struct_field.struct_operand);
         const field_index = struct_field.field_index;
-        const field_ty = container_ty.structFieldType(field_index, mod);
+        const field_ty = object_ty.structFieldType(field_index, mod);
 
         if (!field_ty.hasRuntimeBitsIgnoreComptime(mod)) return null;
 
-        switch (container_ty.zigTypeTag(mod)) {
-            .Struct => switch (container_ty.containerLayout(mod)) {
+        switch (object_ty.zigTypeTag(mod)) {
+            .Struct => switch (object_ty.containerLayout(mod)) {
                 .Packed => unreachable, // TODO
                 else => return try self.extractField(field_ty, object_id, field_index),
             },
-            .Union => switch (container_ty.containerLayout(mod)) {
+            .Union => switch (object_ty.containerLayout(mod)) {
                 .Packed => unreachable, // TODO
                 else => {
                     // Store, pointer-cast, load
-                    const un_general_ty_ref = try self.resolveType(container_ty, .indirect);
+                    const un_general_ty_ref = try self.resolveType(object_ty, .indirect);
                     const un_general_ptr_ty_ref = try self.spv.ptrType(un_general_ty_ref, .Function);
-                    const un_active_ty_ref = try self.resolveUnionType(container_ty, field_index);
+                    const un_active_ty_ref = try self.resolveUnionType(object_ty, field_index);
                     const un_active_ptr_ty_ref = try self.spv.ptrType(un_active_ty_ref, .Function);
                     const field_ty_ref = try self.resolveType(field_ty, .indirect);
                     const field_ptr_ty_ref = try self.spv.ptrType(field_ty_ref, .Function);
@@ -2617,31 +2630,20 @@ pub const DeclGen = struct {
                         .id_result = tmp_id,
                         .storage_class = .Function,
                     });
-                    try self.func.body.emit(self.spv.gpa, .OpStore, .{
-                        .pointer = tmp_id,
-                        .object = object_id,
-                    });
+                    try self.store(object_ty, tmp_id, object_id, false);
                     const casted_tmp_id = self.spv.allocId();
                     try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
                         .id_result_type = self.typeId(un_active_ptr_ty_ref),
                         .id_result = casted_tmp_id,
                         .operand = tmp_id,
                     });
-                    const layout = self.unionLayout(container_ty, field_index);
+                    const layout = self.unionLayout(object_ty, field_index);
                     const field_ptr_id = try self.accessChain(field_ptr_ty_ref, casted_tmp_id, &.{layout.active_field_index});
-                    const result_id = self.spv.allocId();
-                    try self.func.body.emit(self.spv.gpa, .OpLoad, .{
-                        .id_result_type = self.typeId(field_ty_ref),
-                        .id_result = result_id,
-                        .pointer = field_ptr_id,
-                    });
-                    return try self.convertToDirect(field_ty, result_id);
+                    return try self.load(field_ty, field_ptr_id, false);
                 },
             },
             else => unreachable,
         }
-
-        // return try self.extractField(field_ty, object_id, field_index);
     }
 
     fn structFieldPtr(
@@ -2866,19 +2868,21 @@ pub const DeclGen = struct {
         const mod = self.module;
         const ty_op = self.air.instructions.items(.data)[inst].ty_op;
         const ptr_ty = self.typeOf(ty_op.operand);
+        const elem_ty = self.typeOfIndex(inst);
         const operand = try self.resolve(ty_op.operand);
         if (!ptr_ty.isVolatilePtr(mod) and self.liveness.isUnused(inst)) return null;
 
-        return try self.load(ptr_ty, operand);
+        return try self.load(elem_ty, operand, ptr_ty.isVolatilePtr(mod));
     }
 
     fn airStore(self: *DeclGen, inst: Air.Inst.Index) !void {
         const bin_op = self.air.instructions.items(.data)[inst].bin_op;
         const ptr_ty = self.typeOf(bin_op.lhs);
+        const elem_ty = ptr_ty.childType(self.module);
         const ptr = try self.resolve(bin_op.lhs);
         const value = try self.resolve(bin_op.rhs);
 
-        try self.store(ptr_ty, ptr, value);
+        try self.store(elem_ty, ptr, value, ptr_ty.isVolatilePtr(self.module));
     }
 
     fn airLoop(self: *DeclGen, inst: Air.Inst.Index) !void {
@@ -2922,7 +2926,7 @@ pub const DeclGen = struct {
         }
 
         const ptr = try self.resolve(un_op);
-        const value = try self.load(ptr_ty, ptr);
+        const value = try self.load(ret_ty, ptr, ptr_ty.isVolatilePtr(mod));
         try self.func.body.emit(self.spv.gpa, .OpReturnValue, .{
             .value = value,
         });
test/behavior/union.zig
@@ -29,7 +29,6 @@ test "init union with runtime value - floats" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var foo: FooWithFloats = undefined;
 
@@ -59,7 +58,6 @@ test "init union with runtime value" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var foo: Foo = undefined;
 
@@ -170,7 +168,6 @@ test "constant tagged union with payload" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var empty = TaggedUnionWithPayload{ .Empty = {} };
     var full = TaggedUnionWithPayload{ .Full = 13 };
@@ -508,7 +505,6 @@ test "union initializer generates padding only if needed" {
 test "runtime tag name with single field" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const U = union(enum) {
         A: i32,
@@ -585,7 +581,6 @@ test "tagged union as return value" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     switch (returnAnInt(13)) {
         TaggedFoo.One => |value| try expect(value == 13),
@@ -630,7 +625,6 @@ test "union(enum(u32)) with specified and unspecified tag values" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     try comptime expect(Tag(Tag(MultipleChoice2)) == u32);
     try testEnumWithSpecifiedAndUnspecifiedTagValues(MultipleChoice2{ .C = 123 });
@@ -668,7 +662,6 @@ test "switch on union with only 1 field" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var r: PartialInst = undefined;
     r = PartialInst.Compiled;
@@ -697,7 +690,6 @@ const PartialInstWithPayload = union(enum) {
 
 test "union with only 1 field casted to its enum type which has enum value specified" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const Literal = union(enum) {
         Number: f64,
@@ -782,7 +774,6 @@ test "return union init with void payload" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn entry() !void {
@@ -836,7 +827,6 @@ test "@unionInit can modify a union type" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const UnionInitEnum = union(enum) {
         Boolean: bool,
@@ -860,7 +850,6 @@ test "@unionInit can modify a pointer value" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const UnionInitEnum = union(enum) {
         Boolean: bool,
@@ -917,7 +906,6 @@ test "anonymous union literal syntax" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         const Number = union {
@@ -1041,7 +1029,6 @@ test "switching on non exhaustive union" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         const E = enum(u8) {
@@ -1225,7 +1212,6 @@ test "union tag is set when initiated as a temporary value at runtime" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const U = union(enum) {
         a,
@@ -1263,7 +1249,6 @@ test "return an extern union from C calling convention" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const namespace = struct {
         const S = extern struct {
@@ -1294,7 +1279,6 @@ test "noreturn field in union" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const U = union(enum) {
         a: u32,
@@ -1475,7 +1459,6 @@ test "no dependency loop when function pointer in union returns the union" {
 test "union reassignment can use previous value" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const U = union {
         a: u32,
@@ -1527,7 +1510,6 @@ test "reinterpreting enum value inside packed union" {
 
 test "access the tag of a global tagged union" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const U = union(enum) {
         a,
@@ -1539,7 +1521,6 @@ test "access the tag of a global tagged union" {
 
 test "coerce enum literal to union in result loc" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const U = union(enum) {
         a,