Commit 6c05557072

Robin Voetter <robin@voetter.nl>
2023-05-18 18:55:15
spirv: fix some (Ptr)AccessChain uses
The first dereference of PtrAccessChain returns a pointer of the same type as the base pointer, in contrast to AccessChain, where the first dereference returns a pointer of the dereferenced type of the base pointer.
1 parent 0ba0d8f
Changed files (1)
src
codegen
src/codegen/spirv.zig
@@ -457,14 +457,8 @@ pub const DeclGen = struct {
         const members = spv_composite_ty.payload(.@"struct").members;
         for (constituents, members, 0..) |constitent_id, member, index| {
             const index_id = try self.constInt(index_ty_ref, index);
-            const ptr_id = self.spv.allocId();
             const ptr_member_ty_ref = try self.spv.ptrType(member.ty, .Generic, 0);
-            try self.func.body.emit(self.spv.gpa, .OpInBoundsAccessChain, .{
-                .id_result_type = self.typeId(ptr_member_ty_ref),
-                .id_result = ptr_id,
-                .base = ptr_composite_id,
-                .indexes = &.{index_id},
-            });
+            const ptr_id = try self.accessChain(ptr_member_ty_ref, ptr_composite_id, &.{index_id});
             try self.func.body.emit(self.spv.gpa, .OpStore, .{
                 .pointer = ptr_id,
                 .object = constitent_id,
@@ -2089,6 +2083,44 @@ pub const DeclGen = struct {
         return result_id;
     }
 
+    /// AccessChain is essentially PtrAccessChain with 0 as initial argument. The effective
+    /// difference lies in whether the resulting type of the first dereference will be the
+    /// same as that of the base pointer, or that of a dereferenced base pointer. AccessChain
+    /// is the latter and PtrAccessChain is the former.
+    fn accessChain(
+        self: *DeclGen,
+        result_ty_ref: SpvType.Ref,
+        base: IdRef,
+        indexes: []const IdRef,
+    ) !IdRef {
+        const result_id = self.spv.allocId();
+        try self.func.body.emit(self.spv.gpa, .OpInBoundsAccessChain, .{
+            .id_result_type = self.typeId(result_ty_ref),
+            .id_result = result_id,
+            .base = base,
+            .indexes = indexes,
+        });
+        return result_id;
+    }
+
+    fn ptrAccessChain(
+        self: *DeclGen,
+        result_ty_ref: SpvType.Ref,
+        base: IdRef,
+        element: IdRef,
+        indexes: []const IdRef,
+    ) !IdRef {
+        const result_id = self.spv.allocId();
+        try self.func.body.emit(self.spv.gpa, .OpInBoundsPtrAccessChain, .{
+            .id_result_type = self.typeId(result_ty_ref),
+            .id_result = result_id,
+            .base = base,
+            .element = element,
+            .indexes = indexes,
+        });
+        return result_id;
+    }
+
     fn cmp(
         self: *DeclGen,
         comptime op: std.math.CompareOperator,
@@ -2353,12 +2385,12 @@ pub const DeclGen = struct {
         const slice = try self.resolve(bin_op.lhs);
         const index = try self.resolve(bin_op.rhs);
 
-        const spv_ptr_ty = try self.resolveTypeId(self.air.typeOfIndex(inst));
+        const ptr_ty_ref = try self.resolveType(self.air.typeOfIndex(inst), .direct);
 
         const slice_ptr = blk: {
             const result_id = self.spv.allocId();
             try self.func.body.emit(self.spv.gpa, .OpCompositeExtract, .{
-                .id_result_type = spv_ptr_ty,
+                .id_result_type = self.typeId(ptr_ty_ref),
                 .id_result = result_id,
                 .composite = slice,
                 .indexes = &.{0},
@@ -2366,14 +2398,7 @@ pub const DeclGen = struct {
             break :blk result_id;
         };
 
-        const result_id = self.spv.allocId();
-        try self.func.body.emit(self.spv.gpa, .OpInBoundsPtrAccessChain, .{
-            .id_result_type = spv_ptr_ty,
-            .id_result = result_id,
-            .base = slice_ptr,
-            .element = index,
-        });
-        return result_id;
+        return try self.ptrAccessChain(ptr_ty_ref, slice_ptr, index, &.{});
     }
 
     fn airSliceElemVal(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -2385,12 +2410,12 @@ pub const DeclGen = struct {
         const index = try self.resolve(bin_op.rhs);
 
         var slice_buf: Type.SlicePtrFieldTypeBuffer = undefined;
-        const ptr_ty_id = try self.resolveTypeId(slice_ty.slicePtrFieldType(&slice_buf));
+        const ptr_ty_ref = try self.resolveType(slice_ty.slicePtrFieldType(&slice_buf), .direct);
 
         const slice_ptr = blk: {
             const result_id = self.spv.allocId();
             try self.func.body.emit(self.spv.gpa, .OpCompositeExtract, .{
-                .id_result_type = ptr_ty_id,
+                .id_result_type = self.typeId(ptr_ty_ref),
                 .id_result = result_id,
                 .composite = slice,
                 .indexes = &.{0},
@@ -2398,17 +2423,7 @@ pub const DeclGen = struct {
             break :blk result_id;
         };
 
-        const elem_ptr = blk: {
-            const result_id = self.spv.allocId();
-            try self.func.body.emit(self.spv.gpa, .OpInBoundsPtrAccessChain, .{
-                .id_result_type = ptr_ty_id,
-                .id_result = result_id,
-                .base = slice_ptr,
-                .element = index,
-            });
-            break :blk result_id;
-        };
-
+        const elem_ptr = try self.ptrAccessChain(ptr_ty_ref, slice_ptr, index, &.{});
         return try self.load(slice_ty, elem_ptr);
     }
 
@@ -2423,19 +2438,18 @@ pub const DeclGen = struct {
         // TODO: Make this return a null ptr or something
         if (!elem_ty.hasRuntimeBitsIgnoreComptime()) return null;
 
-        const result_type_id = try self.resolveTypeId(result_ty);
+        const result_ty_ref = try self.resolveType(result_ty, .direct);
         const base_ptr = try self.resolve(bin_op.lhs);
         const rhs = try self.resolve(bin_op.rhs);
 
-        const result_id = self.spv.allocId();
-        const indexes = [_]IdRef{rhs};
-        try self.func.body.emit(self.spv.gpa, .OpInBoundsAccessChain, .{
-            .id_result_type = result_type_id,
-            .id_result = result_id,
-            .base = base_ptr,
-            .indexes = &indexes,
-        });
-        return result_id;
+        if (ptr_ty.isSinglePointer()) {
+            // Pointer-to-array. In this case, the resulting pointer is not of the same type
+            // as the ptr_ty, and hence we need to use accessChain.
+            return try self.accessChain(result_ty_ref, base_ptr, &.{rhs});
+        } else {
+            // Resulting pointer type is the same as the ptr_ty, so use ptrAccessChain
+            return try self.ptrAccessChain(result_ty_ref, base_ptr, rhs, &.{});
+        }
     }
 
     fn airStructFieldVal(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -2480,16 +2494,8 @@ pub const DeclGen = struct {
                     const u32_ty_id = self.typeId(try self.intType(.unsigned, 32));
                     const field_index_id = self.spv.allocId();
                     try self.spv.emitConstant(u32_ty_id, field_index_id, .{ .uint32 = field_index });
-                    const result_id = self.spv.allocId();
-                    const result_type_id = try self.resolveTypeId(result_ptr_ty);
-                    const indexes = [_]IdRef{field_index_id};
-                    try self.func.body.emit(self.spv.gpa, .OpInBoundsAccessChain, .{
-                        .id_result_type = result_type_id,
-                        .id_result = result_id,
-                        .base = object_ptr,
-                        .indexes = &indexes,
-                    });
-                    return result_id;
+                    const result_ty_ref = try self.resolveType(result_ptr_ty, .direct);
+                    return try self.accessChain(result_ty_ref, object_ptr, &.{field_index_id});
                 },
             },
             else => unreachable, // TODO