Commit 50a771a11e

Robin Voetter <robin@voetter.nl>
2024-06-15 14:17:07
spirv: add support for workItemId, workGroupId, workGroupSize
1 parent 7829be6
Changed files (3)
src
src/codegen/spirv/Module.zig
@@ -8,7 +8,6 @@
 const Module = @This();
 
 const std = @import("std");
-const builtin = @import("builtin");
 const Allocator = std.mem.Allocator;
 const assert = std.debug.assert;
 
@@ -150,6 +149,8 @@ strings: std.StringArrayHashMapUnmanaged(IdRef) = .{},
 /// this is an ad-hoc structure to cache types where required.
 /// According to the SPIR-V specification, section 2.8, this includes all non-aggregate
 /// non-pointer types.
+/// Additionally, this is used for other values which can be cached, for example,
+/// built-in variables.
 cache: struct {
     bool_type: ?IdRef = null,
     void_type: ?IdRef = null,
@@ -158,6 +159,8 @@ cache: struct {
     // This cache is required so that @Vector(X, u1) in direct representation has the
     // same ID as @Vector(X, bool) in indirect representation.
     vector_types: std.AutoHashMapUnmanaged(struct { IdRef, u32 }, IdRef) = .{},
+
+    builtins: std.AutoHashMapUnmanaged(struct { IdRef, spec.BuiltIn }, Decl.Index) = .{},
 } = .{},
 
 /// Set of Decls, referred to by Decl.Index.
@@ -198,6 +201,7 @@ pub fn deinit(self: *Module) void {
     self.cache.int_types.deinit(self.gpa);
     self.cache.float_types.deinit(self.gpa);
     self.cache.vector_types.deinit(self.gpa);
+    self.cache.builtins.deinit(self.gpa);
 
     self.decls.deinit(self.gpa);
     self.decl_deps.deinit(self.gpa);
@@ -491,6 +495,25 @@ pub fn vectorType(self: *Module, len: u32, child_id: IdRef) !IdRef {
     return entry.value_ptr.*;
 }
 
+/// Return a pointer to a builtin variable. `result_ty_id` must be a **pointer**
+/// with storage class `.Input`.
+pub fn builtin(self: *Module, result_ty_id: IdRef, spirv_builtin: spec.BuiltIn) !Decl.Index {
+    const entry = try self.cache.builtins.getOrPut(self.gpa, .{ result_ty_id, spirv_builtin });
+    if (!entry.found_existing) {
+        const decl_index = try self.allocDecl(.global);
+        const result_id = self.declPtr(decl_index).result_id;
+        entry.value_ptr.* = decl_index;
+        try self.sections.types_globals_constants.emit(self.gpa, .OpVariable, .{
+            .id_result_type = result_ty_id,
+            .id_result = result_id,
+            .storage_class = .Input,
+        });
+        try self.decorate(result_id, .{ .BuiltIn = .{ .built_in = spirv_builtin } });
+        try self.declareDeclDeps(decl_index, &.{});
+    }
+    return entry.value_ptr.*;
+}
+
 pub fn constUndef(self: *Module, ty_id: IdRef) !IdRef {
     const result_id = self.allocId();
     try self.sections.types_globals_constants.emit(self.gpa, .OpUndef, .{
src/codegen/spirv.zig
@@ -3376,6 +3376,11 @@ const DeclGen = struct {
             .call_always_tail  => try self.airCall(inst, .always_tail),
             .call_never_tail   => try self.airCall(inst, .never_tail),
             .call_never_inline => try self.airCall(inst, .never_inline),
+
+            .work_item_id => try self.airWorkItemId(inst),
+            .work_group_size => try self.airWorkGroupSize(inst),
+            .work_group_id => try self.airWorkGroupId(inst),
+
             // zig fmt: on
 
             else => |tag| return self.todo("implement AIR tag {s}", .{@tagName(tag)}),
@@ -6533,6 +6538,56 @@ const DeclGen = struct {
         return result_id;
     }
 
+    fn builtin3D(self: *DeclGen, result_ty: Type, builtin: spec.BuiltIn, dimension: u32, out_of_range_value: anytype) !IdRef {
+        const mod = self.module;
+        if (dimension >= 3) {
+            return try self.constInt(result_ty, out_of_range_value, .direct);
+        }
+        const vec_ty = try mod.vectorType(.{
+            .len = 3,
+            .child = result_ty.toIntern(),
+        });
+        const ptr_ty_id = try self.ptrType(vec_ty, .Input);
+        const spv_decl_index = try self.spv.builtin(ptr_ty_id, builtin);
+        try self.func.decl_deps.put(self.spv.gpa, spv_decl_index, {});
+        const ptr = self.spv.declPtr(spv_decl_index).result_id;
+        const vec = try self.load(vec_ty, ptr, .{});
+        return try self.extractVectorComponent(result_ty, vec, dimension);
+    }
+
+    fn airWorkItemId(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+        if (self.liveness.isUnused(inst)) return null;
+        const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op;
+        const dimension = pl_op.payload;
+        // TODO: Should we make these builtins return usize?
+        const result_id = try self.builtin3D(Type.u64, .LocalInvocationId, dimension, 0);
+        const tmp = Temporary.init(Type.u64, result_id);
+        const result = try self.buildIntConvert(Type.u32, tmp);
+        return try result.materialize(self);
+    }
+
+    fn airWorkGroupSize(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+        if (self.liveness.isUnused(inst)) return null;
+        const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op;
+        const dimension = pl_op.payload;
+        // TODO: Should we make these builtins return usize?
+        const result_id = try self.builtin3D(Type.u64, .WorkgroupSize, dimension, 0);
+        const tmp = Temporary.init(Type.u64, result_id);
+        const result = try self.buildIntConvert(Type.u32, tmp);
+        return try result.materialize(self);
+    }
+
+    fn airWorkGroupId(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+        if (self.liveness.isUnused(inst)) return null;
+        const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op;
+        const dimension = pl_op.payload;
+        // TODO: Should we make these builtins return usize?
+        const result_id = try self.builtin3D(Type.u64, .WorkgroupId, dimension, 0);
+        const tmp = Temporary.init(Type.u64, result_id);
+        const result = try self.buildIntConvert(Type.u32, tmp);
+        return try result.materialize(self);
+    }
+
     fn typeOf(self: *DeclGen, inst: Air.Inst.Ref) Type {
         const mod = self.module;
         return self.air.typeOf(inst, &mod.intern_pool);
src/Sema.zig
@@ -26488,7 +26488,7 @@ fn zirWorkItem(
 
     switch (target.cpu.arch) {
         // TODO: Allow for other GPU targets.
-        .amdgcn => {},
+        .amdgcn, .spirv64, .spirv32 => {},
         else => {
             return sema.fail(block, builtin_src, "builtin only available on GPU targets; targeted architecture is {s}", .{@tagName(target.cpu.arch)});
         },