Commit 39013619b9

Robin Voetter <robin@voetter.nl>
2024-10-20 17:10:55
spirv: generate test entry points for vulkan
1 parent 7c69231
Changed files (2)
src/codegen/spirv.zig
@@ -169,6 +169,13 @@ pub const Object = struct {
     ///   via the usual `intern_map` mechanism.
     ptr_types: PtrTypeMap = .{},
 
+    /// For test declarations for Vulkan, we have to add a push constant with a pointer to a
+    /// buffer that we can use. We only need to generate this once, this holds the link information
+    /// related to that.
+    error_push_constant: ?struct {
+        push_constant_ptr: SpvModule.Decl.Index,
+    } = null,
+
     pub fn init(gpa: Allocator) Object {
         return .{
             .gpa = gpa,
@@ -2908,30 +2915,118 @@ const NavGen = struct {
             .flags = .{ .address_space = .global },
         });
         const ptr_anyerror_ty_id = try self.resolveType(ptr_anyerror_ty, .direct);
-        const kernel_proto_ty_id = try self.functionType(Type.void, &.{ptr_anyerror_ty});
-
-        const test_id = self.spv.declPtr(spv_test_decl_index).result_id;
 
         const spv_decl_index = try self.spv.allocDecl(.func);
         const kernel_id = self.spv.declPtr(spv_decl_index).result_id;
+        // for some reason we don't need to decorate the push constant here...
+        try self.spv.declareDeclDeps(spv_decl_index, &.{spv_test_decl_index});
+
+        const section = &self.spv.sections.functions;
+
+        const target = self.getTarget();
 
-        const error_id = self.spv.allocId();
         const p_error_id = self.spv.allocId();
+        switch (target.os.tag) {
+            .opencl => {
+                const kernel_proto_ty_id = try self.functionType(Type.void, &.{ptr_anyerror_ty});
 
-        const section = &self.spv.sections.functions;
-        try section.emit(self.spv.gpa, .OpFunction, .{
-            .id_result_type = try self.resolveType(Type.void, .direct),
-            .id_result = kernel_id,
-            .function_control = .{},
-            .function_type = kernel_proto_ty_id,
-        });
-        try section.emit(self.spv.gpa, .OpFunctionParameter, .{
-            .id_result_type = ptr_anyerror_ty_id,
-            .id_result = p_error_id,
-        });
-        try section.emit(self.spv.gpa, .OpLabel, .{
-            .id_result = self.spv.allocId(),
-        });
+                try section.emit(self.spv.gpa, .OpFunction, .{
+                    .id_result_type = try self.resolveType(Type.void, .direct),
+                    .id_result = kernel_id,
+                    .function_control = .{},
+                    .function_type = kernel_proto_ty_id,
+                });
+
+                try section.emit(self.spv.gpa, .OpFunctionParameter, .{
+                    .id_result_type = ptr_anyerror_ty_id,
+                    .id_result = p_error_id,
+                });
+
+                try section.emit(self.spv.gpa, .OpLabel, .{
+                    .id_result = self.spv.allocId(),
+                });
+            },
+            .vulkan => {
+                const ptr_ptr_anyerror_ty_id = self.spv.allocId();
+                try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpTypePointer, .{
+                    .id_result = ptr_ptr_anyerror_ty_id,
+                    .storage_class = .PushConstant,
+                    .type = ptr_anyerror_ty_id,
+                });
+
+                if (self.object.error_push_constant == null) {
+                    const spv_err_decl_index = try self.spv.allocDecl(.global);
+                    try self.spv.declareDeclDeps(spv_err_decl_index, &.{});
+
+                    const push_constant_struct_ty_id = try self.spv.structType(
+                        &.{ptr_anyerror_ty_id},
+                        &.{"error_out_ptr"},
+                    );
+                    try self.spv.decorate(push_constant_struct_ty_id, .Block);
+                    try self.spv.decorateMember(push_constant_struct_ty_id, 0, .{ .Offset = .{ .byte_offset = 0 } });
+
+                    const ptr_push_constant_struct_ty_id = self.spv.allocId();
+                    try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpTypePointer, .{
+                        .id_result = ptr_push_constant_struct_ty_id,
+                        .storage_class = .PushConstant,
+                        .type = push_constant_struct_ty_id,
+                    });
+
+                    try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpVariable, .{
+                        .id_result_type = ptr_push_constant_struct_ty_id,
+                        .id_result = self.spv.declPtr(spv_err_decl_index).result_id,
+                        .storage_class = .PushConstant,
+                    });
+
+                    self.object.error_push_constant = .{
+                        .push_constant_ptr = spv_err_decl_index,
+                    };
+                }
+
+                try self.spv.sections.execution_modes.emit(self.spv.gpa, .OpExecutionMode, .{
+                    .entry_point = kernel_id,
+                    .mode = .{ .LocalSize = .{
+                        .x_size = 1,
+                        .y_size = 1,
+                        .z_size = 1,
+                    } },
+                });
+
+                const kernel_proto_ty_id = try self.functionType(Type.void, &.{});
+                try section.emit(self.spv.gpa, .OpFunction, .{
+                    .id_result_type = try self.resolveType(Type.void, .direct),
+                    .id_result = kernel_id,
+                    .function_control = .{},
+                    .function_type = kernel_proto_ty_id,
+                });
+                try section.emit(self.spv.gpa, .OpLabel, .{
+                    .id_result = self.spv.allocId(),
+                });
+
+                const spv_err_decl_index = self.object.error_push_constant.?.push_constant_ptr;
+                const push_constant_id = self.spv.declPtr(spv_err_decl_index).result_id;
+
+                const zero_id = try self.constInt(Type.u32, 0, .direct);
+                // We cannot use OpInBoundsAccessChain to dereference cross-storage class, so we have to use
+                // a load.
+                const tmp = self.spv.allocId();
+                try section.emit(self.spv.gpa, .OpInBoundsAccessChain, .{
+                    .id_result_type = ptr_ptr_anyerror_ty_id,
+                    .id_result = tmp,
+                    .base = push_constant_id,
+                    .indexes = &.{zero_id},
+                });
+                try section.emit(self.spv.gpa, .OpLoad, .{
+                    .id_result_type = ptr_anyerror_ty_id,
+                    .id_result = p_error_id,
+                    .pointer = tmp,
+                });
+            },
+            else => unreachable,
+        }
+
+        const test_id = self.spv.declPtr(spv_test_decl_index).result_id;
+        const error_id = self.spv.allocId();
         try section.emit(self.spv.gpa, .OpFunctionCall, .{
             .id_result_type = anyerror_ty_id,
             .id_result = error_id,
@@ -2941,17 +3036,25 @@ const NavGen = struct {
         try section.emit(self.spv.gpa, .OpStore, .{
             .pointer = p_error_id,
             .object = error_id,
+            .memory_access = .{
+                .Aligned = .{ .literal_integer = @sizeOf(u16) },
+            },
         });
         try section.emit(self.spv.gpa, .OpReturn, {});
         try section.emit(self.spv.gpa, .OpFunctionEnd, {});
 
-        try self.spv.declareDeclDeps(spv_decl_index, &.{spv_test_decl_index});
-
         // Just generate a quick other name because the intel runtime crashes when the entry-
         // point name is the same as a different OpName.
         const test_name = try std.fmt.allocPrint(self.gpa, "test {s}", .{name});
         defer self.gpa.free(test_name);
-        try self.spv.declareEntryPoint(spv_decl_index, test_name, .Kernel);
+
+        const execution_mode: spec.ExecutionModel = switch (target.os.tag) {
+            .vulkan => .GLCompute,
+            .opencl => .Kernel,
+            else => unreachable,
+        };
+
+        try self.spv.declareEntryPoint(spv_decl_index, test_name, execution_mode);
     }
 
     fn genNav(self: *NavGen, do_codegen: bool) !void {
src/link/SpirV/lower_invocation_globals.zig
@@ -400,6 +400,15 @@ const ModuleBuilder = struct {
                     self.section.writeWords(inst.operands[2..]);
                     continue;
                 },
+                .OpExecutionMode, .OpExecutionModeId => {
+                    const original_id: ResultId = @enumFromInt(inst.operands[0]);
+                    const new_id_index = info.entry_points.getIndex(original_id).?;
+                    const new_id: ResultId = @enumFromInt(self.entry_point_new_id_base + new_id_index);
+                    try self.section.emitRaw(self.arena, inst.opcode, inst.operands.len);
+                    self.section.writeOperand(ResultId, new_id);
+                    self.section.writeWords(inst.operands[1..]);
+                    continue;
+                },
                 .OpTypeFunction => {
                     // Re-emitted in `emitFunctionTypes()`. We can do this because
                     // OpTypeFunction's may not currently be used anywhere that is not