Commit a1cb9563f6

Ali Cheraghi <alichraghi@proton.me>
2024-10-31 23:33:33
spirv: Uniform/PushConstant variables
- Rename GPU address spaces to match with SPIR-V spec. - Emit `Block` Decoration for Uniform/PushConstant variables. - Don't emit `OpTypeForwardPointer` for non-opencl targets. (there's still a false-positive about recursive structs) Signed-off-by: Ali Cheraghi <alichraghi@proton.me>
1 parent 17a87d7
Changed files (7)
lib/std/builtin.zig
@@ -514,6 +514,7 @@ pub const AddressSpace = enum(u5) {
     input,
     output,
     uniform,
+    push_constant,
 
     // AVR address spaces.
     flash,
lib/std/Target.zig
@@ -1479,7 +1479,7 @@ pub const Cpu = struct {
                 .fs, .gs, .ss => arch == .x86_64 or arch == .x86,
                 .global, .constant, .local, .shared => is_gpu,
                 .param => is_nvptx,
-                .input, .output, .uniform => is_spirv,
+                .input, .output, .uniform, .push_constant => is_spirv,
                 // TODO this should also check how many flash banks the cpu has
                 .flash, .flash1, .flash2, .flash3, .flash4, .flash5 => arch == .avr,
 
src/codegen/spirv/Module.zig
@@ -402,9 +402,7 @@ pub fn resolveString(self: *Module, string: []const u8) !IdRef {
     return id;
 }
 
-pub fn structType(self: *Module, types: []const IdRef, maybe_names: ?[]const []const u8) !IdRef {
-    const result_id = self.allocId();
-
+pub fn structType(self: *Module, result_id: IdResult, types: []const IdRef, maybe_names: ?[]const []const u8) !void {
     try self.sections.types_globals_constants.emit(self.gpa, .OpTypeStruct, .{
         .id_result = result_id,
         .id_ref = types,
@@ -416,8 +414,6 @@ pub fn structType(self: *Module, types: []const IdRef, maybe_names: ?[]const []c
             try self.memberDebugName(result_id, @intCast(i), name);
         }
     }
-
-    return result_id;
 }
 
 pub fn boolType(self: *Module) !IdRef {
src/codegen/spirv.zig
@@ -897,7 +897,7 @@ const NavGen = struct {
         const result_ty_id = try self.resolveType(ty, repr);
         const ip = &zcu.intern_pool;
 
-        log.debug("lowering constant: ty = {}, val = {}", .{ ty.fmt(pt), val.fmtValue(pt) });
+        log.debug("lowering constant: ty = {}, val = {}, key = {s}", .{ ty.fmt(pt), val.fmtValue(pt), @tagName(ip.indexToKey(val.toIntern())) });
         if (val.isUndefDeep(zcu)) {
             return self.spv.constUndef(result_ty_id);
         }
@@ -1167,7 +1167,6 @@ const NavGen = struct {
 
     fn derivePtr(self: *NavGen, derivation: Value.PointerDeriveStep) Error!IdRef {
         const pt = self.pt;
-        const zcu = pt.zcu;
         switch (derivation) {
             .comptime_alloc_ptr, .comptime_field_ptr => unreachable,
             .int => |int| {
@@ -1211,10 +1210,6 @@ const NavGen = struct {
                     if (oac.byte_offset != 0) break :disallow;
                     // Allow changing the pointer type child only to restructure arrays.
                     // e.g. [3][2]T to T is fine, as is [2]T -> [2][1]T.
-                    const src_base_ty = parent_ptr_ty.arrayBase(zcu)[0];
-                    const dest_base_ty = oac.new_ptr_ty.arrayBase(zcu)[0];
-                    if (self.getTarget().os.tag == .vulkan and src_base_ty.toIntern() != dest_base_ty.toIntern()) break :disallow;
-
                     const result_ty_id = try self.resolveType(oac.new_ptr_ty, .direct);
                     const result_ptr_id = self.spv.allocId();
                     try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
@@ -1224,7 +1219,7 @@ const NavGen = struct {
                     });
                     return result_ptr_id;
                 }
-                return self.fail("Cannot perform pointer cast: '{}' to '{}'", .{
+                return self.fail("cannot perform pointer cast: '{}' to '{}'", .{
                     parent_ptr_ty.fmt(pt),
                     oac.new_ptr_ty.fmt(pt),
                 });
@@ -1308,12 +1303,12 @@ const NavGen = struct {
             .global, .invocation_global => spv_decl.result_id,
         };
 
-        const final_storage_class = self.spvStorageClass(nav.status.resolved.@"addrspace");
-        try self.addFunctionDep(spv_decl_index, final_storage_class);
+        const storage_class = self.spvStorageClass(nav.status.resolved.@"addrspace");
+        try self.addFunctionDep(spv_decl_index, storage_class);
 
-        const decl_ptr_ty_id = try self.ptrType(nav_ty, final_storage_class);
+        const decl_ptr_ty_id = try self.ptrType(nav_ty, storage_class);
 
-        const ptr_id = switch (final_storage_class) {
+        const ptr_id = switch (storage_class) {
             .Generic => try self.castToGeneric(decl_ptr_ty_id, decl_id),
             else => decl_id,
         };
@@ -1399,6 +1394,10 @@ const NavGen = struct {
 
         const child_ty_id = try self.resolveType(child_ty, child_repr);
 
+        if (storage_class == .Uniform or storage_class == .PushConstant) {
+            try self.spv.decorate(child_ty_id, .Block);
+        }
+
         try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpTypePointer, .{
             .id_result = result_id,
             .storage_class = storage_class,
@@ -1503,10 +1502,13 @@ const NavGen = struct {
             member_names[layout.padding_index] = "(padding)";
         }
 
-        const result_id = try self.spv.structType(member_types[0..layout.total_fields], member_names[0..layout.total_fields]);
+        const result_id = self.spv.allocId();
+        try self.spv.structType(result_id, member_types[0..layout.total_fields], member_names[0..layout.total_fields]);
+
         const type_name = try self.resolveTypeName(ty);
         defer self.gpa.free(type_name);
         try self.spv.debugName(result_id, type_name);
+
         return result_id;
     }
 
@@ -1700,10 +1702,13 @@ const NavGen = struct {
                 }
 
                 const size_ty_id = try self.resolveType(Type.usize, .direct);
-                return self.spv.structType(
+                const result_id = self.spv.allocId();
+                try self.spv.structType(
+                    result_id,
                     &.{ ptr_ty_id, size_ty_id },
                     &.{ "ptr", "len" },
                 );
+                return result_id;
             },
             .vector => {
                 const elem_ty = ty.childType(zcu);
@@ -1730,10 +1735,13 @@ const NavGen = struct {
                             member_index += 1;
                         }
 
-                        const result_id = try self.spv.structType(member_types[0..member_index], null);
+                        const result_id = self.spv.allocId();
+                        try self.spv.structType(result_id, member_types[0..member_index], null);
+
                         const type_name = try self.resolveTypeName(ty);
                         defer self.gpa.free(type_name);
                         try self.spv.debugName(result_id, type_name);
+
                         return result_id;
                     },
                     .struct_type => ip.loadStructType(ty.toIntern()),
@@ -1750,7 +1758,9 @@ const NavGen = struct {
                 var member_names = std.ArrayList([]const u8).init(self.gpa);
                 defer member_names.deinit();
 
+                var index: u32 = 0;
                 var it = struct_type.iterateRuntimeOrder(ip);
+                const result_id = self.spv.allocId();
                 while (it.next()) |field_index| {
                     const field_ty = Type.fromInterned(struct_type.field_types.get(ip)[field_index]);
                     if (!field_ty.hasRuntimeBitsIgnoreComptime(zcu)) {
@@ -1758,16 +1768,25 @@ const NavGen = struct {
                         continue;
                     }
 
+                    if (target.os.tag == .vulkan) {
+                        try self.spv.decorateMember(result_id, index, .{ .Offset = .{
+                            .byte_offset = @intCast(ty.structFieldOffset(field_index, zcu)),
+                        } });
+                    }
                     const field_name = struct_type.fieldName(ip, field_index).unwrap() orelse
                         try ip.getOrPutStringFmt(zcu.gpa, pt.tid, "{d}", .{field_index}, .no_embedded_nulls);
                     try member_types.append(try self.resolveType(field_ty, .indirect));
                     try member_names.append(field_name.toSlice(ip));
+
+                    index += 1;
                 }
 
-                const result_id = try self.spv.structType(member_types.items, member_names.items);
+                try self.spv.structType(result_id, member_types.items, member_names.items);
+
                 const type_name = try self.resolveTypeName(ty);
                 defer self.gpa.free(type_name);
                 try self.spv.debugName(result_id, type_name);
+
                 return result_id;
             },
             .optional => {
@@ -1787,10 +1806,13 @@ const NavGen = struct {
 
                 const bool_ty_id = try self.resolveType(Type.bool, .indirect);
 
-                return try self.spv.structType(
+                const result_id = self.spv.allocId();
+                try self.spv.structType(
+                    result_id,
                     &.{ payload_ty_id, bool_ty_id },
                     &.{ "payload", "valid" },
                 );
+                return result_id;
             },
             .@"union" => return try self.resolveUnionType(ty),
             .error_set => return try self.resolveType(Type.u16, repr),
@@ -1819,7 +1841,9 @@ const NavGen = struct {
                     // TODO: ABI padding?
                 }
 
-                return try self.spv.structType(&member_types, &member_names);
+                const result_id = self.spv.allocId();
+                try self.spv.structType(result_id, &member_types, &member_names);
+                return result_id;
             },
             .@"opaque" => {
                 const type_name = try self.resolveTypeName(ty);
@@ -1849,7 +1873,7 @@ const NavGen = struct {
         const target = self.getTarget();
         return switch (as) {
             .generic => switch (target.os.tag) {
-                .vulkan => .Private,
+                .vulkan => .Function,
                 .opencl => .Generic,
                 else => unreachable,
             },
@@ -1861,6 +1885,7 @@ const NavGen = struct {
                 else => unreachable,
             },
             .constant => .UniformConstant,
+            .push_constant => .PushConstant,
             .input => .Input,
             .output => .Output,
             .uniform => .Uniform,
@@ -2958,10 +2983,8 @@ const NavGen = struct {
                     const spv_err_decl_index = try self.spv.allocDecl(.global);
                     try self.spv.declareDeclDeps(spv_err_decl_index, &.{});
 
-                    const push_constant_struct_ty_id = try self.spv.structType(
-                        &.{ptr_anyerror_ty_id},
-                        &.{"error_out_ptr"},
-                    );
+                    const push_constant_struct_ty_id = self.spv.allocId();
+                    try self.spv.structType(push_constant_struct_ty_id, &.{ptr_anyerror_ty_id}, &.{"error_out_ptr"});
                     try self.spv.decorate(push_constant_struct_ty_id, .Block);
                     try self.spv.decorateMember(push_constant_struct_ty_id, 0, .{ .Offset = .{ .byte_offset = 0 } });
 
@@ -3145,15 +3168,15 @@ const NavGen = struct {
                 };
                 assert(maybe_init_val == null); // TODO
 
-                const final_storage_class = self.spvStorageClass(nav.status.resolved.@"addrspace");
-                assert(final_storage_class != .Generic); // These should be instance globals
+                const storage_class = self.spvStorageClass(nav.status.resolved.@"addrspace");
+                assert(storage_class != .Generic); // These should be instance globals
 
-                const ptr_ty_id = try self.ptrType(ty, final_storage_class);
+                const ptr_ty_id = try self.ptrType(ty, storage_class);
 
                 try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpVariable, .{
                     .id_result_type = ptr_ty_id,
                     .id_result = result_id,
-                    .storage_class = final_storage_class,
+                    .storage_class = storage_class,
                 });
 
                 try self.spv.debugName(result_id, nav.fqn.toSlice(ip));
src/link/SpirV.zig
@@ -296,7 +296,7 @@ fn writeCapabilities(spv: *SpvModule, target: std.Target) !void {
     // TODO: Integrate with a hypothetical feature system
     const caps: []const spec.Capability = switch (target.os.tag) {
         .opencl => &.{ .Kernel, .Addresses, .Int8, .Int16, .Int64, .Float64, .Float16, .Vector16, .GenericPointer },
-        .vulkan => &.{ .Shader, .PhysicalStorageBufferAddresses, .StoragePushConstant16, .Int8, .Int16, .Int64, .Float64, .Float16 },
+        .vulkan => &.{ .Shader, .PhysicalStorageBufferAddresses, .Int8, .Int16, .Int64, .Float64, .Float16 },
         else => unreachable,
     };
 
src/Sema.zig
@@ -37820,7 +37820,7 @@ pub fn analyzeAsAddressSpace(
         .gs, .fs, .ss => (arch == .x86 or arch == .x86_64) and ctx == .pointer,
         // TODO: check that .shared and .local are left uninitialized
         .param => is_nv,
-        .input, .output, .uniform => is_spirv,
+        .input, .output, .uniform, .push_constant => is_spirv,
         .global, .shared, .local => is_gpu,
         .constant => is_gpu and (ctx == .constant),
         // TODO this should also check how many flash banks the cpu has
src/target.zig
@@ -418,7 +418,7 @@ pub fn arePointersLogical(target: std.Target, as: AddressSpace) bool {
         .global => false,
         // TODO: Allowed with VK_KHR_variable_pointers.
         .shared => true,
-        .constant, .local, .input, .output, .uniform => true,
+        .constant, .local, .input, .output, .uniform, .push_constant => true,
         else => unreachable,
     };
 }