Commit a567f3871e

Robin Voetter <robin@voetter.nl>
2024-06-04 22:09:15
spirv: improve shuffle codegen
1 parent a3b1ba8
Changed files (2)
src
codegen
test
behavior
src/codegen/spirv.zig
@@ -4082,25 +4082,72 @@ const DeclGen = struct {
         const b = try self.resolve(extra.b);
         const mask = Value.fromInterned(extra.mask);
 
-        const ty = self.typeOfIndex(inst);
+        // Note: number of components in the result, a, and b may differ.
+        const result_ty = self.typeOfIndex(inst);
+        const a_ty = self.typeOf(extra.a);
+        const b_ty = self.typeOf(extra.b);
+
+        const scalar_ty = result_ty.scalarType(mod);
+        const scalar_ty_id = try self.resolveType(scalar_ty, .direct);
+
+        // If all of the types are SPIR-V vectors, we can use OpVectorShuffle.
+        if (self.isSpvVector(result_ty) and self.isSpvVector(a_ty) and self.isSpvVector(b_ty)) {
+            // The SPIR-V shuffle instruction is similar to the Air instruction, except that the elements are
+            // numbered consecutively instead of using negatives.
+
+            const components = try self.gpa.alloc(Word, result_ty.vectorLen(mod));
+            defer self.gpa.free(components);
+
+            const a_len = a_ty.vectorLen(mod);
+
+            for (components, 0..) |*component, i| {
+                const elem = try mask.elemValue(mod, i);
+                if (elem.isUndef(mod)) {
+                    // This is explicitly valid for OpVectorShuffle, it indicates undefined.
+                    component.* = 0xFFFF_FFFF;
+                    continue;
+                }
+
+                const index = elem.toSignedInt(mod);
+                if (index >= 0) {
+                    component.* = @intCast(index);
+                } else {
+                    component.* = @intCast(~index + a_len);
+                }
+            }
 
-        var wip = try self.elementWise(ty, true);
-        defer wip.deinit();
-        for (wip.results, 0..) |*result_id, i| {
+            const result_id = self.spv.allocId();
+            try self.func.body.emit(self.spv.gpa, .OpVectorShuffle, .{
+                .id_result_type = try self.resolveType(result_ty, .direct),
+                .id_result = result_id,
+                .vector_1 = a,
+                .vector_2 = b,
+                .components = components,
+            });
+            return result_id;
+        }
+
+        // Fall back to manually extracting and inserting components.
+
+        const components = try self.gpa.alloc(IdRef, result_ty.vectorLen(mod));
+        defer self.gpa.free(components);
+
+        for (components, 0..) |*id, i| {
             const elem = try mask.elemValue(mod, i);
             if (elem.isUndef(mod)) {
-                result_id.* = try self.spv.constUndef(wip.ty_id);
+                id.* = try self.spv.constUndef(scalar_ty_id);
                 continue;
             }
 
             const index = elem.toSignedInt(mod);
             if (index >= 0) {
-                result_id.* = try self.extractVectorComponent(wip.ty, a, @intCast(index));
+                id.* = try self.extractVectorComponent(scalar_ty, a, @intCast(index));
             } else {
-                result_id.* = try self.extractVectorComponent(wip.ty, b, @intCast(~index));
+                id.* = try self.extractVectorComponent(scalar_ty, b, @intCast(~index));
             }
         }
-        return try wip.finalize();
+
+        return try self.constructVector(result_ty, components);
     }
 
     fn indicesToIds(self: *DeclGen, indices: []const u32) ![]IdRef {
test/behavior/shuffle.zig
@@ -2,6 +2,7 @@ const std = @import("std");
 const builtin = @import("builtin");
 const mem = std.mem;
 const expect = std.testing.expect;
+const expectEqual = std.testing.expectEqual;
 
 test "@shuffle int" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
@@ -49,6 +50,88 @@ test "@shuffle int" {
     try comptime S.doTheTest();
 }
 
+test "@shuffle int strange sizes" {
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
+
+    try comptime testShuffle(2, 2, 2);
+    try testShuffle(2, 2, 2);
+    try comptime testShuffle(4, 4, 4);
+    try testShuffle(4, 4, 4);
+    try comptime testShuffle(7, 4, 4);
+    try testShuffle(7, 4, 4);
+    try comptime testShuffle(8, 6, 4);
+    try testShuffle(8, 6, 4);
+    try comptime testShuffle(2, 7, 5);
+    try testShuffle(2, 7, 5);
+    try comptime testShuffle(13, 16, 12);
+    try testShuffle(13, 16, 12);
+    try comptime testShuffle(19, 3, 17);
+    try testShuffle(19, 3, 17);
+    try comptime testShuffle(1, 10, 1);
+    try testShuffle(1, 10, 1);
+}
+
+fn testShuffle(
+    comptime x_len: comptime_int,
+    comptime a_len: comptime_int,
+    comptime b_len: comptime_int,
+) !void {
+    const T = i32;
+    const XT = @Vector(x_len, T);
+    const AT = @Vector(a_len, T);
+    const BT = @Vector(b_len, T);
+
+    const a_elems = comptime blk: {
+        var elems: [a_len]T = undefined;
+        for (&elems, 0..) |*elem, i| elem.* = @intCast(100 + i);
+        break :blk elems;
+    };
+    var a: AT = a_elems;
+    _ = &a;
+
+    const b_elems = comptime blk: {
+        var elems: [b_len]T = undefined;
+        for (&elems, 0..) |*elem, i| elem.* = @intCast(1000 + i);
+        break :blk elems;
+    };
+    var b: BT = b_elems;
+    _ = &b;
+
+    const mask_seed: []const i32 = &.{ -14, -31, 23, 1, 21, 13, 17, -21, -10, -27, -16, -5, 15, 14, -2, 26, 2, -31, -24, -16 };
+
+    const mask = comptime blk: {
+        var elems: [x_len]i32 = undefined;
+        for (&elems, 0..) |*elem, i| {
+            const mask_val = mask_seed[i];
+            if (mask_val >= 0) {
+                elem.* = @mod(mask_val, a_len);
+            } else {
+                elem.* = @mod(mask_val, -b_len);
+            }
+        }
+
+        break :blk elems;
+    };
+
+    const x: XT = @shuffle(T, a, b, mask);
+
+    const x_elems: [x_len]T = x;
+    for (mask, x_elems) |m, x_elem| {
+        if (m >= 0) {
+            // Element from A
+            try expectEqual(x_elem, a_elems[@intCast(m)]);
+        } else {
+            // Element from B
+            try expectEqual(x_elem, b_elems[@intCast(~m)]);
+        }
+    }
+}
+
 test "@shuffle bool 1" {
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO