Commit 12350f53bf

Robin Voetter <robin@voetter.nl>
2024-03-30 18:30:28
spirv: clz, ctz for opencl
This instruction seems common in compiler_rt.
1 parent f5ab3c9
Changed files (2)
src
codegen
test
behavior
src/codegen/spirv.zig
@@ -2332,6 +2332,9 @@ const DeclGen = struct {
 
             .mul_add => try self.airMulAdd(inst),
 
+            .ctz => try self.airClzCtz(inst, .ctz),
+            .clz => try self.airClzCtz(inst, .clz),
+
             .splat => try self.airSplat(inst),
             .reduce, .reduce_optimized => try self.airReduce(inst),
             .shuffle => try self.airShuffle(inst),
@@ -3029,6 +3032,83 @@ const DeclGen = struct {
         return try wip.finalize();
     }
 
+    fn airClzCtz(self: *DeclGen, inst: Air.Inst.Index, op: enum { clz, ctz }) !?IdRef {
+        if (self.liveness.isUnused(inst)) return null;
+
+        const mod = self.module;
+        const target = self.getTarget();
+        const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
+        const result_ty = self.typeOfIndex(inst);
+        const operand_ty = self.typeOf(ty_op.operand);
+        const operand = try self.resolve(ty_op.operand);
+
+        const info = self.arithmeticTypeInfo(operand_ty);
+        switch (info.class) {
+            .composite_integer => unreachable, // TODO
+            .integer, .strange_integer => {},
+            .float, .bool => unreachable,
+        }
+
+        var wip = try self.elementWise(result_ty, false);
+        defer wip.deinit();
+
+        const elem_ty = if (wip.is_array) operand_ty.scalarType(mod) else operand_ty;
+        const elem_ty_ref = try self.resolveType(elem_ty, .direct);
+        const elem_ty_id = self.typeId(elem_ty_ref);
+
+        for (wip.results, 0..) |*result_id, i| {
+            const elem = try wip.elementAt(operand_ty, operand, i);
+
+            switch (target.os.tag) {
+                .opencl => {
+                    const set = try self.spv.importInstructionSet(.@"OpenCL.std");
+                    const ext_inst: u32 = switch (op) {
+                        .clz => 151, // clz
+                        .ctz => 152, // ctz
+                    };
+
+                    // Note: result of OpenCL ctz/clz returns operand_ty, and we want result_ty.
+                    // result_ty is always large enough to hold the result, so we might have to down
+                    // cast it.
+                    const tmp = self.spv.allocId();
+                    try self.func.body.emit(self.spv.gpa, .OpExtInst, .{
+                        .id_result_type = elem_ty_id,
+                        .id_result = tmp,
+                        .set = set,
+                        .instruction = .{ .inst = ext_inst },
+                        .id_ref_4 = &.{elem},
+                    });
+
+                    if (wip.ty_id == elem_ty_id) {
+                        result_id.* = tmp;
+                        continue;
+                    }
+
+                    result_id.* = self.spv.allocId();
+                    if (result_ty.scalarType(mod).isSignedInt(mod)) {
+                        assert(elem_ty.scalarType(mod).isSignedInt(mod));
+                        try self.func.body.emit(self.spv.gpa, .OpSConvert, .{
+                            .id_result_type = wip.ty_id,
+                            .id_result = result_id.*,
+                            .signed_value = tmp,
+                        });
+                    } else {
+                        assert(elem_ty.scalarType(mod).isUnsignedInt(mod));
+                        try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
+                            .id_result_type = wip.ty_id,
+                            .id_result = result_id.*,
+                            .unsigned_value = tmp,
+                        });
+                    }
+                },
+                .vulkan => unreachable, // TODO
+                else => unreachable,
+            }
+        }
+
+        return try wip.finalize();
+    }
+
     fn airSplat(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
         const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
         const operand_id = try self.resolve(ty_op.operand);
test/behavior/math.zig
@@ -65,7 +65,6 @@ test "@clz" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     try testClz();
     try comptime testClz();
@@ -148,7 +147,6 @@ test "@ctz" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     try testCtz();
     try comptime testCtz();
@@ -1752,7 +1750,6 @@ test "@clz works on both vector and scalar inputs" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var x: u32 = 0x1;
     _ = &x;