Commit 3f2025f59e

Robin Voetter <robin@voetter.nl>
2023-04-08 01:37:00
spirv: emit interface variables for entry points
Also actually implement generating the OpEntryPoint instructions.
1 parent 405f729
Changed files (3)
src/codegen/spirv/Assembler.zig
@@ -435,10 +435,12 @@ fn processGenericInstruction(self: *Assembler) !?AsmValue {
         .Annotation => &self.spv.sections.annotations,
         .TypeDeclaration => unreachable, // Handled elsewhere.
         else => switch (self.inst.opcode) {
-            .OpEntryPoint => &self.spv.sections.entry_points,
+            // TODO: This should emit a proper entry point.
+            .OpEntryPoint => unreachable, // &self.spv.sections.entry_points,
             .OpExecutionMode, .OpExecutionModeId => &self.spv.sections.execution_modes,
             .OpVariable => switch (@intToEnum(spec.StorageClass, operands[2].value)) {
                 .Function => &self.func.prologue,
+                // TODO: Emit a decl dependency
                 else => &self.spv.sections.types_globals_constants,
             },
             // Default case - to be worked out further.
src/codegen/spirv/Module.zig
@@ -93,6 +93,14 @@ pub const Global = struct {
     end_inst: u32,
 };
 
+/// This models a kernel entry point.
+pub const EntryPoint = struct {
+    /// The declaration that should be exported.
+    decl_index: Decl.Index,
+    /// The name of the kernel to be exported.
+    name: []const u8,
+};
+
 /// A general-purpose allocator which may be used to allocate resources for this module
 gpa: Allocator,
 
@@ -107,8 +115,7 @@ sections: struct {
     extensions: Section = .{},
     // OpExtInstImport instructions - skip for now.
     // memory model defined by target, not required here.
-    /// OpEntryPoint instructions.
-    entry_points: Section = .{},
+    /// OpEntryPoint instructions - Handled by `self.entry_points`.
     /// OpExecutionMode and OpExecutionModeId instructions.
     execution_modes: Section = .{},
     /// OpString, OpSourcExtension, OpSource, OpSourceContinued.
@@ -143,8 +150,13 @@ type_cache: TypeCache = .{},
 /// Set of Decls, referred to by Decl.Index.
 decls: std.ArrayListUnmanaged(Decl) = .{},
 
+/// List of dependencies, per decl. This list holds all the dependencies, sliced by the
+/// begin_dep and end_dep in `self.decls`.
 decl_deps: std.ArrayListUnmanaged(Decl.Index) = .{},
 
+/// The list of entry points that should be exported from this module.
+entry_points: std.ArrayListUnmanaged(EntryPoint) = .{},
+
 /// The fields in this structure help to maintain the required order for global variables.
 globals: struct {
     /// Set of globals, referred to by Decl.Index.
@@ -166,7 +178,6 @@ pub fn init(gpa: Allocator, arena: Allocator) Module {
 pub fn deinit(self: *Module) void {
     self.sections.capabilities.deinit(self.gpa);
     self.sections.extensions.deinit(self.gpa);
-    self.sections.entry_points.deinit(self.gpa);
     self.sections.execution_modes.deinit(self.gpa);
     self.sections.debug_strings.deinit(self.gpa);
     self.sections.debug_names.deinit(self.gpa);
@@ -180,6 +191,8 @@ pub fn deinit(self: *Module) void {
     self.decls.deinit(self.gpa);
     self.decl_deps.deinit(self.gpa);
 
+    self.entry_points.deinit(self.gpa);
+
     self.globals.globals.deinit(self.gpa);
     self.globals.section.deinit(self.gpa);
 
@@ -202,16 +215,16 @@ pub fn idBound(self: Module) Word {
 
 fn orderGlobalsInto(
     self: *Module,
-    index: Decl.Index,
+    decl_index: Decl.Index,
     section: *Section,
     seen: *std.DynamicBitSetUnmanaged,
 ) !void {
-    const decl = self.declPtr(index);
+    const decl = self.declPtr(decl_index);
     const deps = self.decl_deps.items[decl.begin_dep..decl.end_dep];
-    const global = self.globalPtr(index).?;
+    const global = self.globalPtr(decl_index).?;
     const insts = self.globals.section.instructions.items[global.begin_inst..global.end_inst];
 
-    seen.set(@enumToInt(index));
+    seen.set(@enumToInt(decl_index));
 
     for (deps) |dep| {
         if (!seen.isSet(@enumToInt(dep))) {
@@ -229,6 +242,8 @@ fn orderGlobals(self: *Module) !Section {
     defer seen.deinit(self.gpa);
 
     var ordered_globals = Section{};
+    errdefer ordered_globals.deinit(self.gpa);
+
     for (globals) |decl_index| {
         if (!seen.isSet(@enumToInt(decl_index))) {
             try self.orderGlobalsInto(decl_index, &ordered_globals, &seen);
@@ -238,6 +253,56 @@ fn orderGlobals(self: *Module) !Section {
     return ordered_globals;
 }
 
+fn addEntryPointDeps(
+    self: *Module,
+    decl_index: Decl.Index,
+    seen: *std.DynamicBitSetUnmanaged,
+    interface: *std.ArrayList(IdRef),
+) !void {
+    const decl = self.declPtr(decl_index);
+    const deps = self.decl_deps.items[decl.begin_dep..decl.end_dep];
+
+    seen.set(@enumToInt(decl_index));
+
+    if (self.globalPtr(decl_index)) |global| {
+        try interface.append(global.result_id);
+    }
+
+    for (deps) |dep| {
+        if (!seen.isSet(@enumToInt(dep))) {
+            try self.addEntryPointDeps(dep, seen, interface);
+        }
+    }
+}
+
+fn entryPoints(self: *Module) !Section {
+    var entry_points = Section{};
+    errdefer entry_points.deinit(self.gpa);
+
+    var interface = std.ArrayList(IdRef).init(self.gpa);
+    defer interface.deinit();
+
+    var seen = try std.DynamicBitSetUnmanaged.initEmpty(self.gpa, self.decls.items.len);
+    defer seen.deinit(self.gpa);
+
+    for (self.entry_points.items) |entry_point| {
+        interface.items.len = 0;
+        seen.setRangeValue(.{ .start = 0, .end = self.decls.items.len }, false);
+
+        try self.addEntryPointDeps(entry_point.decl_index, &seen, &interface);
+
+        const entry_point_id = self.declPtr(entry_point.decl_index).result_id;
+        try entry_points.emit(self.gpa, .OpEntryPoint, .{
+            .execution_model = .Kernel,
+            .entry_point = entry_point_id,
+            .name = entry_point.name,
+            .interface = interface.items,
+        });
+    }
+
+    return entry_points;
+}
+
 /// Emit this module as a spir-v binary.
 pub fn flush(self: *Module, file: std.fs.File) !void {
     // See SPIR-V Spec section 2.3, "Physical Layout of a SPIR-V Module and Instruction"
@@ -256,12 +321,15 @@ pub fn flush(self: *Module, file: std.fs.File) !void {
     var globals = try self.orderGlobals();
     defer globals.deinit(self.gpa);
 
+    var entry_points = try self.entryPoints();
+    defer entry_points.deinit(self.gpa);
+
     // Note: needs to be kept in order according to section 2.3!
     const buffers = &[_][]const Word{
         &header,
         self.sections.capabilities.toWords(),
         self.sections.extensions.toWords(),
-        self.sections.entry_points.toWords(),
+        entry_points.toWords(),
         self.sections.execution_modes.toWords(),
         self.sections.debug_strings.toWords(),
         self.sections.debug_names.toWords(),
@@ -795,3 +863,10 @@ pub fn endGlobal(self: *Module, global_index: Decl.Index, begin_inst: u32) void
     global.begin_inst = begin_inst;
     global.end_inst = @intCast(u32, self.globals.section.instructions.items.len);
 }
+
+pub fn declareEntryPoint(self: *Module, decl_index: Decl.Index, name: []const u8) !void {
+    try self.entry_points.append(self.gpa, .{
+        .decl_index = decl_index,
+        .name = try self.arena.dupe(u8, name),
+    });
+}
src/codegen/spirv.zig
@@ -1343,7 +1343,7 @@ pub const DeclGen = struct {
     ///   OpFunctionEnd
     /// TODO is to also write out the error as a function call parameter, and to somehow fetch
     /// the name of an error in the text executor.
-    fn generateTestEntryPoint(self: *DeclGen, name: []const u8, func: IdResult) !void {
+    fn generateTestEntryPoint(self: *DeclGen, name: []const u8, spv_test_decl_index: SpvModule.Decl.Index) !void {
         const anyerror_ty_ref = try self.resolveType(Type.anyerror, .direct);
         const ptr_anyerror_ty_ref = try self.spv.ptrType(anyerror_ty_ref, .CrossWorkgroup, null);
         const void_ty_ref = try self.resolveType(Type.void, .direct);
@@ -1357,7 +1357,11 @@ pub const DeclGen = struct {
             break :blk try self.spv.resolveType(SpvType.initPayload(&proto_payload.base));
         };
 
-        const kernel_id = self.spv.allocId();
+        const test_id = self.spv.declPtr(spv_test_decl_index).result_id;
+
+        const spv_decl_index = try self.spv.allocDecl(.func);
+        const kernel_id = self.spv.declPtr(spv_decl_index).result_id;
+
         const error_id = self.spv.allocId();
         const p_error_id = self.spv.allocId();
 
@@ -1378,7 +1382,7 @@ pub const DeclGen = struct {
         try section.emit(self.spv.gpa, .OpFunctionCall, .{
             .id_result_type = self.typeId(anyerror_ty_ref),
             .id_result = error_id,
-            .function = func,
+            .function = test_id,
         });
         try section.emit(self.spv.gpa, .OpStore, .{
             .pointer = p_error_id,
@@ -1387,11 +1391,13 @@ pub const DeclGen = struct {
         try section.emit(self.spv.gpa, .OpReturn, {});
         try section.emit(self.spv.gpa, .OpFunctionEnd, {});
 
-        try self.spv.sections.entry_points.emit(self.spv.gpa, .OpEntryPoint, .{
-            .execution_model = .Kernel,
-            .entry_point = kernel_id,
-            .name = name,
-        });
+        try self.spv.declareDeclDeps(spv_decl_index, &.{spv_test_decl_index});
+
+        // Just generate a quick other name because the intel runtime crashes when the entry-
+        // point name is the same as a different OpName.
+        const test_name = try std.fmt.allocPrint(self.gpa, "test {s}", .{name});
+        defer self.gpa.free(test_name);
+        try self.spv.declareEntryPoint(spv_decl_index, test_name);
     }
 
     fn genDecl(self: *DeclGen) !void {
@@ -1451,7 +1457,7 @@ pub const DeclGen = struct {
             });
 
             if (self.module.test_functions.contains(self.decl_index)) {
-                try self.generateTestEntryPoint(fqn, decl_id);
+                try self.generateTestEntryPoint(fqn, spv_decl_index);
             }
         } else {
             const init_val = if (decl.val.castTag(.variable)) |payload|