Commit efe7fae6af

Robin Voetter <robin@voetter.nl>
2023-04-09 01:27:02
spirv: temporarily emit test kernels
SPIR-V cannot represent function pointers without extensions that no vendor implements. For the time being, generate a test kernel for each error, so that we can at least run SOME tests. In the future we may be able to emulate function pointers in some way, but that is not today.
1 parent 719d47d
Changed files (2)
src
codegen
src/codegen/spirv/Module.zig
@@ -187,8 +187,8 @@ fn orderGlobalsInto(
     seen: *std.DynamicBitSetUnmanaged,
 ) !void {
     const node = self.globals.nodes.items[@enumToInt(global_index)];
-    const deps = self.globals.dependencies.items[node.begin_dep .. node.end_dep];
-    const insts = self.globals.section.instructions.items[node.begin_inst .. node.end_inst];
+    const deps = self.globals.dependencies.items[node.begin_dep..node.end_dep];
+    const insts = self.globals.section.instructions.items[node.begin_inst..node.end_inst];
 
     seen.set(@enumToInt(global_index));
 
@@ -725,7 +725,7 @@ pub fn allocGlobal(self: *Module) !Global.Index {
         .begin_inst = undefined,
         .end_inst = undefined,
         .begin_dep = undefined,
-            .end_dep = undefined,
+        .end_dep = undefined,
     });
     return @intToEnum(Global.Index, @intCast(u32, self.globals.nodes.items.len - 1));
 }
src/codegen/spirv.zig
@@ -272,7 +272,7 @@ pub const DeclGen = struct {
 
         if (!entry.found_existing) {
             if (decl.val.castTag(.function)) |_| {
-                entry.value_ptr.* = .{.func = .{ .result_id = result_id }};
+                entry.value_ptr.* = .{ .func = .{ .result_id = result_id } };
             } else {
                 entry.value_ptr.* = .{ .global = try self.spv.allocGlobal() };
             }
@@ -418,11 +418,7 @@ pub const DeclGen = struct {
 
     fn genUndef(self: *DeclGen, ty_ref: SpvType.Ref) Error!IdRef {
         const result_id = self.spv.allocId();
-        try self.spv.sections.types_globals_constants.emit(
-            self.spv.gpa,
-            .OpUndef,
-            .{ .id_result_type = self.typeId(ty_ref), .id_result = result_id }
-        );
+        try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpUndef, .{ .id_result_type = self.typeId(ty_ref), .id_result = result_id });
         return result_id;
     }
 
@@ -899,11 +895,12 @@ pub const DeclGen = struct {
             .initializer = constant_struct_id,
         });
         // TODO: Set alignment of OpVariable.
-        // TODO: We may be able to eliminate this cast.
+        // TODO: We may be able to eliminate these casts.
+        const const_ptr_id = try self.makePointerConstant(section, ptr_constant_struct_ty_ref, var_id);
         try section.emitSpecConstantOp(self.spv.gpa, .OpBitcast, .{
             .id_result_type = self.typeId(ptr_ty_ref),
             .id_result = result_id,
-            .operand = var_id,
+            .operand = const_ptr_id,
         });
     }
 
@@ -1267,13 +1264,13 @@ pub const DeclGen = struct {
                 // Similar to unions, we're going to put the most aligned member first.
                 if (error_align > payload_align) {
                     // Put the error first
-                    members.appendAssumeCapacity(.{ .ty = error_ty_ref, .name = "error"  });
-                    members.appendAssumeCapacity(.{ .ty = payload_ty_ref, .name = "payload"  });
+                    members.appendAssumeCapacity(.{ .ty = error_ty_ref, .name = "error" });
+                    members.appendAssumeCapacity(.{ .ty = payload_ty_ref, .name = "payload" });
                     // TODO: ABI padding?
                 } else {
                     // Put the payload first.
-                    members.appendAssumeCapacity(.{ .ty = payload_ty_ref, .name = "payload"  });
-                    members.appendAssumeCapacity(.{ .ty = error_ty_ref, .name = "error"  });
+                    members.appendAssumeCapacity(.{ .ty = payload_ty_ref, .name = "payload" });
+                    members.appendAssumeCapacity(.{ .ty = error_ty_ref, .name = "error" });
                     // TODO: ABI padding?
                 }
 
@@ -1302,12 +1299,81 @@ pub const DeclGen = struct {
         };
     }
 
+    /// The SPIR-V backend is not yet advanced enough to support the std testing infrastructure.
+    /// In order to be able to run tests, we "temporarily" lower test kernels into separate entry-
+    /// points. The test executor will then be able to invoke these to run the tests.
+    /// Note that tests are lowered according to std.builtin.TestFn, which is `fn () anyerror!void`.
+    /// (anyerror!void has the same layout as anyerror).
+    /// Each test declaration generates a function like.
+    ///   %anyerror = OpTypeInt 0 16
+    ///   %p_anyerror = OpTypePointer CrossWorkgroup %anyerror
+    ///   %K = OpTypeFunction %void %p_anyerror
+    ///
+    ///   %test = OpFunction %void %K
+    ///   %p_err = OpFunctionParameter %p_anyerror
+    ///   %lbl = OpLabel
+    ///   %result = OpFunctionCall %anyerror %func
+    ///   OpStore %p_err %result
+    ///   OpFunctionEnd
+    /// TODO is to also write out the error as a function call parameter, and to somehow fetch
+    /// the name of an error in the text executor.
+    fn generateTestEntryPoint(self: *DeclGen, name: []const u8, func: IdResult) !void {
+        const anyerror_ty_ref = try self.resolveType(Type.anyerror, .direct);
+        const ptr_anyerror_ty_ref = try self.spv.ptrType(anyerror_ty_ref, .CrossWorkgroup, null);
+        const void_ty_ref = try self.resolveType(Type.void, .direct);
+
+        const kernel_proto_ty_ref = blk: {
+            const proto_payload = try self.spv.arena.create(SpvType.Payload.Function);
+            proto_payload.* = .{
+                .return_type = void_ty_ref,
+                .parameters = try self.spv.arena.dupe(SpvType.Ref, &.{ptr_anyerror_ty_ref}),
+            };
+            break :blk try self.spv.resolveType(SpvType.initPayload(&proto_payload.base));
+        };
+
+        const kernel_id = self.spv.allocId();
+        const error_id = self.spv.allocId();
+        const p_error_id = self.spv.allocId();
+
+        const section = &self.spv.sections.functions;
+        try section.emit(self.spv.gpa, .OpFunction, .{
+            .id_result_type = self.typeId(void_ty_ref),
+            .id_result = kernel_id,
+            .function_control = .{},
+            .function_type = self.typeId(kernel_proto_ty_ref),
+        });
+        try section.emit(self.spv.gpa, .OpFunctionParameter, .{
+            .id_result_type = self.typeId(ptr_anyerror_ty_ref),
+            .id_result = p_error_id,
+        });
+        try section.emit(self.spv.gpa, .OpLabel, .{
+            .id_result = self.spv.allocId(),
+        });
+        try section.emit(self.spv.gpa, .OpFunctionCall, .{
+            .id_result_type = self.typeId(anyerror_ty_ref),
+            .id_result = error_id,
+            .function = func,
+        });
+        try section.emit(self.spv.gpa, .OpStore, .{
+            .pointer = p_error_id,
+            .object = error_id,
+        });
+        try section.emit(self.spv.gpa, .OpReturn, {});
+        try section.emit(self.spv.gpa, .OpFunctionEnd, {});
+
+        try self.spv.sections.entry_points.emit(self.spv.gpa, .OpEntryPoint, .{
+            .execution_model = .Kernel,
+            .entry_point = kernel_id,
+            .name = name,
+        });
+    }
+
     fn genDecl(self: *DeclGen) !void {
         const decl = self.module.declPtr(self.decl_index);
         const link = try self.resolveDecl(self.decl_index);
 
         if (decl.val.castTag(.function)) |_| {
-            log.debug("genDecl function {s} = {}", .{decl.name, link.func.result_id.id});
+            log.debug("genDecl function {s} = {}", .{ decl.name, link.func.result_id.id });
 
             assert(decl.ty.zigTypeTag() == .Fn);
             const prototype_id = try self.resolveTypeId(decl.ty);
@@ -1356,6 +1422,10 @@ pub const DeclGen = struct {
                 .target = link.func.result_id,
                 .name = fqn,
             });
+
+            if (self.module.test_functions.contains(self.decl_index)) {
+                try self.generateTestEntryPoint(fqn, link.func.result_id);
+            }
         } else {
             const init_val = if (decl.val.castTag(.variable)) |payload|
                 payload.data.init
@@ -1396,6 +1466,7 @@ pub const DeclGen = struct {
                 const ty_ref = try self.resolveType(decl.ty, .indirect);
                 const ptr_ty_ref = try self.spv.ptrType(ty_ref, storage_class, decl.@"align");
                 // TODO: Can we eliminate this cast?
+                // TODO: Const-wash pointer
                 try section.emitSpecConstantOp(self.spv.gpa, .OpPtrCastToGeneric, .{
                     .id_result_type = self.typeId(ptr_ty_ref),
                     .id_result = global_result_id,
@@ -2036,6 +2107,24 @@ pub const DeclGen = struct {
         return try self.structFieldPtr(result_ptr_ty, struct_ptr_ty, struct_ptr, field_index);
     }
 
+    /// We cannot use an OpVariable directly in an OpSpecConstantOp, but we can
+    /// after we insert a dummy AccessChain...
+    /// TODO: Get rid of this
+    fn makePointerConstant(
+        self: *DeclGen,
+        section: *SpvSection,
+        ptr_ty_ref: SpvType.Ref,
+        ptr_id: IdRef,
+    ) !IdRef {
+        const result_id = self.spv.allocId();
+        try section.emitSpecConstantOp(self.spv.gpa, .OpInBoundsAccessChain, .{
+            .id_result_type = self.typeId(ptr_ty_ref),
+            .id_result = result_id,
+            .base = ptr_id,
+        });
+        return result_id;
+    }
+
     fn variable(
         self: *DeclGen,
         comptime context: enum { function, global },
@@ -2088,11 +2177,14 @@ pub const DeclGen = struct {
                 .pointer = alloc_result_id,
             }),
             // TODO: Can we do without this cast or move it to runtime?
-            else => try section.emitSpecConstantOp(self.spv.gpa, .OpPtrCastToGeneric, .{
-                .id_result_type = self.typeId(ptr_ty_ref),
-                .id_result = result_id,
-                .pointer = alloc_result_id,
-            }),
+            else => {
+                const const_ptr_id = try self.makePointerConstant(section, actual_ptr_ty_ref, alloc_result_id);
+                try section.emitSpecConstantOp(self.spv.gpa, .OpPtrCastToGeneric, .{
+                    .id_result_type = self.typeId(ptr_ty_ref),
+                    .id_result = result_id,
+                    .pointer = const_ptr_id,
+                });
+            },
         }
     }