Commit 631d1b63a8

Robin Voetter <robin@voetter.nl>
2024-01-21 20:38:56
spirv: fix shuffle properly
1 parent 9641d2e
Changed files (5)
src/codegen/spirv.zig
@@ -2876,37 +2876,31 @@ const DeclGen = struct {
     fn airShuffle(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
         const mod = self.module;
         if (self.liveness.isUnused(inst)) return null;
-        const ty = self.typeOfIndex(inst);
         const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
         const extra = self.air.extraData(Air.Shuffle, ty_pl.payload).data;
         const a = try self.resolve(extra.a);
         const b = try self.resolve(extra.b);
         const mask = Value.fromInterned(extra.mask);
-        const mask_len = extra.mask_len;
-        const a_len = self.typeOf(extra.a).vectorLen(mod);
 
-        const result_id = self.spv.allocId();
-        const result_type_id = try self.resolveTypeId(ty);
-        // Similar to LLVM, SPIR-V uses indices larger than the length of the first vector
-        // to index into the second vector.
-        try self.func.body.emitRaw(self.spv.gpa, .OpVectorShuffle, 4 + mask_len);
-        self.func.body.writeOperand(spec.IdResultType, result_type_id);
-        self.func.body.writeOperand(spec.IdResult, result_id);
-        self.func.body.writeOperand(spec.IdRef, a);
-        self.func.body.writeOperand(spec.IdRef, b);
+        const ty = self.typeOfIndex(inst);
 
-        var i: usize = 0;
-        while (i < mask_len) : (i += 1) {
+        var wip = try self.elementWise(ty);
+        defer wip.deinit();
+        for (wip.results, 0..) |*result_id, i| {
             const elem = try mask.elemValue(mod, i);
             if (elem.isUndef(mod)) {
-                self.func.body.writeOperand(spec.LiteralInteger, 0xFFFF_FFFF);
+                result_id.* = try self.spv.constUndef(wip.scalar_ty_ref);
+                continue;
+            }
+
+            const index = elem.toSignedInt(mod);
+            if (index >= 0) {
+                result_id.* = try self.extractField(wip.scalar_ty, a, @intCast(index));
             } else {
-                const int = elem.toSignedInt(mod);
-                const unsigned = if (int >= 0) @as(u32, @intCast(int)) else @as(u32, @intCast(~int + a_len));
-                self.func.body.writeOperand(spec.LiteralInteger, unsigned);
+                result_id.* = try self.extractField(wip.scalar_ty, b, @intCast(~index));
             }
         }
-        return result_id;
+        return try wip.finalize();
     }
 
     fn indicesToIds(self: *DeclGen, indices: []const u32) ![]IdRef {
test/behavior/abs.zig
@@ -224,7 +224,6 @@ test "@abs unsigned int vectors" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     try comptime testAbsUnsignedIntVectors(1);
     try testAbsUnsignedIntVectors(1);
test/behavior/cast.zig
@@ -605,7 +605,6 @@ test "@intCast on vector" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -2508,7 +2507,6 @@ test "@intCast vector of signed integer" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
 
test/behavior/shuffle.zig
@@ -8,7 +8,6 @@ test "@shuffle int" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -54,7 +53,6 @@ test "@shuffle bool 1" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -77,7 +75,6 @@ test "@shuffle bool 2" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     if (builtin.zig_backend == .stage2_llvm) {
         // https://github.com/ziglang/zig/issues/3246
test/behavior/vector.zig
@@ -910,7 +910,6 @@ test "mask parameter of @shuffle is comptime scope" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const __v4hi = @Vector(4, i16);
     var v4_a = __v4hi{ 0, 0, 0, 0 };
@@ -1322,7 +1321,6 @@ test "array operands to shuffle are coerced to vectors" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const mask = [5]i32{ -1, 0, 1, 2, 3 };