Commit fca5f3602d

Ali Cheraghi <alichraghi@proton.me>
2025-05-06 14:55:08
spirv: unroll all vector operations
1 parent f925e13
Changed files (2)
src
codegen
src/codegen/spirv/Module.zig
@@ -164,8 +164,6 @@ cache: struct {
     void_type: ?IdRef = null,
     int_types: std.AutoHashMapUnmanaged(std.builtin.Type.Int, IdRef) = .empty,
     float_types: std.AutoHashMapUnmanaged(std.builtin.Type.Float, IdRef) = .empty,
-    // This cache is required so that @Vector(X, u1) in direct representation has the
-    // same ID as @Vector(X, bool) in indirect representation.
     vector_types: std.AutoHashMapUnmanaged(struct { IdRef, u32 }, IdRef) = .empty,
     array_types: std.AutoHashMapUnmanaged(struct { IdRef, IdRef }, IdRef) = .empty,
 
src/codegen/spirv.zig
@@ -344,8 +344,7 @@ const NavGen = struct {
 
     /// This structure is used to return information about a type typically used for
     /// arithmetic operations. These types may either be integers, floats, or a vector
-    /// of these. Most scalar operations also work on vectors, so we can easily represent
-    /// those as arithmetic types. If the type is a scalar, 'inner type' refers to the
+    /// of these. If the type is a scalar, 'inner type' refers to the
     /// scalar type. Otherwise, if its a vector, it refers to the vector's element type.
     const ArithmeticTypeInfo = struct {
         /// A classification of the inner type.
@@ -615,41 +614,6 @@ const NavGen = struct {
         return if (self.spv.hasFeature(.int64)) 64 else 32;
     }
 
-    /// Checks whether the type is "composite int", an integer consisting of multiple native integers. These are represented by
-    /// arrays of largestSupportedIntBits().
-    /// Asserts `ty` is an integer.
-    fn isCompositeInt(self: *NavGen, ty: Type) bool {
-        return self.backingIntBits(ty) == null;
-    }
-
-    /// Checks whether the type can be directly translated to SPIR-V vectors
-    fn isSpvVector(self: *NavGen, ty: Type) bool {
-        const zcu = self.pt.zcu;
-        if (ty.zigTypeTag(zcu) != .vector) return false;
-
-        // TODO: This check must be expanded for types that can be represented
-        // as integers (enums / packed structs?) and types that are represented
-        // by multiple SPIR-V values.
-        const scalar_ty = ty.scalarType(zcu);
-        switch (scalar_ty.zigTypeTag(zcu)) {
-            .bool,
-            .int,
-            .float,
-            => {},
-            else => return false,
-        }
-
-        const elem_ty = ty.childType(zcu);
-        const len = ty.vectorLen(zcu);
-
-        if (elem_ty.isNumeric(zcu) or elem_ty.toIntern() == .bool_type) {
-            if (len > 1 and len <= 4) return true;
-            if (self.spv.hasFeature(.vector16)) return (len == 8 or len == 16);
-        }
-
-        return false;
-    }
-
     fn arithmeticTypeInfo(self: *NavGen, ty: Type) ArithmeticTypeInfo {
         const zcu = self.pt.zcu;
         const target = self.spv.target;
@@ -659,14 +623,14 @@ const NavGen = struct {
         }
         const vector_len = if (ty.isVector(zcu)) ty.vectorLen(zcu) else null;
         return switch (scalar_ty.zigTypeTag(zcu)) {
-            .bool => ArithmeticTypeInfo{
+            .bool => .{
                 .bits = 1, // Doesn't matter for this class.
                 .backing_bits = self.backingIntBits(1).?,
                 .vector_len = vector_len,
                 .signedness = .unsigned, // Technically, but doesn't matter for this class.
                 .class = .bool,
             },
-            .float => ArithmeticTypeInfo{
+            .float => .{
                 .bits = scalar_ty.floatBits(target),
                 .backing_bits = scalar_ty.floatBits(target), // TODO: F80?
                 .vector_len = vector_len,
@@ -677,16 +641,16 @@ const NavGen = struct {
                 const int_info = scalar_ty.intInfo(zcu);
                 // TODO: Maybe it's useful to also return this value.
                 const maybe_backing_bits = self.backingIntBits(int_info.bits);
-                break :blk ArithmeticTypeInfo{
+                break :blk .{
                     .bits = int_info.bits,
                     .backing_bits = maybe_backing_bits orelse 0,
                     .vector_len = vector_len,
                     .signedness = int_info.signedness,
                     .class = if (maybe_backing_bits) |backing_bits|
                         if (backing_bits == int_info.bits)
-                            ArithmeticTypeInfo.Class.integer
+                            .integer
                         else
-                            ArithmeticTypeInfo.Class.strange_integer
+                            .strange_integer
                     else
                         .composite_integer,
                 };
@@ -1338,19 +1302,6 @@ const NavGen = struct {
         return self.spv.functionType(return_ty_id, param_ids);
     }
 
-    fn zigScalarOrVectorTypeLike(self: *NavGen, new_ty: Type, base_ty: Type) !Type {
-        const pt = self.pt;
-        const new_scalar_ty = new_ty.scalarType(pt.zcu);
-        if (!base_ty.isVector(pt.zcu)) {
-            return new_scalar_ty;
-        }
-
-        return try pt.vectorType(.{
-            .len = base_ty.vectorLen(pt.zcu),
-            .child = new_scalar_ty.toIntern(),
-        });
-    }
-
     /// Generate a union type. Union types are always generated with the
     /// most aligned field active. If the tag alignment is greater
     /// than that of the payload, a regular union (non-packed, with both tag and
@@ -1632,12 +1583,7 @@ const NavGen = struct {
                 const elem_ty = ty.childType(zcu);
                 const elem_ty_id = try self.resolveType(elem_ty, repr);
                 const len = ty.vectorLen(zcu);
-
-                if (self.isSpvVector(ty)) {
-                    return try self.spv.vectorType(len, elem_ty_id);
-                } else {
-                    return try self.arrayType(len, elem_ty_id);
-                }
+                return self.arrayType(len, elem_ty_id);
             },
             .@"struct" => {
                 const struct_type = switch (ip.indexToKey(ty.toIntern())) {
@@ -2035,69 +1981,32 @@ const NavGen = struct {
     const Vectorization = union(enum) {
         /// This is an operation between scalars.
         scalar,
-        /// This is an operation between SPIR-V vectors.
-        /// Value is number of components.
-        spv_vectorized: u32,
         /// This operation is unrolled into separate operations.
         /// Inputs may still be SPIR-V vectors, for example,
         /// when the operation can't be vectorized in SPIR-V.
         /// Value is number of components.
         unrolled: u32,
 
-        /// Derive a vectorization from a particular type. This usually
-        /// only checks the size, but the source-of-truth is implemented
-        /// by `isSpvVector()`.
+        /// Derive a vectorization from a particular type
         fn fromType(ty: Type, ng: *NavGen) Vectorization {
             const zcu = ng.pt.zcu;
-            if (!ty.isVector(zcu)) {
-                return .scalar;
-            } else if (ng.isSpvVector(ty)) {
-                return .{ .spv_vectorized = ty.vectorLen(zcu) };
-            } else {
-                return .{ .unrolled = ty.vectorLen(zcu) };
-            }
+            if (!ty.isVector(zcu)) return .scalar;
+            return .{ .unrolled = ty.vectorLen(zcu) };
         }
 
         /// Given two vectorization methods, compute a "unification": a fallback
         /// that works for both, according to the following rules:
         /// - Scalars may broadcast
-        /// - SPIR-V vectorized operations may unroll
-        /// - Prefer scalar > SPIR-V vectorized > unrolled
+        /// - SPIR-V vectorized operations will unroll
+        /// - Prefer scalar > unrolled
         fn unify(a: Vectorization, b: Vectorization) Vectorization {
-            if (a == .scalar and b == .scalar) {
-                return .scalar;
-            } else if (a == .spv_vectorized and b == .spv_vectorized) {
-                assert(a.components() == b.components());
-                return .{ .spv_vectorized = a.components() };
-            } else if (a == .unrolled or b == .unrolled) {
-                if (a == .unrolled and b == .unrolled) {
-                    assert(a.components() == b.components());
-                    return .{ .unrolled = a.components() };
-                } else if (a == .unrolled) {
-                    return .{ .unrolled = a.components() };
-                } else if (b == .unrolled) {
-                    return .{ .unrolled = b.components() };
-                } else {
-                    unreachable;
-                }
-            } else {
-                if (a == .spv_vectorized) {
-                    return .{ .spv_vectorized = a.components() };
-                } else if (b == .spv_vectorized) {
-                    return .{ .spv_vectorized = b.components() };
-                } else {
-                    unreachable;
-                }
+            if (a == .scalar and b == .scalar) return .scalar;
+            if (a == .unrolled or b == .unrolled) {
+                if (a == .unrolled and b == .unrolled) assert(a.components() == b.components());
+                if (a == .unrolled) return .{ .unrolled = a.components() };
+                return .{ .unrolled = b.components() };
             }
-        }
-
-        /// Force this vectorization to be unrolled, if its
-        /// an operation involving vectors.
-        fn unroll(self: Vectorization) Vectorization {
-            return switch (self) {
-                .scalar, .unrolled => self,
-                .spv_vectorized => |n| .{ .unrolled = n },
-            };
+            unreachable;
         }
 
         /// Query the number of components that inputs of this operation have.
@@ -2106,35 +2015,10 @@ const NavGen = struct {
         fn components(self: Vectorization) u32 {
             return switch (self) {
                 .scalar => 1,
-                .spv_vectorized => |n| n,
                 .unrolled => |n| n,
             };
         }
 
-        /// Query the number of operations involving this vectorization.
-        /// This is basically the number of components, except that SPIR-V vectorized
-        /// operations only need a single SPIR-V instruction.
-        fn operations(self: Vectorization) u32 {
-            return switch (self) {
-                .scalar, .spv_vectorized => 1,
-                .unrolled => |n| n,
-            };
-        }
-
-        /// Turns `ty` into the result-type of an individual vector operation.
-        /// `ty` may be a scalar or vector, it doesn't matter.
-        fn operationType(self: Vectorization, ng: *NavGen, ty: Type) !Type {
-            const pt = ng.pt;
-            const scalar_ty = ty.scalarType(pt.zcu);
-            return switch (self) {
-                .scalar, .unrolled => scalar_ty,
-                .spv_vectorized => |n| try pt.vectorType(.{
-                    .len = n,
-                    .child = scalar_ty.toIntern(),
-                }),
-            };
-        }
-
         /// Turns `ty` into the result-type of the entire operation.
         /// `ty` may be a scalar or vector, it doesn't matter.
         fn resultType(self: Vectorization, ng: *NavGen, ty: Type) !Type {
@@ -2142,10 +2026,7 @@ const NavGen = struct {
             const scalar_ty = ty.scalarType(pt.zcu);
             return switch (self) {
                 .scalar => scalar_ty,
-                .unrolled, .spv_vectorized => |n| try pt.vectorType(.{
-                    .len = n,
-                    .child = scalar_ty.toIntern(),
-                }),
+                .unrolled => |n| try pt.vectorType(.{ .len = n, .child = scalar_ty.toIntern() }),
             };
         }
 
@@ -2155,51 +2036,19 @@ const NavGen = struct {
         fn prepare(self: Vectorization, ng: *NavGen, tmp: Temporary) !PreparedOperand {
             const pt = ng.pt;
             const is_vector = tmp.ty.isVector(pt.zcu);
-            const is_spv_vector = ng.isSpvVector(tmp.ty);
             const value: PreparedOperand.Value = switch (tmp.value) {
                 .singleton => |id| switch (self) {
                     .scalar => blk: {
                         assert(!is_vector);
                         break :blk .{ .scalar = id };
                     },
-                    .spv_vectorized => blk: {
-                        if (is_vector) {
-                            assert(is_spv_vector);
-                            break :blk .{ .spv_vectorwise = id };
-                        }
-
-                        // Broadcast scalar into vector.
-                        const vector_ty = try pt.vectorType(.{
-                            .len = self.components(),
-                            .child = tmp.ty.toIntern(),
-                        });
-
-                        const vector = try ng.constructCompositeSplat(vector_ty, id);
-                        return .{
-                            .ty = vector_ty,
-                            .value = .{ .spv_vectorwise = vector },
-                        };
-                    },
                     .unrolled => blk: {
-                        if (is_vector) {
-                            break :blk .{ .vector_exploded = try tmp.explode(ng) };
-                        } else {
-                            break :blk .{ .scalar_broadcast = id };
-                        }
+                        if (is_vector) break :blk .{ .vector_exploded = try tmp.explode(ng) };
+                        break :blk .{ .scalar_broadcast = id };
                     },
                 },
                 .exploded_vector => |range| switch (self) {
                     .scalar => unreachable,
-                    .spv_vectorized => |n| blk: {
-                        // We can vectorize this operation, but we have an exploded vector. This can happen
-                        // when a vectorizable operation succeeds a non-vectorizable operation. In this case,
-                        // pack up the IDs into a SPIR-V vector. This path should not be able to be hit with
-                        // a type that cannot do that.
-                        assert(is_spv_vector);
-                        assert(range.len == n);
-                        const vec = try tmp.materialize(ng);
-                        break :blk .{ .spv_vectorwise = vec };
-                    },
                     .unrolled => |n| blk: {
                         assert(range.len == n);
                         break :blk .{ .vector_exploded = range };
@@ -2216,17 +2065,14 @@ const NavGen = struct {
         /// Finalize the results of an operation back into a temporary. `results` is
         /// a list of result-ids of the operation.
         fn finalize(self: Vectorization, ty: Type, results: IdRange) Temporary {
-            assert(self.operations() == results.len);
-            const value: Temporary.Value = switch (self) {
-                .scalar, .spv_vectorized => blk: {
-                    break :blk .{ .singleton = results.at(0) };
-                },
-                .unrolled => blk: {
-                    break :blk .{ .exploded_vector = results };
+            assert(self.components() == results.len);
+            return .{
+                .ty = ty,
+                .value = switch (self) {
+                    .scalar => .{ .singleton = results.at(0) },
+                    .unrolled => .{ .exploded_vector = results },
                 },
             };
-
-            return .{ .ty = ty, .value = value };
         }
 
         /// This struct represents an operand that has gone through some setup, and is
@@ -2242,32 +2088,20 @@ const NavGen = struct {
                 scalar: IdResult,
                 /// A single scalar that is broadcasted in an unrolled operation.
                 scalar_broadcast: IdResult,
-                /// A SPIR-V vector that is used in SPIR-V vectorize operation.
-                spv_vectorwise: IdResult,
                 /// A vector represented by a consecutive list of IDs that is used in an unrolled operation.
                 vector_exploded: IdRange,
             };
 
             /// Query the value at a particular index of the operation. Note that
-            /// the index is *not* the component/lane, but the index of the *operation*. When
-            /// this operation is vectorized, the return value of this function is a SPIR-V vector.
-            /// See also `Vectorization.operations()`.
+            /// the index is *not* the component/lane, but the index of the *operation*.
             fn at(self: PreparedOperand, i: usize) IdResult {
                 switch (self.value) {
                     .scalar => |id| {
                         assert(i == 0);
                         return id;
                     },
-                    .scalar_broadcast => |id| {
-                        return id;
-                    },
-                    .spv_vectorwise => |id| {
-                        assert(i == 0);
-                        return id;
-                    },
-                    .vector_exploded => |range| {
-                        return range.at(i);
-                    },
+                    .scalar_broadcast => |id| return id,
+                    .vector_exploded => |range| return range.at(i),
                 }
             }
         };
@@ -2299,7 +2133,7 @@ const NavGen = struct {
 
     /// This function builds an OpSConvert of OpUConvert depending on the
     /// signedness of the types.
-    fn buildIntConvert(self: *NavGen, dst_ty: Type, src: Temporary) !Temporary {
+    fn buildConvert(self: *NavGen, dst_ty: Type, src: Temporary) !Temporary {
         const zcu = self.pt.zcu;
 
         const dst_ty_id = try self.resolveType(dst_ty.scalarType(zcu), .direct);
@@ -2318,13 +2152,17 @@ const NavGen = struct {
             return src.pun(result_ty);
         }
 
-        const ops = v.operations();
+        const ops = v.components();
         const results = self.spv.allocIds(ops);
 
-        const op_result_ty = try v.operationType(self, dst_ty);
+        const op_result_ty = dst_ty.scalarType(zcu);
         const op_result_ty_id = try self.resolveType(op_result_ty, .direct);
 
-        const opcode: Opcode = if (dst_ty.isSignedInt(zcu)) .OpSConvert else .OpUConvert;
+        const opcode: Opcode = blk: {
+            if (dst_ty.scalarType(zcu).isAnyFloat()) break :blk .OpFConvert;
+            if (dst_ty.scalarType(zcu).isSignedInt(zcu)) break :blk .OpSConvert;
+            break :blk .OpUConvert;
+        };
 
         const op_src = try v.prepare(self, src);
 
@@ -2339,13 +2177,14 @@ const NavGen = struct {
     }
 
     fn buildFma(self: *NavGen, a: Temporary, b: Temporary, c: Temporary) !Temporary {
+        const zcu = self.pt.zcu;
         const target = self.spv.target;
 
         const v = self.vectorization(.{ a, b, c });
-        const ops = v.operations();
+        const ops = v.components();
         const results = self.spv.allocIds(ops);
 
-        const op_result_ty = try v.operationType(self, a.ty);
+        const op_result_ty = a.ty.scalarType(zcu);
         const op_result_ty_id = try self.resolveType(op_result_ty, .direct);
         const result_ty = try v.resultType(self, a.ty);
 
@@ -2382,10 +2221,10 @@ const NavGen = struct {
         const zcu = self.pt.zcu;
 
         const v = self.vectorization(.{ condition, lhs, rhs });
-        const ops = v.operations();
+        const ops = v.components();
         const results = self.spv.allocIds(ops);
 
-        const op_result_ty = try v.operationType(self, lhs.ty);
+        const op_result_ty = lhs.ty.scalarType(zcu);
         const op_result_ty_id = try self.resolveType(op_result_ty, .direct);
         const result_ty = try v.resultType(self, lhs.ty);
 
@@ -2431,10 +2270,10 @@ const NavGen = struct {
 
     fn buildCmp(self: *NavGen, pred: CmpPredicate, lhs: Temporary, rhs: Temporary) !Temporary {
         const v = self.vectorization(.{ lhs, rhs });
-        const ops = v.operations();
+        const ops = v.components();
         const results = self.spv.allocIds(ops);
 
-        const op_result_ty = try v.operationType(self, Type.bool);
+        const op_result_ty: Type = .bool;
         const op_result_ty_id = try self.resolveType(op_result_ty, .direct);
         const result_ty = try v.resultType(self, Type.bool);
 
@@ -2498,22 +2337,12 @@ const NavGen = struct {
     };
 
     fn buildUnary(self: *NavGen, op: UnaryOp, operand: Temporary) !Temporary {
+        const zcu = self.pt.zcu;
         const target = self.spv.target;
-        const v = blk: {
-            const v = self.vectorization(.{operand});
-            break :blk switch (op) {
-                // TODO: These instructions don't seem to be working
-                // properly for LLVM-based backends on OpenCL for 8- and
-                // 16-component vectors.
-                .i_abs => if (self.spv.hasFeature(.vector16) and v.components() >= 8) v.unroll() else v,
-                else => v,
-            };
-        };
-
-        const ops = v.operations();
+        const v = self.vectorization(.{operand});
+        const ops = v.components();
         const results = self.spv.allocIds(ops);
-
-        const op_result_ty = try v.operationType(self, operand.ty);
+        const op_result_ty = operand.ty.scalarType(zcu);
         const op_result_ty_id = try self.resolveType(op_result_ty, .direct);
         const result_ty = try v.resultType(self, operand.ty);
 
@@ -2628,13 +2457,14 @@ const NavGen = struct {
     };
 
     fn buildBinary(self: *NavGen, op: BinaryOp, lhs: Temporary, rhs: Temporary) !Temporary {
+        const zcu = self.pt.zcu;
         const target = self.spv.target;
 
         const v = self.vectorization(.{ lhs, rhs });
-        const ops = v.operations();
+        const ops = v.components();
         const results = self.spv.allocIds(ops);
 
-        const op_result_ty = try v.operationType(self, lhs.ty);
+        const op_result_ty = lhs.ty.scalarType(zcu);
         const op_result_ty_id = try self.resolveType(op_result_ty, .direct);
         const result_ty = try v.resultType(self, lhs.ty);
 
@@ -2730,9 +2560,9 @@ const NavGen = struct {
         const ip = &zcu.intern_pool;
 
         const v = lhs.vectorization(self).unify(rhs.vectorization(self));
-        const ops = v.operations();
+        const ops = v.components();
 
-        const arith_op_ty = try v.operationType(self, lhs.ty);
+        const arith_op_ty = lhs.ty.scalarType(zcu);
         const arith_op_ty_id = try self.resolveType(arith_op_ty, .direct);
 
         const lhs_op = try v.prepare(self, lhs);
@@ -3175,17 +3005,18 @@ const NavGen = struct {
     /// Convert representation from indirect (in memory) to direct (in 'register')
     /// This converts the argument type from resolveType(ty, .indirect) to resolveType(ty, .direct).
     fn convertToDirect(self: *NavGen, ty: Type, operand_id: IdRef) !IdRef {
-        const zcu = self.pt.zcu;
+        const pt = self.pt;
+        const zcu = pt.zcu;
         switch (ty.scalarType(zcu).zigTypeTag(zcu)) {
             .bool => {
                 const false_id = try self.constBool(false, .indirect);
-                // The operation below requires inputs in direct representation, but the operand
-                // is actually in indirect representation.
-                // Cheekily swap out the type to the direct equivalent of the indirect type here, they have the
-                // same representation when converted to SPIR-V.
-                const operand_ty = try self.zigScalarOrVectorTypeLike(Type.u1, ty);
-                // Note: We can guarantee that these are the same ID due to the SPIR-V Module's `vector_types` cache!
-                assert(try self.resolveType(operand_ty, .direct) == try self.resolveType(ty, .indirect));
+                const operand_ty = blk: {
+                    if (!ty.isVector(pt.zcu)) break :blk Type.u1;
+                    break :blk try pt.vectorType(.{
+                        .len = ty.vectorLen(pt.zcu),
+                        .child = Type.u1.toIntern(),
+                    });
+                };
 
                 const result = try self.buildCmp(
                     .i_ne,
@@ -3226,7 +3057,6 @@ const NavGen = struct {
     }
 
     fn extractVectorComponent(self: *NavGen, result_ty: Type, vector_id: IdRef, field: u32) !IdRef {
-        // Whether this is an OpTypeVector or OpTypeArray, we need to emit the same instruction regardless.
         const result_ty_id = try self.resolveType(result_ty, .direct);
         const result_id = self.spv.allocId();
         const indexes = [_]u32{field};
@@ -3485,7 +3315,7 @@ const NavGen = struct {
         // Note: The sign may differ here between the shift and the base type, in case
         // of an arithmetic right shift. SPIR-V still expects the same type,
         // so in that case we have to cast convert to signed.
-        const casted_shift = try self.buildIntConvert(base.ty.scalarType(zcu), shift);
+        const casted_shift = try self.buildConvert(base.ty.scalarType(zcu), shift);
 
         const shifted = switch (info.signedness) {
             .unsigned => try self.buildBinary(unsigned, base, casted_shift),
@@ -3815,12 +3645,12 @@ const NavGen = struct {
             .unsigned => blk: {
                 if (maybe_op_ty_bits) |op_ty_bits| {
                     const op_ty = try pt.intType(.unsigned, op_ty_bits);
-                    const casted_lhs = try self.buildIntConvert(op_ty, lhs);
-                    const casted_rhs = try self.buildIntConvert(op_ty, rhs);
+                    const casted_lhs = try self.buildConvert(op_ty, lhs);
+                    const casted_rhs = try self.buildConvert(op_ty, rhs);
 
                     const full_result = try self.buildBinary(.i_mul, casted_lhs, casted_rhs);
 
-                    const low_bits = try self.buildIntConvert(lhs.ty, full_result);
+                    const low_bits = try self.buildConvert(lhs.ty, full_result);
                     const result = try self.normalize(low_bits, info);
 
                     // Shift the result bits away to get the overflow bits.
@@ -3846,9 +3676,7 @@ const NavGen = struct {
                 const high_overflowed = try self.buildCmp(.i_ne, zero, high_bits);
 
                 // If no overflow bits in low_bits, no extra work needs to be done.
-                if (info.backing_bits == info.bits) {
-                    break :blk .{ result, high_overflowed };
-                }
+                if (info.backing_bits == info.bits) break :blk .{ result, high_overflowed };
 
                 // Shift the result bits away to get the overflow bits.
                 const shift = Temporary.init(lhs.ty, try self.constInt(lhs.ty, info.bits));
@@ -3886,13 +3714,13 @@ const NavGen = struct {
                 if (maybe_op_ty_bits) |op_ty_bits| {
                     const op_ty = try pt.intType(.signed, op_ty_bits);
                     // Assume normalized; sign bit is set. We want a sign extend.
-                    const casted_lhs = try self.buildIntConvert(op_ty, lhs);
-                    const casted_rhs = try self.buildIntConvert(op_ty, rhs);
+                    const casted_lhs = try self.buildConvert(op_ty, lhs);
+                    const casted_rhs = try self.buildConvert(op_ty, rhs);
 
                     const full_result = try self.buildBinary(.i_mul, casted_lhs, casted_rhs);
 
                     // Truncate to the result type.
-                    const low_bits = try self.buildIntConvert(lhs.ty, full_result);
+                    const low_bits = try self.buildConvert(lhs.ty, full_result);
                     const result = try self.normalize(low_bits, info);
 
                     // Now, we need to check the overflow bits AND the sign
@@ -3929,9 +3757,7 @@ const NavGen = struct {
                 // If no overflow bits in low_bits, no extra work needs to be done.
                 // Careful, we still have to check the sign bit, so this branch
                 // only goes for i33 and such.
-                if (info.backing_bits == info.bits + 1) {
-                    break :blk .{ result, high_overflowed };
-                }
+                if (info.backing_bits == info.bits + 1) break :blk .{ result, high_overflowed };
 
                 // Shift the result bits away to get the overflow bits.
                 const shift = Temporary.init(lhs.ty, try self.constInt(lhs.ty, info.bits - 1));
@@ -3972,7 +3798,7 @@ const NavGen = struct {
 
         // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that,
         // so just manually upcast it if required.
-        const casted_shift = try self.buildIntConvert(base.ty.scalarType(zcu), shift);
+        const casted_shift = try self.buildConvert(base.ty.scalarType(zcu), shift);
 
         const left = try self.buildBinary(.sll, base, casted_shift);
         const result = try self.normalize(left, info);
@@ -4026,7 +3852,7 @@ const NavGen = struct {
         // Result of OpenCL ctz/clz returns operand.ty, and we want result_ty.
         // result_ty is always large enough to hold the result, so we might have to down
         // cast it.
-        const result = try self.buildIntConvert(scalar_result_ty, count);
+        const result = try self.buildConvert(scalar_result_ty, count);
         return try result.materialize(self);
     }
 
@@ -4057,11 +3883,8 @@ const NavGen = struct {
         const operand_ty = self.typeOf(reduce.operand);
         const scalar_ty = operand_ty.scalarType(zcu);
         const scalar_ty_id = try self.resolveType(scalar_ty, .direct);
-
         const info = self.arithmeticTypeInfo(operand_ty);
-
         const len = operand_ty.vectorLen(zcu);
-
         const first = try self.extractVectorComponent(scalar_ty, operand, 0);
 
         switch (reduce.operation) {
@@ -4136,51 +3959,9 @@ const NavGen = struct {
 
         // Note: number of components in the result, a, and b may differ.
         const result_ty = self.typeOfIndex(inst);
-        const a_ty = self.typeOf(extra.a);
-        const b_ty = self.typeOf(extra.b);
-
         const scalar_ty = result_ty.scalarType(zcu);
         const scalar_ty_id = try self.resolveType(scalar_ty, .direct);
 
-        // If all of the types are SPIR-V vectors, we can use OpVectorShuffle.
-        if (self.isSpvVector(result_ty) and self.isSpvVector(a_ty) and self.isSpvVector(b_ty)) {
-            // The SPIR-V shuffle instruction is similar to the Air instruction, except that the elements are
-            // numbered consecutively instead of using negatives.
-
-            const components = try self.gpa.alloc(Word, result_ty.vectorLen(zcu));
-            defer self.gpa.free(components);
-
-            const a_len = a_ty.vectorLen(zcu);
-
-            for (components, 0..) |*component, i| {
-                const elem = try mask.elemValue(pt, i);
-                if (elem.isUndef(zcu)) {
-                    // This is explicitly valid for OpVectorShuffle, it indicates undefined.
-                    component.* = 0xFFFF_FFFF;
-                    continue;
-                }
-
-                const index = elem.toSignedInt(zcu);
-                if (index >= 0) {
-                    component.* = @intCast(index);
-                } else {
-                    component.* = @intCast(~index + a_len);
-                }
-            }
-
-            const result_id = self.spv.allocId();
-            try self.func.body.emit(self.spv.gpa, .OpVectorShuffle, .{
-                .id_result_type = try self.resolveType(result_ty, .direct),
-                .id_result = result_id,
-                .vector_1 = a,
-                .vector_2 = b,
-                .components = components,
-            });
-            return result_id;
-        }
-
-        // Fall back to manually extracting and inserting components.
-
         const constituents = try self.gpa.alloc(IdRef, result_ty.vectorLen(zcu));
         defer self.gpa.free(constituents);
 
@@ -4535,9 +4316,7 @@ const NavGen = struct {
         const dst_ty_id = try self.resolveType(dst_ty, .direct);
 
         const result_id = blk: {
-            if (src_ty_id == dst_ty_id) {
-                break :blk src_id;
-            }
+            if (src_ty_id == dst_ty_id) break :blk src_id;
 
             // TODO: Some more cases are missing here
             //   See fn bitCast in llvm.zig
@@ -4618,7 +4397,7 @@ const NavGen = struct {
             return try src.materialize(self);
         }
 
-        const converted = try self.buildIntConvert(dst_ty, src);
+        const converted = try self.buildConvert(dst_ty, src);
 
         // Make sure to normalize the result if shrinking.
         // Because strange ints are sign extended in their backing
@@ -4698,17 +4477,10 @@ const NavGen = struct {
 
     fn airFloatCast(self: *NavGen, inst: Air.Inst.Index) !?IdRef {
         const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
-        const operand_id = try self.resolve(ty_op.operand);
+        const operand = try self.temporary(ty_op.operand);
         const dest_ty = self.typeOfIndex(inst);
-        const dest_ty_id = try self.resolveType(dest_ty, .direct);
-
-        const result_id = self.spv.allocId();
-        try self.func.body.emit(self.spv.gpa, .OpFConvert, .{
-            .id_result_type = dest_ty_id,
-            .id_result = result_id,
-            .float_value = operand_id,
-        });
-        return result_id;
+        const result = try self.buildConvert(dest_ty, operand);
+        return try result.materialize(self);
     }
 
     fn airNot(self: *NavGen, inst: Air.Inst.Index) !?IdRef {
@@ -4796,7 +4568,7 @@ const NavGen = struct {
                             break :blk try self.bitCast(field_int_ty, field_ty, field_id);
                         };
                         const shift_rhs = try self.constInt(backing_int_ty, running_bits);
-                        const extended_int_conv = try self.buildIntConvert(backing_int_ty, .{
+                        const extended_int_conv = try self.buildConvert(backing_int_ty, .{
                             .ty = field_int_ty,
                             .value = .{ .singleton = field_int_id },
                         });
@@ -5016,17 +4788,6 @@ const NavGen = struct {
         const array_id = try self.resolve(bin_op.lhs);
         const index_id = try self.resolve(bin_op.rhs);
 
-        if (self.isSpvVector(array_ty)) {
-            const result_id = self.spv.allocId();
-            try self.func.body.emit(self.spv.gpa, .OpVectorExtractDynamic, .{
-                .id_result_type = try self.resolveType(elem_ty, .direct),
-                .id_result = result_id,
-                .vector = array_id,
-                .index = index_id,
-            });
-            return result_id;
-        }
-
         // SPIR-V doesn't have an array indexing function for some damn reason.
         // For now, just generate a temporary and use that.
         // TODO: This backend probably also should use isByRef from llvm...
@@ -5173,7 +4934,7 @@ const NavGen = struct {
                     return self.bitCast(ty, payload_ty, payload.?);
                 }
 
-                const trunc = try self.buildIntConvert(ty, .{ .ty = payload_ty, .value = .{ .singleton = payload.? } });
+                const trunc = try self.buildConvert(ty, .{ .ty = payload_ty, .value = .{ .singleton = payload.? } });
                 return try trunc.materialize(self);
             }
 
@@ -5182,7 +4943,7 @@ const NavGen = struct {
                 try self.convertToIndirect(payload_ty, payload.?)
             else
                 try self.bitCast(payload_int_ty, payload_ty, payload.?);
-            const trunc = try self.buildIntConvert(ty, .{ .ty = payload_int_ty, .value = .{ .singleton = payload_int } });
+            const trunc = try self.buildConvert(ty, .{ .ty = payload_int_ty, .value = .{ .singleton = payload_int } });
             return try trunc.materialize(self);
         }
 
@@ -5273,7 +5034,7 @@ const NavGen = struct {
                     const result_id = blk: {
                         if (self.backingIntBits(field_bit_size).? == self.backingIntBits(@intCast(object_ty.bitSize(zcu))).?)
                             break :blk try self.bitCast(field_int_ty, object_ty, try masked.materialize(self));
-                        const trunc = try self.buildIntConvert(field_int_ty, masked);
+                        const trunc = try self.buildConvert(field_int_ty, masked);
                         break :blk try trunc.materialize(self);
                     };
                     if (field_ty.ip_index == .bool_type) return try self.convertToDirect(.bool, result_id);
@@ -5297,7 +5058,7 @@ const NavGen = struct {
                     const result_id = blk: {
                         if (self.backingIntBits(field_bit_size).? == self.backingIntBits(@intCast(backing_int_ty.bitSize(zcu))).?)
                             break :blk try self.bitCast(int_ty, backing_int_ty, try masked.materialize(self));
-                        const trunc = try self.buildIntConvert(int_ty, masked);
+                        const trunc = try self.buildConvert(int_ty, masked);
                         break :blk try trunc.materialize(self);
                     };
                     if (field_ty.ip_index == .bool_type) return try self.convertToDirect(.bool, result_id);
@@ -6752,7 +6513,7 @@ const NavGen = struct {
         // TODO: Should we make these builtins return usize?
         const result_id = try self.builtin3D(Type.u64, .LocalInvocationId, dimension, 0);
         const tmp = Temporary.init(Type.u64, result_id);
-        const result = try self.buildIntConvert(Type.u32, tmp);
+        const result = try self.buildConvert(Type.u32, tmp);
         return try result.materialize(self);
     }
 
@@ -6763,7 +6524,7 @@ const NavGen = struct {
         // TODO: Should we make these builtins return usize?
         const result_id = try self.builtin3D(Type.u64, .WorkgroupSize, dimension, 0);
         const tmp = Temporary.init(Type.u64, result_id);
-        const result = try self.buildIntConvert(Type.u32, tmp);
+        const result = try self.buildConvert(Type.u32, tmp);
         return try result.materialize(self);
     }
 
@@ -6774,7 +6535,7 @@ const NavGen = struct {
         // TODO: Should we make these builtins return usize?
         const result_id = try self.builtin3D(Type.u64, .WorkgroupId, dimension, 0);
         const tmp = Temporary.init(Type.u64, result_id);
-        const result = try self.buildIntConvert(Type.u32, tmp);
+        const result = try self.buildConvert(Type.u32, tmp);
         return try result.materialize(self);
     }