Commit 5090d75e48

Robin Voetter <robin@voetter.nl>
2023-10-21 13:04:18
spirv: make load() and store() accept MemoryOptions
This struct is used to configure the load, such as to make it volatile. Previously this was done using a single bool, but this struct makes it shorter to write non-volatile loads (the usual) and more clear whats going on when a volatile load is required.
1 parent 200bca3
Changed files (1)
src
codegen
src/codegen/spirv.zig
@@ -1910,11 +1910,15 @@ const DeclGen = struct {
         return try self.convertToDirect(result_ty, result_id);
     }
 
-    fn load(self: *DeclGen, value_ty: Type, ptr_id: IdRef, is_volatile: bool) !IdRef {
+    const MemoryOptions = struct {
+        is_volatile: bool = false,
+    };
+
+    fn load(self: *DeclGen, value_ty: Type, ptr_id: IdRef, options: MemoryOptions) !IdRef {
         const indirect_value_ty_ref = try self.resolveType(value_ty, .indirect);
         const result_id = self.spv.allocId();
         const access = spec.MemoryAccess.Extended{
-            .Volatile = is_volatile,
+            .Volatile = options.is_volatile,
         };
         try self.func.body.emit(self.spv.gpa, .OpLoad, .{
             .id_result_type = self.typeId(indirect_value_ty_ref),
@@ -1925,10 +1929,10 @@ const DeclGen = struct {
         return try self.convertToDirect(value_ty, result_id);
     }
 
-    fn store(self: *DeclGen, value_ty: Type, ptr_id: IdRef, value_id: IdRef, is_volatile: bool) !void {
+    fn store(self: *DeclGen, value_ty: Type, ptr_id: IdRef, value_id: IdRef, options: MemoryOptions) !void {
         const indirect_value_id = try self.convertToIndirect(value_ty, value_id);
         const access = spec.MemoryAccess.Extended{
-            .Volatile = is_volatile,
+            .Volatile = options.is_volatile,
         };
         try self.func.body.emit(self.spv.gpa, .OpStore, .{
             .pointer = ptr_id,
@@ -2849,14 +2853,14 @@ const DeclGen = struct {
         const dst_ptr_ty_ref = try self.ptrType(dst_ty, .Function);
 
         const tmp_id = try self.alloc(src_ty, .{ .storage_class = .Function });
-        try self.store(src_ty, tmp_id, src_id, false);
+        try self.store(src_ty, tmp_id, src_id, .{});
         const casted_ptr_id = self.spv.allocId();
         try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
             .id_result_type = self.typeId(dst_ptr_ty_ref),
             .id_result = casted_ptr_id,
             .operand = tmp_id,
         });
-        return try self.load(dst_ty, casted_ptr_id, false);
+        return try self.load(dst_ty, casted_ptr_id, .{});
     }
 
     fn airBitCast(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -3219,7 +3223,7 @@ const DeclGen = struct {
 
         const slice_ptr = try self.extractField(ptr_ty, slice_id, 0);
         const elem_ptr = try self.ptrAccessChain(ptr_ty_ref, slice_ptr, index_id, &.{});
-        return try self.load(slice_ty.childType(mod), elem_ptr, slice_ty.isVolatilePtr(mod));
+        return try self.load(slice_ty.childType(mod), elem_ptr, .{ .is_volatile = slice_ty.isVolatilePtr(mod) });
     }
 
     fn ptrElemPtr(self: *DeclGen, ptr_ty: Type, ptr_id: IdRef, index_id: IdRef) !IdRef {
@@ -3273,9 +3277,9 @@ const DeclGen = struct {
         const elem_ptr_ty_ref = try self.ptrType(elem_ty, .Function);
 
         const tmp_id = try self.alloc(array_ty, .{ .storage_class = .Function });
-        try self.store(array_ty, tmp_id, array_id, false);
+        try self.store(array_ty, tmp_id, array_id, .{});
         const elem_ptr_id = try self.accessChainId(elem_ptr_ty_ref, tmp_id, &.{index_id});
-        return try self.load(elem_ty, elem_ptr_id, false);
+        return try self.load(elem_ty, elem_ptr_id, .{});
     }
 
     fn airPtrElemVal(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -3288,7 +3292,7 @@ const DeclGen = struct {
         const ptr_id = try self.resolve(bin_op.lhs);
         const index_id = try self.resolve(bin_op.rhs);
         const elem_ptr_id = try self.ptrElemPtr(ptr_ty, ptr_id, index_id);
-        return try self.load(elem_ty, elem_ptr_id, ptr_ty.isVolatilePtr(mod));
+        return try self.load(elem_ty, elem_ptr_id, .{ .is_volatile = ptr_ty.isVolatilePtr(mod) });
     }
 
     fn airSetUnionTag(self: *DeclGen, inst: Air.Inst.Index) !void {
@@ -3307,10 +3311,10 @@ const DeclGen = struct {
         const new_tag_id = try self.resolve(bin_op.rhs);
 
         if (layout.payload_size == 0) {
-            try self.store(tag_ty, union_ptr_id, new_tag_id, un_ptr_ty.isVolatilePtr(mod));
+            try self.store(tag_ty, union_ptr_id, new_tag_id, .{ .is_volatile = un_ptr_ty.isVolatilePtr(mod) });
         } else {
             const ptr_id = try self.accessChain(tag_ptr_ty_ref, union_ptr_id, &.{layout.tag_index});
-            try self.store(tag_ty, ptr_id, new_tag_id, un_ptr_ty.isVolatilePtr(mod));
+            try self.store(tag_ty, ptr_id, new_tag_id, .{ .is_volatile = un_ptr_ty.isVolatilePtr(mod) });
         }
     }
 
@@ -3384,13 +3388,13 @@ const DeclGen = struct {
             const tag_ptr_ty_ref = try self.ptrType(maybe_tag_ty.?, .Function);
             const ptr_id = try self.accessChain(tag_ptr_ty_ref, tmp_id, &.{@as(u32, @intCast(layout.tag_index))});
             const tag_id = try self.constInt(tag_ty_ref, tag_int);
-            try self.store(maybe_tag_ty.?, ptr_id, tag_id, false);
+            try self.store(maybe_tag_ty.?, ptr_id, tag_id, .{});
         }
 
         if (layout.active_field_size != 0) {
             const active_field_ptr_ty_ref = try self.ptrType(layout.active_field_ty, .Function);
             const ptr_id = try self.accessChain(active_field_ptr_ty_ref, tmp_id, &.{@as(u32, @intCast(layout.active_field_index))});
-            try self.store(layout.active_field_ty, ptr_id, payload.?, false);
+            try self.store(layout.active_field_ty, ptr_id, payload.?, .{});
         } else {
             assert(payload == null);
         }
@@ -3468,7 +3472,7 @@ const DeclGen = struct {
                         .id_result = tmp_id,
                         .storage_class = .Function,
                     });
-                    try self.store(object_ty, tmp_id, object_id, false);
+                    try self.store(object_ty, tmp_id, object_id, .{});
                     const casted_tmp_id = self.spv.allocId();
                     try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
                         .id_result_type = self.typeId(un_active_ptr_ty_ref),
@@ -3477,7 +3481,7 @@ const DeclGen = struct {
                     });
                     const layout = self.unionLayout(object_ty, field_index);
                     const field_ptr_id = try self.accessChain(field_ptr_ty_ref, casted_tmp_id, &.{layout.active_field_index});
-                    return try self.load(field_ty, field_ptr_id, false);
+                    return try self.load(field_ty, field_ptr_id, .{});
                 },
             },
             else => unreachable,
@@ -3730,7 +3734,7 @@ const DeclGen = struct {
         const operand = try self.resolve(ty_op.operand);
         if (!ptr_ty.isVolatilePtr(mod) and self.liveness.isUnused(inst)) return null;
 
-        return try self.load(elem_ty, operand, ptr_ty.isVolatilePtr(mod));
+        return try self.load(elem_ty, operand, .{ .is_volatile = ptr_ty.isVolatilePtr(mod) });
     }
 
     fn airStore(self: *DeclGen, inst: Air.Inst.Index) !void {
@@ -3740,7 +3744,7 @@ const DeclGen = struct {
         const ptr = try self.resolve(bin_op.lhs);
         const value = try self.resolve(bin_op.rhs);
 
-        try self.store(elem_ty, ptr, value, ptr_ty.isVolatilePtr(self.module));
+        try self.store(elem_ty, ptr, value, .{ .is_volatile = ptr_ty.isVolatilePtr(self.module) });
     }
 
     fn airLoop(self: *DeclGen, inst: Air.Inst.Index) !void {
@@ -3804,7 +3808,7 @@ const DeclGen = struct {
         }
 
         const ptr = try self.resolve(un_op);
-        const value = try self.load(ret_ty, ptr, ptr_ty.isVolatilePtr(mod));
+        const value = try self.load(ret_ty, ptr, .{ .is_volatile = ptr_ty.isVolatilePtr(mod) });
         try self.func.body.emit(self.spv.gpa, .OpReturnValue, .{
             .value = value,
         });