Commit 8fa54eb798

Ali Cheraghi <alichraghi@proton.me>
2025-05-11 14:15:44
spirv: error when execution mode is set more than once
1 parent 9209f4b
Changed files (5)
lib/std/gpu.zig
@@ -1,81 +1,24 @@
 const std = @import("std.zig");
 
-/// Will make `ptr` contain the location of the current invocation within the
-/// global workgroup. Each component is equal to the index of the local workgroup
-/// multiplied by the size of the local workgroup plus `localInvocationId`.
-/// `ptr` must be a reference to variable or struct field.
-pub fn globalInvocationId(comptime ptr: *addrspace(.input) @Vector(3, u32)) void {
-    asm volatile (
-        \\OpDecorate %ptr BuiltIn GlobalInvocationId
-        :
-        : [ptr] "" (ptr),
-    );
-}
-
-/// Will make that variable contain the location of the current cluster
-/// culling, task, mesh, or compute shader invocation within the local
-/// workgroup. Each component ranges from zero through to the size of the
-/// workgroup in that dimension minus one.
-/// `ptr` must be a reference to variable or struct field.
-pub fn localInvocationId(comptime ptr: *addrspace(.input) @Vector(3, u32)) void {
-    asm volatile (
-        \\OpDecorate %ptr BuiltIn LocalInvocationId
-        :
-        : [ptr] "" (ptr),
-    );
-}
-
-/// Output vertex position from a `Vertex` entrypoint
-/// `ptr` must be a reference to variable or struct field.
-pub fn position(comptime ptr: *addrspace(.output) @Vector(4, f32)) void {
-    asm volatile (
-        \\OpDecorate %ptr BuiltIn Position
-        :
-        : [ptr] "" (ptr),
-    );
-}
-
-/// Will make `ptr` contain the index of the vertex that is
-/// being processed by the current vertex shader invocation.
-/// `ptr` must be a reference to variable or struct field.
-pub fn vertexIndex(comptime ptr: *addrspace(.input) u32) void {
-    asm volatile (
-        \\OpDecorate %ptr BuiltIn VertexIndex
-        :
-        : [ptr] "" (ptr),
-    );
-}
-
-/// Will make `ptr` contain the index of the instance that is
-/// being processed by the current vertex shader invocation.
-/// `ptr` must be a reference to variable or struct field.
-pub fn instanceIndex(comptime ptr: *addrspace(.input) u32) void {
-    asm volatile (
-        \\OpDecorate %ptr BuiltIn InstanceIndex
-        :
-        : [ptr] "" (ptr),
-    );
-}
-
-/// Output fragment depth from a `Fragment` entrypoint
-/// `ptr` must be a reference to variable or struct field.
-pub fn fragmentCoord(comptime ptr: *addrspace(.input) @Vector(4, f32)) void {
-    asm volatile (
-        \\OpDecorate %ptr BuiltIn FragCoord
-        :
-        : [ptr] "" (ptr),
-    );
-}
-
-/// Output fragment depth from a `Fragment` entrypoint
-/// `ptr` must be a reference to variable or struct field.
-pub fn fragmentDepth(comptime ptr: *addrspace(.output) f32) void {
-    asm volatile (
-        \\OpDecorate %ptr BuiltIn FragDepth
-        :
-        : [ptr] "" (ptr),
-    );
-}
+pub const position_in = @extern(*addrspace(.input) @Vector(4, f32), .{ .name = "position" });
+pub const position_out = @extern(*addrspace(.output) @Vector(4, f32), .{ .name = "position" });
+pub const point_size_in = @extern(*addrspace(.input) f32, .{ .name = "point_size" });
+pub const point_size_out = @extern(*addrspace(.output) f32, .{ .name = "point_size" });
+pub extern const invocation_id: u32 addrspace(.input);
+pub extern const frag_coord: @Vector(4, f32) addrspace(.input);
+pub extern const point_coord: @Vector(2, f32) addrspace(.input);
+// TODO: direct/indirect values
+// pub extern const front_facing: bool addrspace(.input);
+// TODO: runtime array
+// pub extern const sample_mask;
+pub extern var frag_depth: f32 addrspace(.output);
+pub extern const num_workgroups: @Vector(3, u32) addrspace(.input);
+pub extern const workgroup_size: @Vector(3, u32) addrspace(.input);
+pub extern const workgroup_id: @Vector(3, u32) addrspace(.input);
+pub extern const local_invocation_id: @Vector(3, u32) addrspace(.input);
+pub extern const global_invocation_id: @Vector(3, u32) addrspace(.input);
+pub extern const vertex_index: u32 addrspace(.input);
+pub extern const instance_index: u32 addrspace(.input);
 
 /// Forms the main linkage for `input` and `output` address spaces.
 /// `ptr` must be a reference to variable or struct field.
@@ -101,74 +44,85 @@ pub fn binding(comptime ptr: anytype, comptime set: u32, comptime bind: u32) voi
     );
 }
 
-pub const Origin = enum(u32) {
-    /// Increase toward the right and downward
-    upper_left = 7,
-    /// Increase toward the right and upward
-    lower_left = 8,
-};
-
-/// The coordinates appear to originate in the specified `origin`.
-/// Only valid with the `Fragment` calling convention.
-pub fn fragmentOrigin(comptime entry_point: anytype, comptime origin: Origin) void {
-    asm volatile (
-        \\OpExecutionMode %entry_point $origin
-        :
-        : [entry_point] "" (entry_point),
-          [origin] "c" (@intFromEnum(origin)),
-    );
-}
-
-pub const DepthMode = enum(u32) {
-    /// Declares that this entry point dynamically writes the
-    /// `fragmentDepth` built in-decorated variable.
-    replacing = 12,
+pub const ExecutionMode = union(Tag) {
+    /// Sets origin of the framebuffer to the upper-left corner
+    origin_upper_left,
+    /// Sets origin of the framebuffer to the lower-left corner
+    origin_lower_left,
+    /// Indicates that the fragment shader writes to `frag_depth`,
+    /// replacing the fixed-function depth value.
+    depth_replacing,
     /// Indicates that per-fragment tests may assume that
-    /// any `fragmentDepth` built in-decorated value written by the shader is
+    /// any `frag_depth` built in-decorated value written by the shader is
     /// greater-than-or-equal to the fragment’s interpolated depth value
-    greater = 14,
+    depth_greater,
     /// Indicates that per-fragment tests may assume that
-    /// any `fragmentDepth` built in-decorated value written by the shader is
+    /// any `frag_depth` built in-decorated value written by the shader is
     /// less-than-or-equal to the fragment’s interpolated depth value
-    less = 15,
+    depth_less,
     /// Indicates that per-fragment tests may assume that
-    /// any `fragmentDepth` built in-decorated value written by the shader is
+    /// any `frag_depth` built in-decorated value written by the shader is
     /// the same as the fragment’s interpolated depth value
-    unchanged = 16,
-};
+    depth_unchanged,
+    /// Indicates the workgroup size in the x, y, and z dimensions.
+    local_size: LocalSize,
 
-/// Only valid with the `Fragment` calling convention.
-pub fn depthMode(comptime entry_point: anytype, comptime mode: DepthMode) void {
-    asm volatile (
-        \\OpExecutionMode %entry_point $mode
-        :
-        : [entry_point] "" (entry_point),
-          [mode] "c" (mode),
-    );
-}
+    pub const Tag = enum(u32) {
+        origin_upper_left = 7,
+        origin_lower_left = 8,
+        depth_replacing = 12,
+        depth_greater = 14,
+        depth_less = 15,
+        depth_unchanged = 16,
+        local_size = 17,
+    };
 
-/// Indicates the workgroup size in the `x`, `y`, and `z` dimensions.
-/// Only valid with the `GLCompute` or `Kernel` calling conventions.
-pub fn workgroupSize(comptime entry_point: anytype, comptime size: @Vector(3, u32)) void {
-    asm volatile (
-        \\OpExecutionMode %entry_point LocalSize %x %y %z
-        :
-        : [entry_point] "" (entry_point),
-          [x] "c" (size[0]),
-          [y] "c" (size[1]),
-          [z] "c" (size[2]),
-    );
-}
+    pub const LocalSize = struct { x: u32, y: u32, z: u32 };
+};
 
-/// A hint to the client, which indicates the workgroup size in the `x`, `y`, and `z` dimensions.
-/// Only valid with the `GLCompute` or `Kernel` calling conventions.
-pub fn workgroupSizeHint(comptime entry_point: anytype, comptime size: @Vector(3, u32)) void {
-    asm volatile (
-        \\OpExecutionMode %entry_point LocalSizeHint %x %y %z
-        :
-        : [entry_point] "" (entry_point),
-          [x] "c" (size[0]),
-          [y] "c" (size[1]),
-          [z] "c" (size[2]),
-    );
+/// Declare the mode entry point executes in.
+pub fn executionMode(comptime entry_point: anytype, comptime mode: ExecutionMode) void {
+    const cc = @typeInfo(@TypeOf(entry_point)).@"fn".calling_convention;
+    switch (mode) {
+        .origin_upper_left,
+        .origin_lower_left,
+        .depth_replacing,
+        .depth_greater,
+        .depth_less,
+        .depth_unchanged,
+        => {
+            if (cc != .spirv_fragment) {
+                @compileError(
+                    \\invalid execution mode '
+                ++ @tagName(mode) ++
+                    \\' for function with '
+                ++ @tagName(cc) ++
+                    \\' calling convention
+                );
+            }
+            asm volatile (
+                \\OpExecutionMode %entry_point $mode
+                :
+                : [entry_point] "" (entry_point),
+                  [mode] "c" (@intFromEnum(mode)),
+            );
+        },
+        .local_size => |size| {
+            if (cc != .spirv_kernel) {
+                @compileError(
+                    \\invalid execution mode 'local_size' for function with '
+                ++ @tagName(cc) ++
+                    \\' calling convention
+                );
+            }
+            asm volatile (
+                \\OpExecutionMode %entry_point LocalSize $x $y $z
+                :
+                : [entry_point] "" (entry_point),
+                  [x] "c" (size.x),
+                  [y] "c" (size.y),
+                  [z] "c" (size.z),
+            );
+        },
+    }
 }
src/codegen/spirv/Assembler.zig
@@ -296,12 +296,26 @@ fn processInstruction(self: *Assembler) !void {
             };
             break :blk .{ .value = try self.spv.importInstructionSet(set_tag) };
         },
+        .OpExecutionMode, .OpExecutionModeId => {
+            assert(try self.processGenericInstruction() == null);
+            const entry_point_id = try self.resolveRefId(self.inst.operands.items[0].ref_id);
+            const exec_mode: spec.ExecutionMode = @enumFromInt(self.inst.operands.items[1].value);
+            const gop = try self.spv.entry_points.getOrPut(self.gpa, entry_point_id);
+            if (!gop.found_existing) {
+                gop.value_ptr.* = .{};
+            } else if (gop.value_ptr.exec_mode != null) {
+                return self.fail(
+                    self.currentToken().start,
+                    "cannot set execution mode more than once to any entry point",
+                    .{},
+                );
+            }
+            gop.value_ptr.exec_mode = exec_mode;
+            return;
+        },
         else => switch (self.inst.opcode.class()) {
             .TypeDeclaration => try self.processTypeInstruction(),
-            else => if (try self.processGenericInstruction()) |result|
-                result
-            else
-                return,
+            else => (try self.processGenericInstruction()) orelse return,
         },
     };
 
src/codegen/spirv/Module.zig
@@ -92,11 +92,12 @@ pub const Decl = struct {
 /// This models a kernel entry point.
 pub const EntryPoint = struct {
     /// The declaration that should be exported.
-    decl_index: Decl.Index,
+    decl_index: ?Decl.Index = null,
     /// The name of the kernel to be exported.
-    name: []const u8,
+    name: ?[]const u8 = null,
     /// Calling Convention
-    execution_model: spec.ExecutionModel,
+    exec_model: ?spec.ExecutionModel = null,
+    exec_mode: ?spec.ExecutionMode = null,
 };
 
 /// A general-purpose allocator which may be used to allocate resources for this module
@@ -184,7 +185,7 @@ decls: std.ArrayListUnmanaged(Decl) = .empty,
 decl_deps: std.ArrayListUnmanaged(Decl.Index) = .empty,
 
 /// The list of entry points that should be exported from this module.
-entry_points: std.ArrayListUnmanaged(EntryPoint) = .empty,
+entry_points: std.AutoArrayHashMapUnmanaged(IdRef, EntryPoint) = .empty,
 
 pub fn init(gpa: Allocator, target: std.Target) Module {
     const version_minor: u8 = blk: {
@@ -304,19 +305,30 @@ fn entryPoints(self: *Module) !Section {
     var seen = try std.DynamicBitSetUnmanaged.initEmpty(self.gpa, self.decls.items.len);
     defer seen.deinit(self.gpa);
 
-    for (self.entry_points.items) |entry_point| {
+    for (self.entry_points.keys(), self.entry_points.values()) |entry_point_id, entry_point| {
         interface.items.len = 0;
         seen.setRangeValue(.{ .start = 0, .end = self.decls.items.len }, false);
 
-        try self.addEntryPointDeps(entry_point.decl_index, &seen, &interface);
-
-        const entry_point_id = self.declPtr(entry_point.decl_index).result_id;
+        try self.addEntryPointDeps(entry_point.decl_index.?, &seen, &interface);
         try entry_points.emit(self.gpa, .OpEntryPoint, .{
-            .execution_model = entry_point.execution_model,
+            .execution_model = entry_point.exec_model.?,
             .entry_point = entry_point_id,
-            .name = entry_point.name,
+            .name = entry_point.name.?,
             .interface = interface.items,
         });
+
+        if (entry_point.exec_mode == null and entry_point.exec_model == .Fragment) {
+            switch (self.target.os.tag) {
+                .vulkan, .opengl => |tag| {
+                    try self.sections.execution_modes.emit(self.gpa, .OpExecutionMode, .{
+                        .entry_point = entry_point_id,
+                        .mode = if (tag == .vulkan) .OriginUpperLeft else .OriginLowerLeft,
+                    });
+                },
+                .opencl => {},
+                else => unreachable,
+            }
+        }
     }
 
     return entry_points;
@@ -749,13 +761,15 @@ pub fn declareEntryPoint(
     self: *Module,
     decl_index: Decl.Index,
     name: []const u8,
-    execution_model: spec.ExecutionModel,
+    exec_model: spec.ExecutionModel,
+    exec_mode: ?spec.ExecutionMode,
 ) !void {
-    try self.entry_points.append(self.gpa, .{
-        .decl_index = decl_index,
-        .name = try self.arena.allocator().dupe(u8, name),
-        .execution_model = execution_model,
-    });
+    const gop = try self.entry_points.getOrPut(self.gpa, self.declPtr(decl_index).result_id);
+    gop.value_ptr.decl_index = decl_index;
+    gop.value_ptr.name = try self.arena.allocator().dupe(u8, name);
+    gop.value_ptr.exec_model = exec_model;
+    // Might've been set by assembler
+    if (!gop.found_existing) gop.value_ptr.exec_mode = exec_mode;
 }
 
 pub fn debugName(self: *Module, target: IdResult, name: []const u8) !void {
src/codegen/spirv.zig
@@ -2870,7 +2870,7 @@ const NavGen = struct {
         };
 
         try self.spv.declareDeclDeps(spv_decl_index, decl_deps.items);
-        try self.spv.declareEntryPoint(spv_decl_index, test_name, execution_mode);
+        try self.spv.declareEntryPoint(spv_decl_index, test_name, execution_mode, null);
     }
 
     fn genNav(self: *NavGen, do_codegen: bool) !void {
@@ -2976,10 +2976,6 @@ const NavGen = struct {
                     try self.spv.decorate(result_id, .{ .BuiltIn = .{ .built_in = .Position } });
                 } else if (nav.fqn.eqlSlice("point_size", ip)) {
                     try self.spv.decorate(result_id, .{ .BuiltIn = .{ .built_in = .PointSize } });
-                } else if (nav.fqn.eqlSlice("vertex_id", ip)) {
-                    try self.spv.decorate(result_id, .{ .BuiltIn = .{ .built_in = .VertexId } });
-                } else if (nav.fqn.eqlSlice("instance_id", ip)) {
-                    try self.spv.decorate(result_id, .{ .BuiltIn = .{ .built_in = .InstanceId } });
                 } else if (nav.fqn.eqlSlice("invocation_id", ip)) {
                     try self.spv.decorate(result_id, .{ .BuiltIn = .{ .built_in = .InvocationId } });
                 } else if (nav.fqn.eqlSlice("frag_coord", ip)) {
@@ -2990,8 +2986,6 @@ const NavGen = struct {
                     try self.spv.decorate(result_id, .{ .BuiltIn = .{ .built_in = .FrontFacing } });
                 } else if (nav.fqn.eqlSlice("sample_mask", ip)) {
                     try self.spv.decorate(result_id, .{ .BuiltIn = .{ .built_in = .SampleMask } });
-                } else if (nav.fqn.eqlSlice("sample_mask", ip)) {
-                    try self.spv.decorate(result_id, .{ .BuiltIn = .{ .built_in = .SampleMask } });
                 } else if (nav.fqn.eqlSlice("frag_depth", ip)) {
                     try self.spv.decorate(result_id, .{ .BuiltIn = .{ .built_in = .FragDepth } });
                 } else if (nav.fqn.eqlSlice("num_workgroups", ip)) {
src/link/SpirV.zig
@@ -162,7 +162,7 @@ pub fn updateExports(
     if (ip.isFunctionType(nav_ty)) {
         const spv_decl_index = try self.object.resolveNav(zcu, nav_index);
         const cc = Type.fromInterned(nav_ty).fnCallingConvention(zcu);
-        const execution_model: spec.ExecutionModel = switch (target.os.tag) {
+        const exec_model: spec.ExecutionModel = switch (target.os.tag) {
             .vulkan, .opengl => switch (cc) {
                 .spirv_vertex => .Vertex,
                 .spirv_fragment => .Fragment,
@@ -185,7 +185,8 @@ pub fn updateExports(
             try self.object.spv.declareEntryPoint(
                 spv_decl_index,
                 exp.opts.name.toSlice(ip),
-                execution_model,
+                exec_model,
+                null,
             );
         }
     }