Commit 3357c59ceb

Robin Voetter <robin@voetter.nl>
2023-03-18 15:59:56
new builtins: @workItemId, @workGroupId, @workGroupSize
* @workItemId returns the index of the work item in a work group for a dimension. * @workGroupId returns the index of the work group in the kernel dispatch for a dimension. * @workGroupSize returns the size of the work group for a dimension. These builtins are mainly useful for GPU backends. They are currently only implemented for the AMDGCN LLVM backend.
1 parent 83051b0
doc/langref.html.in
@@ -9578,6 +9578,28 @@ fn foo(comptime T: type, ptr: *T) T {
       Remove {#syntax#}volatile{#endsyntax#} qualifier from a pointer.
       </p>
       {#header_close#}
+
+      {#header_open|@workGroupId#}
+      <pre>{#syntax#}@workGroupId(comptime dimension: u32) u32{#endsyntax#}</pre>
+      <p>
+      Returns the index of the work group in the current kernel invocation in dimension {#syntax#}dimension{#endsyntax#}.
+      </p>
+      {#header_close#}
+
+      {#header_open|@workGroupSize#}
+      <pre>{#syntax#}@workGroupSize(comptime dimension: u32) u32{#endsyntax#}</pre>
+      <p>
+      Returns the number of work items that a work group has in dimension {#syntax#}dimension{#endsyntax#}.
+      </p>
+      {#header_close#}
+
+      {#header_open|@workItemId#}
+      <pre>{#syntax#}@workItemId(comptime dimension: u32) u32{#endsyntax#}</pre>
+      <p>
+      Returns the index of the work item in the work group in dimension {#syntax#}dimension{#endsyntax#}. This function returns values between {#syntax#}0{#endsyntax#} (inclusive) and {#syntax#}@workGroupSize(dimension){#endsyntax#} (exclusive).
+      </p>
+      {#header_close#}
+
       {#header_close#}
 
       {#header_open|Build Mode#}
src/arch/aarch64/CodeGen.zig
@@ -890,6 +890,10 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
 
             .wasm_memory_size => unreachable,
             .wasm_memory_grow => unreachable,
+
+            .work_item_id => unreachable,
+            .work_group_size => unreachable,
+            .work_group_id => unreachable,
             // zig fmt: on
         }
 
src/arch/arm/CodeGen.zig
@@ -874,6 +874,10 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
 
             .wasm_memory_size => unreachable,
             .wasm_memory_grow => unreachable,
+
+            .work_item_id => unreachable,
+            .work_group_size => unreachable,
+            .work_group_id => unreachable,
             // zig fmt: on
         }
 
src/arch/riscv64/CodeGen.zig
@@ -704,6 +704,10 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
 
             .wasm_memory_size => unreachable,
             .wasm_memory_grow => unreachable,
+
+            .work_item_id => unreachable,
+            .work_group_size => unreachable,
+            .work_group_id => unreachable,
             // zig fmt: on
         }
         if (std.debug.runtime_safety) {
src/arch/sparc64/CodeGen.zig
@@ -720,6 +720,10 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
 
             .wasm_memory_size => unreachable,
             .wasm_memory_grow => unreachable,
+
+            .work_item_id => unreachable,
+            .work_group_size => unreachable,
+            .work_group_id => unreachable,
             // zig fmt: on
         }
 
src/arch/wasm/CodeGen.zig
@@ -1997,6 +1997,11 @@ fn genInst(func: *CodeGen, inst: Air.Inst.Index) InnerError!void {
         .reduce_optimized,
         .float_to_int_optimized,
         => return func.fail("TODO implement optimized float mode", .{}),
+
+        .work_item_id,
+        .work_group_size,
+        .work_group_id,
+        => unreachable,
     };
 }
 
src/arch/x86_64/CodeGen.zig
@@ -1132,6 +1132,10 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
 
             .wasm_memory_size => unreachable,
             .wasm_memory_grow => unreachable,
+
+            .work_item_id => unreachable,
+            .work_group_size => unreachable,
+            .work_group_id => unreachable,
             // zig fmt: on
         }
 
src/codegen/c.zig
@@ -2995,6 +2995,11 @@ fn genBodyInner(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail,
             .c_va_arg => try airCVaArg(f, inst),
             .c_va_end => try airCVaEnd(f, inst),
             .c_va_copy => try airCVaCopy(f, inst),
+
+            .work_item_id,
+            .work_group_size,
+            .work_group_id,
+            => unreachable,
             // zig fmt: on
         };
         if (result_value == .new_local) {
src/codegen/llvm.zig
@@ -4745,6 +4745,10 @@ pub const FuncGen = struct {
                 .c_va_copy => try self.airCVaCopy(inst),
                 .c_va_end => try self.airCVaEnd(inst),
                 .c_va_start => try self.airCVaStart(inst),
+
+                .work_item_id => try self.airWorkItemId(inst),
+                .work_group_size => try self.airWorkGroupSize(inst),
+                .work_group_id => try self.airWorkGroupId(inst),
                 // zig fmt: on
             };
             if (opt_value) |val| {
@@ -9567,6 +9571,74 @@ pub const FuncGen = struct {
         return self.builder.buildAddrSpaceCast(operand, llvm_dest_ty, "");
     }
 
+    fn amdgcnWorkIntrinsic(self: *FuncGen, dimension: u32, default: u32, comptime basename: []const u8) !?*llvm.Value {
+        const llvm_u32 = self.context.intType(32);
+
+        const llvm_fn_name = switch (dimension) {
+            0 => basename ++ ".x",
+            1 => basename ++ ".y",
+            2 => basename ++ ".z",
+            else => return llvm_u32.constInt(default, .False),
+        };
+
+        const args: [0]*llvm.Value = .{};
+        const llvm_fn = self.getIntrinsic(llvm_fn_name, &.{});
+        return self.builder.buildCall(llvm_fn.globalGetValueType(), llvm_fn, &args, args.len, .Fast, .Auto, "");
+    }
+
+    fn airWorkItemId(self: *FuncGen, inst: Air.Inst.Index) !?*llvm.Value {
+        if (self.liveness.isUnused(inst)) return null;
+
+        const target = self.dg.module.getTarget();
+        assert(target.cpu.arch == .amdgcn); // TODO is to port this function to other GPU architectures
+
+        const pl_op = self.air.instructions.items(.data)[inst].pl_op;
+        const dimension = pl_op.payload;
+        return self.amdgcnWorkIntrinsic(dimension, 0, "llvm.amdgcn.workitem.id");
+    }
+
+    fn airWorkGroupSize(self: *FuncGen, inst: Air.Inst.Index) !?*llvm.Value {
+        if (self.liveness.isUnused(inst)) return null;
+
+        const target = self.dg.module.getTarget();
+        assert(target.cpu.arch == .amdgcn); // TODO is to port this function to other GPU architectures
+
+        const pl_op = self.air.instructions.items(.data)[inst].pl_op;
+        const dimension = pl_op.payload;
+        const llvm_u32 = self.context.intType(32);
+        if (dimension >= 3) {
+            return llvm_u32.constInt(1, .False);
+        }
+
+        // Fetch the dispatch pointer, which points to this structure:
+        // https://github.com/RadeonOpenCompute/ROCR-Runtime/blob/adae6c61e10d371f7cbc3d0e94ae2c070cab18a4/src/inc/hsa.h#L2913
+        const llvm_fn = self.getIntrinsic("llvm.amdgcn.dispatch.ptr", &.{});
+        const args: [0]*llvm.Value = .{};
+        const dispatch_ptr = self.builder.buildCall(llvm_fn.globalGetValueType(), llvm_fn, &args, args.len, .Fast, .Auto, "");
+        dispatch_ptr.setAlignment(4);
+
+        // Load the work_group_* member from the struct as u16.
+        // Just treat the dispatch pointer as an array of u16 to keep things simple.
+        const offset = 2 + dimension;
+        const index = [_]*llvm.Value{llvm_u32.constInt(offset, .False)};
+        const llvm_u16 = self.context.intType(16);
+        const workgroup_size_ptr = self.builder.buildInBoundsGEP(llvm_u16, dispatch_ptr, &index, index.len, "");
+        const workgroup_size = self.builder.buildLoad(llvm_u16, workgroup_size_ptr, "");
+        workgroup_size.setAlignment(2);
+        return workgroup_size;
+    }
+
+    fn airWorkGroupId(self: *FuncGen, inst: Air.Inst.Index) !?*llvm.Value {
+        if (self.liveness.isUnused(inst)) return null;
+
+        const target = self.dg.module.getTarget();
+        assert(target.cpu.arch == .amdgcn); // TODO is to port this function to other GPU architectures
+
+        const pl_op = self.air.instructions.items(.data)[inst].pl_op;
+        const dimension = pl_op.payload;
+        return self.amdgcnWorkIntrinsic(dimension, 0, "llvm.amdgcn.workgroup.id");
+    }
+
     fn getErrorNameTable(self: *FuncGen) !*llvm.Value {
         if (self.dg.object.error_name_table) |table| {
             return table;
src/Air.zig
@@ -761,6 +761,22 @@ pub const Inst = struct {
         /// Uses the `ty` field.
         c_va_start,
 
+        /// Implements @workItemId builtin.
+        /// Result type is always `u32`
+        /// Uses the `pl_op` field, payload is the dimension to get the work item id for.
+        /// Operand is unused and set to Ref.none
+        work_item_id,
+        /// Implements @workGroupSize builtin.
+        /// Result type is always `u32`
+        /// Uses the `pl_op` field, payload is the dimension to get the work group size for.
+        /// Operand is unused and set to Ref.none
+        work_group_size,
+        /// Implements @workGroupId builtin.
+        /// Result type is always `u32`
+        /// Uses the `pl_op` field, payload is the dimension to get the work group id for.
+        /// Operand is unused and set to Ref.none
+        work_group_id,
+
         pub fn fromCmpOp(op: std.math.CompareOperator, optimized: bool) Tag {
             switch (op) {
                 .lt => return if (optimized) .cmp_lt_optimized else .cmp_lt,
@@ -1267,6 +1283,11 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type {
             const err_union_ty = air.typeOf(datas[inst].pl_op.operand);
             return err_union_ty.errorUnionPayload();
         },
+
+        .work_item_id,
+        .work_group_size,
+        .work_group_id,
+        => return Type.u32,
     }
 }
 
src/AstGen.zig
@@ -8549,6 +8549,40 @@ fn builtinCall(
             }
             return rvalue(gz, ri, try gz.addNodeExtended(.c_va_start, node), node);
         },
+
+        .work_item_id => {
+            if (astgen.fn_block == null) {
+                return astgen.failNode(node, "'@workItemId' outside function scope", .{});
+            }
+            const operand = try comptimeExpr(gz, scope, .{ .rl = .{ .coerced_ty = .u32_type } }, params[0]);
+            const result = try gz.addExtendedPayload(.work_item_id, Zir.Inst.UnNode{
+                .node = gz.nodeIndexToRelative(node),
+                .operand = operand,
+            });
+            return rvalue(gz, ri, result, node);
+        },
+        .work_group_size => {
+            if (astgen.fn_block == null) {
+                return astgen.failNode(node, "'@workGroupSize' outside function scope", .{});
+            }
+            const operand = try comptimeExpr(gz, scope, .{ .rl = .{ .coerced_ty = .u32_type } }, params[0]);
+            const result = try gz.addExtendedPayload(.work_group_size, Zir.Inst.UnNode{
+                .node = gz.nodeIndexToRelative(node),
+                .operand = operand,
+            });
+            return rvalue(gz, ri, result, node);
+        },
+        .work_group_id => {
+            if (astgen.fn_block == null) {
+                return astgen.failNode(node, "'@workGroupId' outside function scope", .{});
+            }
+            const operand = try comptimeExpr(gz, scope, .{ .rl = .{ .coerced_ty = .u32_type } }, params[0]);
+            const result = try gz.addExtendedPayload(.work_group_id, Zir.Inst.UnNode{
+                .node = gz.nodeIndexToRelative(node),
+                .operand = operand,
+            });
+            return rvalue(gz, ri, result, node);
+        },
     }
 }
 
src/BuiltinFn.zig
@@ -118,6 +118,9 @@ pub const Tag = enum {
     union_init,
     Vector,
     volatile_cast,
+    work_item_id,
+    work_group_size,
+    work_group_id,
 };
 
 pub const MemLocRequirement = enum {
@@ -980,5 +983,25 @@ pub const list = list: {
                 .param_count = 1,
             },
         },
+        .{
+            "@workItemId", .{
+                .tag = .work_item_id,
+                .param_count = 1,
+            },
+        },
+        .{
+            "@workGroupSize",
+            .{
+                .tag = .work_group_size,
+                .param_count = 1,
+            },
+        },
+        .{
+            "@workGroupId",
+            .{
+                .tag = .work_group_id,
+                .param_count = 1,
+            },
+        },
     });
 };
src/Liveness.zig
@@ -240,6 +240,9 @@ pub fn categorizeOperand(
         .err_return_trace,
         .save_err_return_trace_index,
         .c_va_start,
+        .work_item_id,
+        .work_group_size,
+        .work_group_id,
         => return .none,
 
         .fence => return .write,
@@ -864,6 +867,9 @@ fn analyzeInst(
         .err_return_trace,
         .save_err_return_trace_index,
         .c_va_start,
+        .work_item_id,
+        .work_group_size,
+        .work_group_id,
         => return trackOperands(a, new_set, inst, main_tomb, .{ .none, .none, .none }),
 
         .not,
src/print_air.zig
@@ -328,6 +328,11 @@ const Writer = struct {
             .vector_store_elem => try w.writeVectorStoreElem(s, inst),
 
             .dbg_block_begin, .dbg_block_end => {},
+
+            .work_item_id,
+            .work_group_size,
+            .work_group_id,
+            => try w.writeWorkDimension(s, inst),
         }
         try s.writeAll(")\n");
     }
@@ -869,6 +874,11 @@ const Writer = struct {
         try w.writeOperand(s, inst, 0, pl_op.operand);
     }
 
+    fn writeWorkDimension(w: *Writer, s: anytype, inst: Air.Inst.Index) @TypeOf(s).Error!void {
+        const pl_op = w.air.instructions.items(.data)[inst].pl_op;
+        try s.print("{d}", .{pl_op.payload});
+    }
+
     fn writeOperand(
         w: *Writer,
         s: anytype,
src/print_zir.zig
@@ -512,6 +512,9 @@ const Writer = struct {
             .c_va_end,
             .const_cast,
             .volatile_cast,
+            .work_item_id,
+            .work_group_size,
+            .work_group_id,
             => {
                 const inst_data = self.code.extraData(Zir.Inst.UnNode, extended.operand).data;
                 const src = LazySrcLoc.nodeOffset(inst_data.node);
src/Sema.zig
@@ -1164,6 +1164,9 @@ fn analyzeBodyInner(
                     .c_va_start            => try sema.zirCVaStart(          block, extended),
                     .const_cast,           => try sema.zirConstCast(         block, extended),
                     .volatile_cast,        => try sema.zirVolatileCast(      block, extended),
+                    .work_item_id          => try sema.zirWorkItem(          block, extended, extended.opcode),
+                    .work_group_size       => try sema.zirWorkItem(          block, extended, extended.opcode),
+                    .work_group_id         => try sema.zirWorkItem(          block, extended, extended.opcode),
                     // zig fmt: on
 
                     .fence => {
@@ -22437,6 +22440,42 @@ fn zirBuiltinExtern(
     return sema.addConstant(ty, ref);
 }
 
+fn zirWorkItem(
+    sema: *Sema,
+    block: *Block,
+    extended: Zir.Inst.Extended.InstData,
+    zir_tag: Zir.Inst.Extended,
+) CompileError!Air.Inst.Ref {
+    const extra = sema.code.extraData(Zir.Inst.UnNode, extended.operand).data;
+    const dimension_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = extra.node };
+    const builtin_src = LazySrcLoc.nodeOffset(extra.node);
+    const target = sema.mod.getTarget();
+
+    switch (target.cpu.arch) {
+        // TODO: Allow for other GPU targets.
+        .amdgcn => {},
+        else => {
+            return sema.fail(block, builtin_src, "builtin only available on GPU targets; targeted architecture is {s}", .{@tagName(target.cpu.arch)});
+        },
+    }
+
+    const dimension = @intCast(u32, try sema.resolveInt(block, dimension_src, extra.operand, Type.u32, "dimension must be comptime-known"));
+    try sema.requireRuntimeBlock(block, builtin_src, null);
+
+    return block.addInst(.{
+        .tag = switch (zir_tag) {
+            .work_item_id => .work_item_id,
+            .work_group_size => .work_group_size,
+            .work_group_id => .work_group_id,
+            else => unreachable,
+        },
+        .data = .{ .pl_op = .{
+            .operand = .none,
+            .payload = dimension,
+        } },
+    });
+}
+
 fn requireRuntimeBlock(sema: *Sema, block: *Block, src: LazySrcLoc, runtime_src: ?LazySrcLoc) !void {
     if (block.is_comptime) {
         const msg = msg: {
src/Zir.zig
@@ -2032,6 +2032,15 @@ pub const Inst = struct {
         /// Implements the `@volatileCast` builtin.
         /// `operand` is payload index to `UnNode`.
         volatile_cast,
+        /// Implements the `@workItemId` builtin.
+        /// `operand` is payload index to `UnNode`.
+        work_item_id,
+        /// Implements the `@workGroupSize` builtin.
+        /// `operand` is payload index to `UnNode`.
+        work_group_size,
+        /// Implements the `@workGroupId` builtin.
+        /// `operand` is payload index to `UnNode`.
+        work_group_id,
 
         pub const InstData = struct {
             opcode: Extended,