Commit 5dffbf32bf

Robin Voetter <robin@voetter.nl>
2023-09-17 17:17:46
spirv: air struct_field_val for unions
1 parent decdedf
Changed files (2)
src
codegen
test
behavior
src/codegen/spirv.zig
@@ -452,16 +452,12 @@ pub const DeclGen = struct {
             .storage_class = .Function,
         });
 
-        // Note: using 32-bit ints here because usize crashes the translator as well
-        const index_ty_ref = try self.intType(.unsigned, 32);
-
         const spv_composite_ty = self.spv.cache.lookup(result_ty_ref).struct_type;
         const member_types = spv_composite_ty.member_types;
 
         for (constituents, member_types, 0..) |constitent_id, member_ty_ref, index| {
-            const index_id = try self.constInt(index_ty_ref, index);
             const ptr_member_ty_ref = try self.spv.ptrType(member_ty_ref, .Function);
-            const ptr_id = try self.accessChain(ptr_member_ty_ref, ptr_composite_id, &.{index_id});
+            const ptr_id = try self.accessChain(ptr_member_ty_ref, ptr_composite_id, &.{@as(u32, @intCast(index))});
             try self.func.body.emit(self.spv.gpa, .OpStore, .{
                 .pointer = ptr_id,
                 .object = constitent_id,
@@ -493,16 +489,12 @@ pub const DeclGen = struct {
             .storage_class = .Function,
         });
 
-        // Note: using 32-bit ints here because usize crashes the translator as well
-        const index_ty_ref = try self.intType(.unsigned, 32);
-
         const spv_composite_ty = self.spv.cache.lookup(result_ty_ref).array_type;
         const elem_ty_ref = spv_composite_ty.element_type;
         const ptr_elem_ty_ref = try self.spv.ptrType(elem_ty_ref, .Function);
 
         for (constituents, 0..) |constitent_id, index| {
-            const index_id = try self.constInt(index_ty_ref, index);
-            const ptr_id = try self.accessChain(ptr_elem_ty_ref, ptr_composite_id, &.{index_id});
+            const ptr_id = try self.accessChain(ptr_elem_ty_ref, ptr_composite_id, &.{@as(u32, @intCast(index))});
             try self.func.body.emit(self.spv.gpa, .OpStore, .{
                 .pointer = ptr_id,
                 .object = constitent_id,
@@ -535,18 +527,36 @@ pub const DeclGen = struct {
                 const decl_id = self.spv.declPtr(spv_decl_index).result_id;
                 try self.func.decl_deps.put(self.spv.gpa, spv_decl_index, {});
 
-                switch (decl.@"addrspace") {
-                    .generic => {
-                        // Pointer should be generic, but is actually placed in CrossWorkgroup.
+                const final_storage_class = spvStorageClass(decl.@"addrspace");
+
+                const decl_ty_ref = try self.resolveType(decl.ty, .indirect);
+                const decl_ptr_ty_ref = try self.spv.ptrType(decl_ty_ref, final_storage_class);
+
+                const ptr_id = switch (final_storage_class) {
+                    .Generic => blk: {
+                        // Pointer should be Generic, but is actually placed in CrossWorkgroup.
                         const result_id = self.spv.allocId();
                         try self.func.body.emit(self.spv.gpa, .OpPtrCastToGeneric, .{
-                            .id_result_type = ty_id,
+                            .id_result_type = self.typeId(decl_ptr_ty_ref),
                             .id_result = result_id,
                             .pointer = decl_id,
                         });
-                        return result_id;
+                        break :blk result_id;
                     },
-                    else => return decl_id, // Variable is already correct, probably. Maybe needs a bitcast?
+                    else => decl_id,
+                };
+
+                if (decl_ptr_ty_ref != ty_ref) {
+                    // Differing pointer types, insert a cast.
+                    const casted_ptr_id = self.spv.allocId();
+                    try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
+                        .id_result_type = ty_id,
+                        .id_result = casted_ptr_id,
+                        .operand = ptr_id,
+                    });
+                    return casted_ptr_id;
+                } else {
+                    return ptr_id;
                 }
             },
         }
@@ -820,14 +830,11 @@ pub const DeclGen = struct {
                     .storage_class = .Function,
                 });
 
-                const index_ty_ref = try self.intType(.unsigned, 32);
-
                 if (layout.tag_size != 0) {
-                    const index_id = try self.constInt(index_ty_ref, @as(u32, @intCast(layout.tag_index)));
                     const tag_ty = ty.unionTagTypeSafety(mod).?;
                     const tag_ty_ref = try self.resolveType(tag_ty, .indirect);
                     const tag_ptr_ty_ref = try self.spv.ptrType(tag_ty_ref, .Function);
-                    const ptr_id = try self.accessChain(tag_ptr_ty_ref, var_id, &.{index_id});
+                    const ptr_id = try self.accessChain(tag_ptr_ty_ref, var_id, &.{@as(u32, @intCast(layout.tag_index))});
                     const tag_id = try self.constant(tag_ty, un.tag.toValue(), .indirect);
                     try self.func.body.emit(self.spv.gpa, .OpStore, .{
                         .pointer = ptr_id,
@@ -836,10 +843,9 @@ pub const DeclGen = struct {
                 }
 
                 if (layout.active_field_size != 0) {
-                    const index_id = try self.constInt(index_ty_ref, @as(u32, @intCast(layout.active_field_index)));
                     const active_field_ty_ref = try self.resolveType(layout.active_field_ty, .indirect);
                     const active_field_ptr_ty_ref = try self.spv.ptrType(active_field_ty_ref, .Function);
-                    const ptr_id = try self.accessChain(active_field_ptr_ty_ref, var_id, &.{index_id});
+                    const ptr_id = try self.accessChain(active_field_ptr_ty_ref, var_id, &.{@as(u32, @intCast(layout.active_field_index))});
                     const value_id = try self.constant(layout.active_field_ty, un.val.toValue(), .indirect);
                     try self.func.body.emit(self.spv.gpa, .OpStore, .{
                         .pointer = ptr_id,
@@ -2070,40 +2076,65 @@ pub const DeclGen = struct {
         return result_id;
     }
 
-    /// AccessChain is essentially PtrAccessChain with 0 as initial argument. The effective
-    /// difference lies in whether the resulting type of the first dereference will be the
-    /// same as that of the base pointer, or that of a dereferenced base pointer. AccessChain
-    /// is the latter and PtrAccessChain is the former.
-    fn accessChain(
+    fn indicesToIds(self: *DeclGen, indices: []const u32) ![]IdRef {
+        const index_ty_ref = try self.intType(.unsigned, 32);
+        const ids = try self.gpa.alloc(IdRef, indices.len);
+        errdefer self.gpa.free(ids);
+        for (indices, ids) |index, *id| {
+            id.* = try self.constInt(index_ty_ref, index);
+        }
+
+        return ids;
+    }
+
+    fn accessChainId(
         self: *DeclGen,
         result_ty_ref: CacheRef,
         base: IdRef,
-        indexes: []const IdRef,
+        indices: []const IdRef,
     ) !IdRef {
         const result_id = self.spv.allocId();
         try self.func.body.emit(self.spv.gpa, .OpInBoundsAccessChain, .{
             .id_result_type = self.typeId(result_ty_ref),
             .id_result = result_id,
             .base = base,
-            .indexes = indexes,
+            .indexes = indices,
         });
         return result_id;
     }
 
+    /// AccessChain is essentially PtrAccessChain with 0 as initial argument. The effective
+    /// difference lies in whether the resulting type of the first dereference will be the
+    /// same as that of the base pointer, or that of a dereferenced base pointer. AccessChain
+    /// is the latter and PtrAccessChain is the former.
+    fn accessChain(
+        self: *DeclGen,
+        result_ty_ref: CacheRef,
+        base: IdRef,
+        indices: []const u32,
+    ) !IdRef {
+        const ids = try self.indicesToIds(indices);
+        defer self.gpa.free(ids);
+        return try self.accessChainId(result_ty_ref, base, ids);
+    }
+
     fn ptrAccessChain(
         self: *DeclGen,
         result_ty_ref: CacheRef,
         base: IdRef,
         element: IdRef,
-        indexes: []const IdRef,
+        indices: []const u32,
     ) !IdRef {
+        const ids = try self.indicesToIds(indices);
+        defer self.gpa.free(ids);
+
         const result_id = self.spv.allocId();
         try self.func.body.emit(self.spv.gpa, .OpInBoundsPtrAccessChain, .{
             .id_result_type = self.typeId(result_ty_ref),
             .id_result = result_id,
             .base = base,
             .element = element,
-            .indexes = indexes,
+            .indexes = ids,
         });
         return result_id;
     }
@@ -2116,7 +2147,7 @@ pub const DeclGen = struct {
             .One => {
                 // Pointer to array
                 // TODO: Is this correct?
-                return try self.accessChain(result_ty_ref, ptr_id, &.{offset_id});
+                return try self.accessChainId(result_ty_ref, ptr_id, &.{offset_id});
             },
             .C, .Many => {
                 return try self.ptrAccessChain(result_ty_ref, ptr_id, offset_id, &.{});
@@ -2493,7 +2524,7 @@ pub const DeclGen = struct {
         if (ptr_ty.isSinglePointer(mod)) {
             // Pointer-to-array. In this case, the resulting pointer is not of the same type
             // as the ptr_ty (we want a *T, not a *[N]T), and hence we need to use accessChain.
-            return try self.accessChain(elem_ptr_ty_ref, ptr_id, &.{index_id});
+            return try self.accessChainId(elem_ptr_ty_ref, ptr_id, &.{index_id});
         } else {
             // Resulting pointer type is the same as the ptr_ty, so use ptrAccessChain
             return try self.ptrAccessChain(elem_ptr_ty_ref, ptr_id, index_id, &.{});
@@ -2540,15 +2571,14 @@ pub const DeclGen = struct {
         const un_ty = self.typeOf(ty_op.operand);
 
         const mod = self.module;
-        const layout = un_ty.unionGetLayout(mod);
+        const layout = self.unionLayout(un_ty, null);
         if (layout.tag_size == 0) return null;
 
         const union_handle = try self.resolve(ty_op.operand);
         if (layout.payload_size == 0) return union_handle;
 
         const tag_ty = un_ty.unionTagTypeSafety(mod).?;
-        const tag_index = @intFromBool(layout.tag_align.compare(.lt, layout.payload_align));
-        return try self.extractField(tag_ty, union_handle, tag_index);
+        return try self.extractField(tag_ty, union_handle, layout.tag_index);
     }
 
     fn airStructFieldVal(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -2558,16 +2588,60 @@ pub const DeclGen = struct {
         const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
         const struct_field = self.air.extraData(Air.StructField, ty_pl.payload).data;
 
-        const struct_ty = self.typeOf(struct_field.struct_operand);
+        const container_ty = self.typeOf(struct_field.struct_operand);
         const object_id = try self.resolve(struct_field.struct_operand);
         const field_index = struct_field.field_index;
-        const field_ty = struct_ty.structFieldType(field_index, mod);
+        const field_ty = container_ty.structFieldType(field_index, mod);
 
         if (!field_ty.hasRuntimeBitsIgnoreComptime(mod)) return null;
 
-        assert(struct_ty.zigTypeTag(mod) == .Struct); // Cannot do unions yet.
+        switch (container_ty.zigTypeTag(mod)) {
+            .Struct => switch (container_ty.containerLayout(mod)) {
+                .Packed => unreachable, // TODO
+                else => return try self.extractField(field_ty, object_id, field_index),
+            },
+            .Union => switch (container_ty.containerLayout(mod)) {
+                .Packed => unreachable, // TODO
+                else => {
+                    // Store, pointer-cast, load
+                    const un_general_ty_ref = try self.resolveType(container_ty, .indirect);
+                    const un_general_ptr_ty_ref = try self.spv.ptrType(un_general_ty_ref, .Function);
+                    const un_active_ty_ref = try self.resolveUnionType(container_ty, field_index);
+                    const un_active_ptr_ty_ref = try self.spv.ptrType(un_active_ty_ref, .Function);
+                    const field_ty_ref = try self.resolveType(field_ty, .indirect);
+                    const field_ptr_ty_ref = try self.spv.ptrType(field_ty_ref, .Function);
+
+                    const tmp_id = self.spv.allocId();
+                    try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
+                        .id_result_type = self.typeId(un_general_ptr_ty_ref),
+                        .id_result = tmp_id,
+                        .storage_class = .Function,
+                    });
+                    try self.func.body.emit(self.spv.gpa, .OpStore, .{
+                        .pointer = tmp_id,
+                        .object = object_id,
+                    });
+                    const casted_tmp_id = self.spv.allocId();
+                    try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
+                        .id_result_type = self.typeId(un_active_ptr_ty_ref),
+                        .id_result = casted_tmp_id,
+                        .operand = tmp_id,
+                    });
+                    const layout = self.unionLayout(container_ty, field_index);
+                    const field_ptr_id = try self.accessChain(field_ptr_ty_ref, casted_tmp_id, &.{layout.active_field_index});
+                    const result_id = self.spv.allocId();
+                    try self.func.body.emit(self.spv.gpa, .OpLoad, .{
+                        .id_result_type = self.typeId(field_ty_ref),
+                        .id_result = result_id,
+                        .pointer = field_ptr_id,
+                    });
+                    return try self.convertToDirect(field_ty, result_id);
+                },
+            },
+            else => unreachable,
+        }
 
-        return try self.extractField(field_ty, object_id, field_index);
+        // return try self.extractField(field_ty, object_id, field_index);
     }
 
     fn structFieldPtr(
@@ -2583,10 +2657,8 @@ pub const DeclGen = struct {
             .Struct => switch (object_ty.containerLayout(mod)) {
                 .Packed => unreachable, // TODO
                 else => {
-                    const field_index_ty_ref = try self.intType(.unsigned, 32);
-                    const field_index_id = try self.constInt(field_index_ty_ref, field_index);
                     const result_ty_ref = try self.resolveType(result_ptr_ty, .direct);
-                    return try self.accessChain(result_ty_ref, object_ptr, &.{field_index_id});
+                    return try self.accessChain(result_ty_ref, object_ptr, &.{field_index});
                 },
             },
             else => unreachable, // TODO
test/behavior/union.zig
@@ -14,7 +14,6 @@ test "basic unions with floats" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var foo = FooWithFloats{ .int = 1 };
     try expect(foo.int == 1);
@@ -42,7 +41,6 @@ test "basic unions" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var foo = Foo{ .int = 1 };
     try expect(foo.int == 1);
@@ -342,7 +340,6 @@ test "constant packed union" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     try testConstPackedUnion(&[_]PackThis{PackThis{ .StringLiteral = 1 }});
 }
@@ -453,7 +450,6 @@ test "global union with single field is correctly initialized" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     glbl = Foo1{
         .f = @typeInfo(Foo1).Union.fields[0].type{ .x = 123 },