Commit 5ee96688a7

Ali Chraghi <alichraghi@proton.me>
2024-02-09 06:43:11
spirv: emit vectorized operations
1 parent ddcea2c
Changed files (1)
src
codegen
src/codegen/spirv.zig
@@ -629,6 +629,16 @@ const DeclGen = struct {
         return self.backingIntBits(ty) == null;
     }
 
+    /// Checks whether the type can be directly translated to SPIR-V vectors
+    fn isVector(self: *DeclGen, ty: Type) bool {
+        const mod = self.module;
+        if (ty.zigTypeTag(mod) != .Vector) return false;
+        const elem_ty = ty.childType(mod);
+        const len = ty.vectorLen(mod);
+        const is_scalar = elem_ty.isNumeric(mod) or elem_ty.toIntern() == .bool_type;
+        return is_scalar and len > 1 and len <= 4;
+    }
+
     fn arithmeticTypeInfo(self: *DeclGen, ty: Type) ArithmeticTypeInfo {
         const mod = self.module;
         const target = self.getTarget();
@@ -694,6 +704,24 @@ const DeclGen = struct {
     /// This function, unlike SpvModule.constInt, takes care to bitcast
     /// the value to an unsigned int first for Kernels.
     fn constInt(self: *DeclGen, ty_ref: CacheRef, value: anytype) !IdRef {
+        switch (self.spv.cache.lookup(ty_ref)) {
+            .vector_type => |vec_type| {
+                const elem_ids = try self.gpa.alloc(IdRef, vec_type.component_count);
+                defer self.gpa.free(elem_ids);
+                const int_value = try self.constInt(vec_type.component_type, value);
+                @memset(elem_ids, int_value);
+
+                const constituents_id = self.spv.allocId();
+                try self.func.body.emit(self.spv.gpa, .OpCompositeConstruct, .{
+                    .id_result_type = self.typeId(ty_ref),
+                    .id_result = constituents_id,
+                    .constituents = elem_ids,
+                });
+                return constituents_id;
+            },
+            else => {},
+        }
+
         if (value < 0) {
             const ty = self.spv.cache.lookup(ty_ref).int_type;
             // Manually truncate the value so that the resulting value
@@ -711,6 +739,24 @@ const DeclGen = struct {
 
     /// Emits a float constant
     fn constFloat(self: *DeclGen, ty_ref: CacheRef, value: f128) !IdRef {
+        switch (self.spv.cache.lookup(ty_ref)) {
+            .vector_type => |vec_type| {
+                const elem_ids = try self.gpa.alloc(IdRef, vec_type.component_count);
+                defer self.gpa.free(elem_ids);
+                const int_value = try self.constFloat(vec_type.component_type, value);
+                @memset(elem_ids, int_value);
+
+                const constituents_id = self.spv.allocId();
+                try self.func.body.emit(self.spv.gpa, .OpCompositeConstruct, .{
+                    .id_result_type = self.typeId(ty_ref),
+                    .id_result = constituents_id,
+                    .constituents = elem_ids,
+                });
+                return constituents_id;
+            },
+            else => {},
+        }
+
         const ty = self.spv.cache.lookup(ty_ref).float_type;
         return switch (ty.bits) {
             16 => try self.spv.resolveId(.{ .float = .{ .ty = ty_ref, .value = .{ .float16 = @floatCast(value) } } }),
@@ -726,9 +772,9 @@ const DeclGen = struct {
     /// if the parameters are in indirect representation, then the result is too.
     fn constructComposite(self: *DeclGen, ty: Type, constituents: []const IdRef) !IdRef {
         const constituents_id = self.spv.allocId();
-        const type_id = try self.resolveTypeId(ty);
+        const type_id = try self.resolveType(ty, .direct);
         try self.func.body.emit(self.spv.gpa, .OpCompositeConstruct, .{
-            .id_result_type = type_id,
+            .id_result_type = self.typeId(type_id),
             .id_result = constituents_id,
             .constituents = constituents,
         });
@@ -901,19 +947,19 @@ const DeclGen = struct {
                         .bytes => |bytes| {
                             // TODO: This is really space inefficient, perhaps there is a better
                             // way to do it?
-                            for (bytes, 0..) |byte, i| {
-                                constituents[i] = try self.constInt(elem_ty_ref, byte);
+                            for (constituents, bytes) |*constituent, byte| {
+                                constituent.* = try self.constInt(elem_ty_ref, byte);
                             }
                         },
                         .elems => |elems| {
-                            for (0..@as(usize, @intCast(array_type.len))) |i| {
-                                constituents[i] = try self.constant(elem_ty, Value.fromInterned(elems[i]), .indirect);
+                            for (constituents, elems) |*constituent, elem| {
+                                constituent.* = try self.constant(elem_ty, Value.fromInterned(elem), .indirect);
                             }
                         },
                         .repeated_elem => |elem| {
                             const val_id = try self.constant(elem_ty, Value.fromInterned(elem), .indirect);
-                            for (0..@as(usize, @intCast(array_type.len))) |i| {
-                                constituents[i] = val_id;
+                            for (constituents) |*constituent| {
+                                constituent.* = val_id;
                             }
                         },
                     }
@@ -1448,12 +1494,11 @@ const DeclGen = struct {
                 const elem_ty = ty.childType(mod);
                 const elem_ty_ref = try self.resolveType(elem_ty, .indirect);
                 const len = ty.vectorLen(mod);
-                const is_scalar = elem_ty.isNumeric(mod) or elem_ty.toIntern() == .bool_type;
 
-                const ty_ref = if (is_scalar and len > 1 and len <= 4)
-                    try self.spv.vectorType(ty.vectorLen(mod), elem_ty_ref)
+                const ty_ref = if (self.isVector(ty))
+                    try self.spv.vectorType(len, elem_ty_ref)
                 else
-                    try self.spv.arrayType(ty.vectorLen(mod), elem_ty_ref);
+                    try self.spv.arrayType(len, elem_ty_ref);
 
                 try self.type_map.put(self.gpa, ty.toIntern(), .{ .ty_ref = ty_ref });
                 return ty_ref;
@@ -1752,18 +1797,16 @@ const DeclGen = struct {
     }
 
     /// This structure is used as helper for element-wise operations. It is intended
-    /// to be used with both vectors and single elements.
+    /// to be used with vectors, fake vectors (arrays) and single elements.
     const WipElementWise = struct {
         dg: *DeclGen,
         result_ty: Type,
+        ty: Type,
         /// Always in direct representation.
-        result_ty_ref: CacheRef,
-        scalar_ty: Type,
-        /// Always in direct representation.
-        scalar_ty_ref: CacheRef,
-        scalar_ty_id: IdRef,
-        /// True if the input is actually a vector type.
-        is_vector: bool,
+        ty_ref: CacheRef,
+        ty_id: IdRef,
+        /// True if the input is an array type.
+        is_array: bool,
         /// The element-wise operation should fill these results before calling finalize().
         /// These should all be in **direct** representation! `finalize()` will convert
         /// them to indirect if required.
@@ -1774,29 +1817,28 @@ const DeclGen = struct {
         }
 
         /// Utility function to extract the element at a particular index in an
-        /// input vector. This type is expected to be a vector if `wip.is_vector`, and
-        /// a scalar otherwise.
+        /// input array. This type is expected to be a fake vector (array) if `wip.is_array`, and
+        /// a vector or scalar otherwise.
         fn elementAt(wip: WipElementWise, ty: Type, value: IdRef, index: usize) !IdRef {
             const mod = wip.dg.module;
-            if (wip.is_vector) {
+            if (wip.is_array) {
                 assert(ty.isVector(mod));
                 return try wip.dg.extractField(ty.childType(mod), value, @intCast(index));
             } else {
-                assert(!ty.isVector(mod));
                 assert(index == 0);
                 return value;
             }
         }
 
-        /// Turns the results of this WipElementWise into a result. This can either
-        /// be a vector or single element, depending on `result_ty`.
+        /// Turns the results of this WipElementWise into a result. This can be
+        /// vectors, fake vectors (arrays) and single elements, depending on `result_ty`.
         /// After calling this function, this WIP is no longer usable.
         /// Results is in `direct` representation.
         fn finalize(wip: *WipElementWise) !IdRef {
-            if (wip.is_vector) {
+            if (wip.is_array) {
                 // Convert all the constituents to indirect, as required for the array.
                 for (wip.results) |*result| {
-                    result.* = try wip.dg.convertToIndirect(wip.scalar_ty, result.*);
+                    result.* = try wip.dg.convertToIndirect(wip.ty, result.*);
                 }
                 return try wip.dg.constructComposite(wip.result_ty, wip.results);
             } else {
@@ -1806,33 +1848,30 @@ const DeclGen = struct {
 
         /// Allocate a result id at a particular index, and return it.
         fn allocId(wip: *WipElementWise, index: usize) IdRef {
-            assert(wip.is_vector or index == 0);
+            assert(wip.is_array or index == 0);
             wip.results[index] = wip.dg.spv.allocId();
             return wip.results[index];
         }
     };
 
     /// Create a new element-wise operation.
-    fn elementWise(self: *DeclGen, result_ty: Type) !WipElementWise {
+    fn elementWise(self: *DeclGen, result_ty: Type, force_element_wise: bool) !WipElementWise {
         const mod = self.module;
-        // For now, this operation also reasons in terms of `.direct` representation.
-        const result_ty_ref = try self.resolveType(result_ty, .direct);
-        const is_vector = result_ty.isVector(mod);
-        const num_results = if (is_vector) result_ty.vectorLen(mod) else 1;
+        const is_array = result_ty.isVector(mod) and (!self.isVector(result_ty) or force_element_wise);
+        const num_results = if (is_array) result_ty.vectorLen(mod) else 1;
         const results = try self.gpa.alloc(IdRef, num_results);
-        for (results) |*result| result.* = undefined;
+        @memset(results, undefined);
 
-        const scalar_ty = result_ty.scalarType(mod);
-        const scalar_ty_ref = try self.resolveType(scalar_ty, .direct);
+        const ty = if (is_array) result_ty.scalarType(mod) else result_ty;
+        const ty_ref = try self.resolveType(ty, .direct);
 
         return .{
             .dg = self,
             .result_ty = result_ty,
-            .result_ty_ref = result_ty_ref,
-            .scalar_ty = scalar_ty,
-            .scalar_ty_ref = scalar_ty_ref,
-            .scalar_ty_id = self.typeId(scalar_ty_ref),
-            .is_vector = is_vector,
+            .ty = ty,
+            .ty_ref = ty_ref,
+            .ty_id = self.typeId(ty_ref),
+            .is_array = is_array,
             .results = results,
         };
     }
@@ -2312,11 +2351,11 @@ const DeclGen = struct {
     }
 
     fn binOpSimple(self: *DeclGen, ty: Type, lhs_id: IdRef, rhs_id: IdRef, comptime opcode: Opcode) !IdRef {
-        var wip = try self.elementWise(ty);
+        var wip = try self.elementWise(ty, false);
         defer wip.deinit();
         for (0..wip.results.len) |i| {
             try self.func.body.emit(self.spv.gpa, opcode, .{
-                .id_result_type = wip.scalar_ty_id,
+                .id_result_type = wip.ty_id,
                 .id_result = wip.allocId(i),
                 .operand_1 = try wip.elementAt(ty, lhs_id, i),
                 .operand_2 = try wip.elementAt(ty, rhs_id, i),
@@ -2345,7 +2384,7 @@ const DeclGen = struct {
 
         const result_ty = self.typeOfIndex(inst);
         const shift_ty = self.typeOf(bin_op.rhs);
-        const scalar_shift_ty_ref = try self.resolveType(shift_ty.scalarType(mod), .direct);
+        const shift_ty_ref = try self.resolveType(shift_ty, .direct);
 
         const info = self.arithmeticTypeInfo(result_ty);
         switch (info.class) {
@@ -2354,7 +2393,7 @@ const DeclGen = struct {
             .float, .bool => unreachable,
         }
 
-        var wip = try self.elementWise(result_ty);
+        var wip = try self.elementWise(result_ty, false);
         defer wip.deinit();
         for (wip.results, 0..) |*result_id, i| {
             const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i);
@@ -2362,10 +2401,10 @@ const DeclGen = struct {
 
             // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that,
             // so just manually upcast it if required.
-            const shift_id = if (scalar_shift_ty_ref != wip.scalar_ty_ref) blk: {
+            const shift_id = if (shift_ty_ref != wip.ty_ref) blk: {
                 const shift_id = self.spv.allocId();
                 try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
-                    .id_result_type = wip.scalar_ty_id,
+                    .id_result_type = wip.ty_id,
                     .id_result = shift_id,
                     .unsigned_value = rhs_elem_id,
                 });
@@ -2374,7 +2413,7 @@ const DeclGen = struct {
 
             const value_id = self.spv.allocId();
             const args = .{
-                .id_result_type = wip.scalar_ty_id,
+                .id_result_type = wip.ty_id,
                 .id_result = value_id,
                 .base = lhs_elem_id,
                 .shift = shift_id,
@@ -2386,7 +2425,7 @@ const DeclGen = struct {
                 try self.func.body.emit(self.spv.gpa, unsigned, args);
             }
 
-            result_id.* = try self.normalize(wip.scalar_ty_ref, value_id, info);
+            result_id.* = try self.normalize(wip.ty_ref, value_id, info);
         }
         return try wip.finalize();
     }
@@ -2405,14 +2444,14 @@ const DeclGen = struct {
     fn minMax(self: *DeclGen, result_ty: Type, op: std.math.CompareOperator, lhs_id: IdRef, rhs_id: IdRef) !IdRef {
         const info = self.arithmeticTypeInfo(result_ty);
 
-        var wip = try self.elementWise(result_ty);
+        var wip = try self.elementWise(result_ty, true);
         defer wip.deinit();
         for (wip.results, 0..) |*result_id, i| {
             const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i);
             const rhs_elem_id = try wip.elementAt(result_ty, rhs_id, i);
 
             // TODO: Use fmin for OpenCL
-            const cmp_id = try self.cmp(op, Type.bool, wip.scalar_ty, lhs_elem_id, rhs_elem_id);
+            const cmp_id = try self.cmp(op, Type.bool, wip.ty, lhs_elem_id, rhs_elem_id);
             const selection_id = switch (info.class) {
                 .float => blk: {
                     // cmp uses OpFOrd. When we have 0 [<>] nan this returns false,
@@ -2440,7 +2479,7 @@ const DeclGen = struct {
 
             result_id.* = self.spv.allocId();
             try self.func.body.emit(self.spv.gpa, .OpSelect, .{
-                .id_result_type = wip.scalar_ty_id,
+                .id_result_type = wip.ty_id,
                 .id_result = result_id.*,
                 .condition = selection_id,
                 .object_1 = lhs_elem_id,
@@ -2545,7 +2584,7 @@ const DeclGen = struct {
             .bool => unreachable,
         };
 
-        var wip = try self.elementWise(ty);
+        var wip = try self.elementWise(ty, false);
         defer wip.deinit();
         for (wip.results, 0..) |*result_id, i| {
             const lhs_elem_id = try wip.elementAt(ty, lhs_id, i);
@@ -2553,7 +2592,7 @@ const DeclGen = struct {
 
             const value_id = self.spv.allocId();
             const operands = .{
-                .id_result_type = wip.scalar_ty_id,
+                .id_result_type = wip.ty_id,
                 .id_result = value_id,
                 .operand_1 = lhs_elem_id,
                 .operand_2 = rhs_elem_id,
@@ -2568,7 +2607,7 @@ const DeclGen = struct {
 
             // TODO: Trap on overflow? Probably going to be annoying.
             // TODO: Look into SPV_KHR_no_integer_wrap_decoration which provides NoSignedWrap/NoUnsignedWrap.
-            result_id.* = try self.normalize(wip.scalar_ty_ref, value_id, info);
+            result_id.* = try self.normalize(wip.ty_ref, value_id, info);
         }
 
         return try wip.finalize();
@@ -2582,12 +2621,12 @@ const DeclGen = struct {
         const operand_id = try self.resolve(ty_op.operand);
         // Note: operand_ty may be signed, while ty is always unsigned!
         const operand_ty = self.typeOf(ty_op.operand);
-        const ty = self.typeOfIndex(inst);
-        const info = self.arithmeticTypeInfo(ty);
+        const result_ty = self.typeOfIndex(inst);
+        const info = self.arithmeticTypeInfo(result_ty);
         const operand_scalar_ty = operand_ty.scalarType(mod);
         const operand_scalar_ty_ref = try self.resolveType(operand_scalar_ty, .direct);
 
-        var wip = try self.elementWise(ty);
+        var wip = try self.elementWise(result_ty, true);
         defer wip.deinit();
 
         const zero_id = switch (info.class) {
@@ -2615,7 +2654,7 @@ const DeclGen = struct {
                 .composite_integer => unreachable, // TODO
                 .bool => unreachable,
             }
-            const neg_norm_id = try self.normalize(wip.scalar_ty_ref, neg_id, info);
+            const neg_norm_id = try self.normalize(wip.ty_ref, neg_id, info);
 
             const gt_zero_id = try self.cmp(.gt, Type.bool, operand_scalar_ty, elem_id, zero_id);
             const abs_id = self.spv.allocId();
@@ -2627,7 +2666,7 @@ const DeclGen = struct {
                 .object_2 = neg_norm_id,
             });
             // For Shader, we may need to cast from signed to unsigned here.
-            result_id.* = try self.bitCast(wip.scalar_ty, operand_scalar_ty, abs_id);
+            result_id.* = try self.bitCast(wip.ty, operand_scalar_ty, abs_id);
         }
         return try wip.finalize();
     }
@@ -2641,6 +2680,7 @@ const DeclGen = struct {
     ) !?IdRef {
         if (self.liveness.isUnused(inst)) return null;
 
+        const mod = self.module;
         const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
         const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;
         const lhs = try self.resolve(extra.lhs);
@@ -2651,6 +2691,10 @@ const DeclGen = struct {
         const ov_ty = result_ty.structFieldType(1, self.module);
 
         const bool_ty_ref = try self.resolveType(Type.bool, .direct);
+        const cmp_ty_ref = if (self.isVector(operand_ty))
+            try self.spv.vectorType(operand_ty.vectorLen(mod), bool_ty_ref)
+        else
+            bool_ty_ref;
 
         const info = self.arithmeticTypeInfo(operand_ty);
         switch (info.class) {
@@ -2659,9 +2703,9 @@ const DeclGen = struct {
             .float, .bool => unreachable,
         }
 
-        var wip_result = try self.elementWise(operand_ty);
+        var wip_result = try self.elementWise(operand_ty, false);
         defer wip_result.deinit();
-        var wip_ov = try self.elementWise(ov_ty);
+        var wip_ov = try self.elementWise(ov_ty, false);
         defer wip_ov.deinit();
         for (wip_result.results, wip_ov.results, 0..) |*result_id, *ov_id, i| {
             const lhs_elem_id = try wip_result.elementAt(operand_ty, lhs, i);
@@ -2671,14 +2715,14 @@ const DeclGen = struct {
             const value_id = self.spv.allocId();
 
             try self.func.body.emit(self.spv.gpa, add, .{
-                .id_result_type = wip_result.scalar_ty_id,
+                .id_result_type = wip_result.ty_id,
                 .id_result = value_id,
                 .operand_1 = lhs_elem_id,
                 .operand_2 = rhs_elem_id,
             });
 
             // Normalize the result so that the comparisons go well
-            result_id.* = try self.normalize(wip_result.scalar_ty_ref, value_id, info);
+            result_id.* = try self.normalize(wip_result.ty_ref, value_id, info);
 
             const overflowed_id = switch (info.signedness) {
                 .unsigned => blk: {
@@ -2686,7 +2730,7 @@ const DeclGen = struct {
                     // For subtraction the conditions need to be swapped.
                     const overflowed_id = self.spv.allocId();
                     try self.func.body.emit(self.spv.gpa, ucmp, .{
-                        .id_result_type = self.typeId(bool_ty_ref),
+                        .id_result_type = self.typeId(cmp_ty_ref),
                         .id_result = overflowed_id,
                         .operand_1 = result_id.*,
                         .operand_2 = lhs_elem_id,
@@ -2712,9 +2756,9 @@ const DeclGen = struct {
                     // = (rhs < 0) == (lhs > value)
 
                     const rhs_lt_zero_id = self.spv.allocId();
-                    const zero_id = try self.constInt(wip_result.scalar_ty_ref, 0);
+                    const zero_id = try self.constInt(wip_result.ty_ref, 0);
                     try self.func.body.emit(self.spv.gpa, .OpSLessThan, .{
-                        .id_result_type = self.typeId(bool_ty_ref),
+                        .id_result_type = self.typeId(cmp_ty_ref),
                         .id_result = rhs_lt_zero_id,
                         .operand_1 = rhs_elem_id,
                         .operand_2 = zero_id,
@@ -2722,7 +2766,7 @@ const DeclGen = struct {
 
                     const value_gt_lhs_id = self.spv.allocId();
                     try self.func.body.emit(self.spv.gpa, scmp, .{
-                        .id_result_type = self.typeId(bool_ty_ref),
+                        .id_result_type = self.typeId(cmp_ty_ref),
                         .id_result = value_gt_lhs_id,
                         .operand_1 = lhs_elem_id,
                         .operand_2 = result_id.*,
@@ -2730,7 +2774,7 @@ const DeclGen = struct {
 
                     const overflowed_id = self.spv.allocId();
                     try self.func.body.emit(self.spv.gpa, .OpLogicalEqual, .{
-                        .id_result_type = self.typeId(bool_ty_ref),
+                        .id_result_type = self.typeId(cmp_ty_ref),
                         .id_result = overflowed_id,
                         .operand_1 = rhs_lt_zero_id,
                         .operand_2 = value_gt_lhs_id,
@@ -2739,7 +2783,7 @@ const DeclGen = struct {
                 },
             };
 
-            ov_id.* = try self.intFromBool(wip_ov.scalar_ty_ref, overflowed_id);
+            ov_id.* = try self.intFromBool(wip_ov.ty_ref, overflowed_id);
         }
 
         return try self.constructComposite(
@@ -2759,11 +2803,15 @@ const DeclGen = struct {
         const result_ty = self.typeOfIndex(inst);
         const operand_ty = self.typeOf(extra.lhs);
         const shift_ty = self.typeOf(extra.rhs);
-        const scalar_shift_ty_ref = try self.resolveType(shift_ty.scalarType(mod), .direct);
+        const shift_ty_ref = try self.resolveType(shift_ty, .direct);
 
         const ov_ty = result_ty.structFieldType(1, self.module);
 
         const bool_ty_ref = try self.resolveType(Type.bool, .direct);
+        const cmp_ty_ref = if (self.isVector(operand_ty))
+            try self.spv.vectorType(operand_ty.vectorLen(mod), bool_ty_ref)
+        else
+            bool_ty_ref;
 
         const info = self.arithmeticTypeInfo(operand_ty);
         switch (info.class) {
@@ -2772,9 +2820,9 @@ const DeclGen = struct {
             .float, .bool => unreachable,
         }
 
-        var wip_result = try self.elementWise(operand_ty);
+        var wip_result = try self.elementWise(operand_ty, false);
         defer wip_result.deinit();
-        var wip_ov = try self.elementWise(ov_ty);
+        var wip_ov = try self.elementWise(ov_ty, false);
         defer wip_ov.deinit();
         for (wip_result.results, wip_ov.results, 0..) |*result_id, *ov_id, i| {
             const lhs_elem_id = try wip_result.elementAt(operand_ty, lhs, i);
@@ -2782,10 +2830,10 @@ const DeclGen = struct {
 
             // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that,
             // so just manually upcast it if required.
-            const shift_id = if (scalar_shift_ty_ref != wip_result.scalar_ty_ref) blk: {
+            const shift_id = if (shift_ty_ref != wip_result.ty_ref) blk: {
                 const shift_id = self.spv.allocId();
                 try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
-                    .id_result_type = wip_result.scalar_ty_id,
+                    .id_result_type = wip_result.ty_id,
                     .id_result = shift_id,
                     .unsigned_value = rhs_elem_id,
                 });
@@ -2794,18 +2842,18 @@ const DeclGen = struct {
 
             const value_id = self.spv.allocId();
             try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{
-                .id_result_type = wip_result.scalar_ty_id,
+                .id_result_type = wip_result.ty_id,
                 .id_result = value_id,
                 .base = lhs_elem_id,
                 .shift = shift_id,
             });
-            result_id.* = try self.normalize(wip_result.scalar_ty_ref, value_id, info);
+            result_id.* = try self.normalize(wip_result.ty_ref, value_id, info);
 
             const right_shift_id = self.spv.allocId();
             switch (info.signedness) {
                 .signed => {
                     try self.func.body.emit(self.spv.gpa, .OpShiftRightArithmetic, .{
-                        .id_result_type = wip_result.scalar_ty_id,
+                        .id_result_type = wip_result.ty_id,
                         .id_result = right_shift_id,
                         .base = result_id.*,
                         .shift = shift_id,
@@ -2813,7 +2861,7 @@ const DeclGen = struct {
                 },
                 .unsigned => {
                     try self.func.body.emit(self.spv.gpa, .OpShiftRightLogical, .{
-                        .id_result_type = wip_result.scalar_ty_id,
+                        .id_result_type = wip_result.ty_id,
                         .id_result = right_shift_id,
                         .base = result_id.*,
                         .shift = shift_id,
@@ -2823,13 +2871,13 @@ const DeclGen = struct {
 
             const overflowed_id = self.spv.allocId();
             try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{
-                .id_result_type = self.typeId(bool_ty_ref),
+                .id_result_type = self.typeId(cmp_ty_ref),
                 .id_result = overflowed_id,
                 .operand_1 = lhs_elem_id,
                 .operand_2 = right_shift_id,
             });
 
-            ov_id.* = try self.intFromBool(wip_ov.scalar_ty_ref, overflowed_id);
+            ov_id.* = try self.intFromBool(wip_ov.ty_ref, overflowed_id);
         }
 
         return try self.constructComposite(
@@ -2853,19 +2901,19 @@ const DeclGen = struct {
         const info = self.arithmeticTypeInfo(ty);
         assert(info.class == .float); // .mul_add is only emitted for floats
 
-        var wip = try self.elementWise(ty);
+        var wip = try self.elementWise(ty, false);
         defer wip.deinit();
         for (0..wip.results.len) |i| {
             const mul_result = self.spv.allocId();
             try self.func.body.emit(self.spv.gpa, .OpFMul, .{
-                .id_result_type = wip.scalar_ty_id,
+                .id_result_type = wip.ty_id,
                 .id_result = mul_result,
                 .operand_1 = try wip.elementAt(ty, mulend1, i),
                 .operand_2 = try wip.elementAt(ty, mulend2, i),
             });
 
             try self.func.body.emit(self.spv.gpa, .OpFAdd, .{
-                .id_result_type = wip.scalar_ty_id,
+                .id_result_type = wip.ty_id,
                 .id_result = wip.allocId(i),
                 .operand_1 = mul_result,
                 .operand_2 = try wip.elementAt(ty, addend, i),
@@ -2879,11 +2927,9 @@ const DeclGen = struct {
         const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
         const operand_id = try self.resolve(ty_op.operand);
         const result_ty = self.typeOfIndex(inst);
-        var wip = try self.elementWise(result_ty);
+        var wip = try self.elementWise(result_ty, true);
         defer wip.deinit();
-        for (wip.results) |*result_id| {
-            result_id.* = operand_id;
-        }
+        @memset(wip.results, operand_id);
         return try wip.finalize();
     }
 
@@ -2965,20 +3011,20 @@ const DeclGen = struct {
 
         const ty = self.typeOfIndex(inst);
 
-        var wip = try self.elementWise(ty);
+        var wip = try self.elementWise(ty, true);
         defer wip.deinit();
         for (wip.results, 0..) |*result_id, i| {
             const elem = try mask.elemValue(mod, i);
             if (elem.isUndef(mod)) {
-                result_id.* = try self.spv.constUndef(wip.scalar_ty_ref);
+                result_id.* = try self.spv.constUndef(wip.ty_ref);
                 continue;
             }
 
             const index = elem.toSignedInt(mod);
             if (index >= 0) {
-                result_id.* = try self.extractField(wip.scalar_ty, a, @intCast(index));
+                result_id.* = try self.extractField(wip.ty, a, @intCast(index));
             } else {
-                result_id.* = try self.extractField(wip.scalar_ty, b, @intCast(~index));
+                result_id.* = try self.extractField(wip.ty, b, @intCast(~index));
             }
         }
         return try wip.finalize();
@@ -3188,7 +3234,7 @@ const DeclGen = struct {
                 return result_id;
             },
             .Vector => {
-                var wip = try self.elementWise(result_ty);
+                var wip = try self.elementWise(result_ty, true);
                 defer wip.deinit();
                 const scalar_ty = ty.scalarType(mod);
                 for (wip.results, 0..) |*result_id, i| {
@@ -3374,19 +3420,19 @@ const DeclGen = struct {
             return operand_id;
         }
 
-        var wip = try self.elementWise(dst_ty);
+        var wip = try self.elementWise(dst_ty, false);
         defer wip.deinit();
         for (wip.results, 0..) |*result_id, i| {
             const elem_id = try wip.elementAt(src_ty, operand_id, i);
             const value_id = self.spv.allocId();
             switch (dst_info.signedness) {
                 .signed => try self.func.body.emit(self.spv.gpa, .OpSConvert, .{
-                    .id_result_type = wip.scalar_ty_id,
+                    .id_result_type = wip.ty_id,
                     .id_result = value_id,
                     .signed_value = elem_id,
                 }),
                 .unsigned => try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
-                    .id_result_type = wip.scalar_ty_id,
+                    .id_result_type = wip.ty_id,
                     .id_result = value_id,
                     .unsigned_value = elem_id,
                 }),
@@ -3397,7 +3443,7 @@ const DeclGen = struct {
             // type, we don't need to normalize when growing the type. The
             // representation is already the same.
             if (dst_info.bits < src_info.bits) {
-                result_id.* = try self.normalize(wip.scalar_ty_ref, value_id, dst_info);
+                result_id.* = try self.normalize(wip.ty_ref, value_id, dst_info);
             } else {
                 result_id.* = value_id;
             }
@@ -3482,11 +3528,11 @@ const DeclGen = struct {
         const operand_id = try self.resolve(un_op);
         const result_ty = self.typeOfIndex(inst);
 
-        var wip = try self.elementWise(result_ty);
+        var wip = try self.elementWise(result_ty, false);
         defer wip.deinit();
         for (wip.results, 0..) |*result_id, i| {
             const elem_id = try wip.elementAt(Type.bool, operand_id, i);
-            result_id.* = try self.intFromBool(wip.scalar_ty_ref, elem_id);
+            result_id.* = try self.intFromBool(wip.ty_ref, elem_id);
         }
         return try wip.finalize();
     }
@@ -3515,12 +3561,12 @@ const DeclGen = struct {
         const result_ty = self.typeOfIndex(inst);
         const info = self.arithmeticTypeInfo(result_ty);
 
-        var wip = try self.elementWise(result_ty);
+        var wip = try self.elementWise(result_ty, false);
         defer wip.deinit();
 
         for (0..wip.results.len) |i| {
             const args = .{
-                .id_result_type = wip.scalar_ty_id,
+                .id_result_type = wip.ty_id,
                 .id_result = wip.allocId(i),
                 .operand = try wip.elementAt(result_ty, operand_id, i),
             };
@@ -3563,10 +3609,7 @@ const DeclGen = struct {
             // Convert the pointer-to-array to a pointer to the first element.
             try self.accessChain(elem_ptr_ty_ref, array_ptr_id, &.{0});
 
-        return try self.constructComposite(
-            slice_ty,
-            &.{ elem_ptr_id, len_id },
-        );
+        return try self.constructComposite(slice_ty, &.{ elem_ptr_id, len_id });
     }
 
     fn airSlice(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -3580,10 +3623,7 @@ const DeclGen = struct {
 
         // Note: Types should not need to be converted to direct, these types
         // dont need to be converted.
-        return try self.constructComposite(
-            slice_ty,
-            &.{ ptr_id, len_id },
-        );
+        return try self.constructComposite(slice_ty, &.{ ptr_id, len_id });
     }
 
     fn airAggregateInit(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -3641,9 +3681,9 @@ const DeclGen = struct {
                 const elem_ids = try self.gpa.alloc(IdRef, n_elems);
                 defer self.gpa.free(elem_ids);
 
-                for (elements, 0..) |element, i| {
+                for (elements, elem_ids) |element, *elem_id| {
                     const id = try self.resolve(element);
-                    elem_ids[i] = try self.convertToIndirect(result_ty.childType(mod), id);
+                    elem_id.* = try self.convertToIndirect(result_ty.childType(mod), id);
                 }
 
                 return try self.constructComposite(result_ty, elem_ids);
@@ -3654,9 +3694,9 @@ const DeclGen = struct {
                 const elem_ids = try self.gpa.alloc(IdRef, n_elems);
                 defer self.gpa.free(elem_ids);
 
-                for (elements, 0..) |element, i| {
+                for (elements, elem_ids) |element, *elem_id| {
                     const id = try self.resolve(element);
-                    elem_ids[i] = try self.convertToIndirect(array_info.elem_type, id);
+                    elem_id.* = try self.convertToIndirect(array_info.elem_type, id);
                 }
 
                 if (array_info.sentinel) |sentinel_val| {