Commit 4bd9d9b7e0

Robin Voetter <robin@voetter.nl>
2024-06-02 15:57:18
spirv: change direct vector child repr to direct
Previously the child type of a vector was always in indirect representation. Concretely, this meant that vectors of bools are represented by vectors of u8. This was undesirable because it introduced a difference between vectorizable operations with a scalar bool and a vector of bool. This commit changes the representation to be the same for vectors and scalars everywhere. Some issues arised with constructing vectors: it seems the previous temporary- and-pointer approach does not work properly with vectors of bool. To work around this, simply use OpCompositeConstruct. This is the proper instruction for this, but it was previously not used because of a now-solved limitation in the SPIRV-LLVM-Translator. It was not yet applied to Zig because the Intel OpenCL CPU runtime does not have a recent enough version of the translator yet, but to solve that we just switch to testing with POCL instead.
1 parent b9d738a
Changed files (1)
src
codegen
src/codegen/spirv.zig
@@ -22,6 +22,7 @@ const IdResultType = spec.IdResultType;
 const StorageClass = spec.StorageClass;
 
 const SpvModule = @import("spirv/Module.zig");
+const IdRange = SpvModule.IdRange;
 
 const SpvSection = @import("spirv/Section.zig");
 const SpvAssembler = @import("spirv/Assembler.zig");
@@ -32,7 +33,7 @@ pub const zig_call_abi_ver = 3;
 
 const InternMap = std.AutoHashMapUnmanaged(struct { InternPool.Index, DeclGen.Repr }, IdResult);
 const PtrTypeMap = std.AutoHashMapUnmanaged(
-    struct { InternPool.Index, StorageClass },
+    struct { InternPool.Index, StorageClass, DeclGen.Repr },
     struct { ty_id: IdRef, fwd_emitted: bool },
 );
 
@@ -626,7 +627,7 @@ const DeclGen = struct {
     }
 
     /// Checks whether the type can be directly translated to SPIR-V vectors
-    fn isVector(self: *DeclGen, ty: Type) bool {
+    fn isSpvVector(self: *DeclGen, ty: Type) bool {
         const mod = self.module;
         const target = self.getTarget();
         if (ty.zigTypeTag(mod) != .Vector) return false;
@@ -798,26 +799,39 @@ const DeclGen = struct {
 
     /// Construct a vector at runtime.
     /// ty must be an vector type.
-    /// Constituents should be in `indirect` representation (as the elements of an vector should be).
-    /// Result is in `direct` representation.
     fn constructVector(self: *DeclGen, ty: Type, constituents: []const IdRef) !IdRef {
-        // The Khronos LLVM-SPIRV translator crashes because it cannot construct structs which'
-        // operands are not constant.
+        const mod = self.module;
+        assert(ty.vectorLen(mod) == constituents.len);
+
+        // Note: older versions of the Khronos SPRIV-LLVM translator crash on this instruction
+        // because it cannot construct structs which' operands are not constant.
         // See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349
-        // For now, just initialize the struct by setting the fields manually...
-        // TODO: Make this OpCompositeConstruct when we can
+        // Currently this is the case for Intel OpenCL CPU runtime (2023-WW46), but the
+        // alternatives dont work properly:
+        // - using temporaries/pointers doesn't work properly with vectors of bool, causes
+        //   backends that use llvm to crash
+        // - using OpVectorInsertDynamic doesn't work for non-spirv-vectors of bool.
+
+        const result_id = self.spv.allocId();
+        try self.func.body.emit(self.spv.gpa, .OpCompositeConstruct, .{
+            .id_result_type = try self.resolveType(ty, .direct),
+            .id_result = result_id,
+            .constituents = constituents,
+        });
+        return result_id;
+    }
+
+    /// Construct a vector at runtime with all lanes set to the same value.
+    /// ty must be an vector type.
+    fn constructVectorSplat(self: *DeclGen, ty: Type, constituent: IdRef) !IdRef {
         const mod = self.module;
-        const ptr_composite_id = try self.alloc(ty, .{ .storage_class = .Function });
-        const ptr_elem_ty_id = try self.ptrType(ty.elemType2(mod), .Function);
-        for (constituents, 0..) |constitent_id, index| {
-            const ptr_id = try self.accessChain(ptr_elem_ty_id, ptr_composite_id, &.{@as(u32, @intCast(index))});
-            try self.func.body.emit(self.spv.gpa, .OpStore, .{
-                .pointer = ptr_id,
-                .object = constitent_id,
-            });
-        }
+        const n = ty.vectorLen(mod);
 
-        return try self.load(ty, ptr_composite_id, .{});
+        const constituents = try self.gpa.alloc(IdRef, n);
+        defer self.gpa.free(constituents);
+        @memset(constituents, constituent);
+
+        return try self.constructVector(ty, constituents);
     }
 
     /// Construct an array at runtime.
@@ -1031,21 +1045,27 @@ const DeclGen = struct {
                         const constituents = try self.gpa.alloc(IdRef, @intCast(ty.arrayLenIncludingSentinel(mod)));
                         defer self.gpa.free(constituents);
 
+                        const child_repr: Repr = switch (tag) {
+                            .array_type => .indirect,
+                            .vector_type => .direct,
+                            else => unreachable,
+                        };
+
                         switch (aggregate.storage) {
                             .bytes => |bytes| {
                                 // TODO: This is really space inefficient, perhaps there is a better
                                 // way to do it?
                                 for (constituents, bytes.toSlice(constituents.len, ip)) |*constituent, byte| {
-                                    constituent.* = try self.constInt(elem_ty, byte, .indirect);
+                                    constituent.* = try self.constInt(elem_ty, byte, child_repr);
                                 }
                             },
                             .elems => |elems| {
                                 for (constituents, elems) |*constituent, elem| {
-                                    constituent.* = try self.constant(elem_ty, Value.fromInterned(elem), .indirect);
+                                    constituent.* = try self.constant(elem_ty, Value.fromInterned(elem), child_repr);
                                 }
                             },
                             .repeated_elem => |elem| {
-                                @memset(constituents, try self.constant(elem_ty, Value.fromInterned(elem), .indirect));
+                                @memset(constituents, try self.constant(elem_ty, Value.fromInterned(elem), child_repr));
                             },
                         }
 
@@ -1334,7 +1354,11 @@ const DeclGen = struct {
     }
 
     fn ptrType(self: *DeclGen, child_ty: Type, storage_class: StorageClass) !IdRef {
-        const key = .{ child_ty.toIntern(), storage_class };
+        return try self.ptrType2(child_ty, storage_class, .indirect);
+    }
+
+    fn ptrType2(self: *DeclGen, child_ty: Type, storage_class: StorageClass, child_repr: Repr) !IdRef {
+        const key = .{ child_ty.toIntern(), storage_class, child_repr };
         const entry = try self.ptr_types.getOrPut(self.gpa, key);
         if (entry.found_existing) {
             const fwd_id = entry.value_ptr.ty_id;
@@ -1354,7 +1378,7 @@ const DeclGen = struct {
             .fwd_emitted = false,
         };
 
-        const child_ty_id = try self.resolveType(child_ty, .indirect);
+        const child_ty_id = try self.resolveType(child_ty, child_repr);
 
         try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpTypePointer, .{
             .id_result = result_id,
@@ -1645,11 +1669,10 @@ const DeclGen = struct {
             },
             .Vector => {
                 const elem_ty = ty.childType(mod);
-                // TODO: Make `.direct`.
-                const elem_ty_id = try self.resolveType(elem_ty, .indirect);
+                const elem_ty_id = try self.resolveType(elem_ty, repr);
                 const len = ty.vectorLen(mod);
 
-                if (self.isVector(ty)) {
+                if (self.isSpvVector(ty)) {
                     return try self.spv.vectorType(len, elem_ty_id);
                 } else {
                     return try self.arrayType(len, elem_ty_id);
@@ -1948,7 +1971,7 @@ const DeclGen = struct {
             const mod = wip.dg.module;
             if (wip.is_array) {
                 assert(ty.isVector(mod));
-                return try wip.dg.extractField(ty.childType(mod), value, @intCast(index));
+                return try wip.dg.extractVectorComponent(ty.childType(mod), value, @intCast(index));
             } else {
                 assert(index == 0);
                 return value;
@@ -1961,11 +1984,7 @@ const DeclGen = struct {
         /// Results is in `direct` representation.
         fn finalize(wip: *WipElementWise) !IdRef {
             if (wip.is_array) {
-                // Convert all the constituents to indirect, as required for the array.
-                for (wip.results) |*result| {
-                    result.* = try wip.dg.convertToIndirect(wip.ty, result.*);
-                }
-                return try wip.dg.constructArray(wip.result_ty, wip.results);
+                return try wip.dg.constructVector(wip.result_ty, wip.results);
             } else {
                 return wip.results[0];
             }
@@ -1982,7 +2001,7 @@ const DeclGen = struct {
     /// Create a new element-wise operation.
     fn elementWise(self: *DeclGen, result_ty: Type, force_element_wise: bool) !WipElementWise {
         const mod = self.module;
-        const is_array = result_ty.isVector(mod) and (!self.isVector(result_ty) or force_element_wise);
+        const is_array = result_ty.isVector(mod) and (!self.isSpvVector(result_ty) or force_element_wise);
         const num_results = if (is_array) result_ty.vectorLen(mod) else 1;
         const results = try self.gpa.alloc(IdRef, num_results);
         @memset(results, undefined);
@@ -2253,29 +2272,102 @@ const DeclGen = struct {
     /// This converts the argument type from resolveType(ty, .indirect) to resolveType(ty, .direct).
     fn convertToDirect(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef {
         const mod = self.module;
-        return switch (ty.zigTypeTag(mod)) {
-            .Bool => blk: {
-                const result_id = self.spv.allocId();
-                try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{
-                    .id_result_type = try self.resolveType(Type.bool, .direct),
-                    .id_result = result_id,
-                    .operand_1 = operand_id,
-                    .operand_2 = try self.constBool(false, .indirect),
-                });
-                break :blk result_id;
+        const scalar_ty = ty.scalarType(mod);
+        const is_spv_vector = self.isSpvVector(ty);
+        switch (scalar_ty.zigTypeTag(mod)) {
+            .Bool => {
+                // TODO: We may want to use something like elementWise in this function.
+                // First we need to audit whether this would recursively call into itself.
+                if (!ty.isVector(mod) or is_spv_vector) {
+                    const result_id = self.spv.allocId();
+                    const scalar_false_id = try self.constBool(false, .indirect);
+                    const false_id = if (is_spv_vector) blk: {
+                        const index = try mod.intern_pool.get(mod.gpa, .{
+                            .vector_type = .{
+                                .len = ty.vectorLen(mod),
+                                .child = Type.u1.toIntern(),
+                            },
+                        });
+                        const vec_ty = Type.fromInterned(index);
+                        break :blk try self.constructVectorSplat(vec_ty, scalar_false_id);
+                    } else scalar_false_id;
+
+                    try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{
+                        .id_result_type = try self.resolveType(ty, .direct),
+                        .id_result = result_id,
+                        .operand_1 = operand_id,
+                        .operand_2 = false_id,
+                    });
+                    return result_id;
+                }
+
+                const constituents = try self.gpa.alloc(IdRef, ty.vectorLen(mod));
+                for (constituents, 0..) |*id, i| {
+                    const element = try self.extractVectorComponent(scalar_ty, operand_id, @intCast(i));
+                    id.* = try self.convertToDirect(scalar_ty, element);
+                }
+                return try self.constructVector(ty, constituents);
             },
-            else => operand_id,
-        };
+            else => return operand_id,
+        }
     }
 
     /// Convert representation from direct (in 'register) to direct (in memory)
     /// This converts the argument type from resolveType(ty, .direct) to resolveType(ty, .indirect).
     fn convertToIndirect(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef {
         const mod = self.module;
-        return switch (ty.zigTypeTag(mod)) {
-            .Bool => try self.intFromBool(Type.u1, operand_id),
-            else => operand_id,
-        };
+        const scalar_ty = ty.scalarType(mod);
+        const is_spv_vector = self.isSpvVector(ty);
+        switch (scalar_ty.zigTypeTag(mod)) {
+            .Bool => {
+                const result_ty = if (is_spv_vector) blk: {
+                    const index = try mod.intern_pool.get(mod.gpa, .{
+                        .vector_type = .{
+                            .len = ty.vectorLen(mod),
+                            .child = Type.u1.toIntern(),
+                        },
+                    });
+                    break :blk Type.fromInterned(index);
+                } else Type.u1;
+
+                if (!ty.isVector(mod) or is_spv_vector) {
+                    // TODO: We may want to use something like elementWise in this function.
+                    // First we need to audit whether this would recursively call into itself.
+                    // Also unify it with intFromBool
+
+                    const scalar_zero_id = try self.constInt(Type.u1, 0, .direct);
+                    const scalar_one_id = try self.constInt(Type.u1, 1, .direct);
+
+                    const zero_id = if (is_spv_vector)
+                        try self.constructVectorSplat(result_ty, scalar_zero_id)
+                    else
+                        scalar_zero_id;
+
+                    const one_id = if (is_spv_vector)
+                        try self.constructVectorSplat(result_ty, scalar_one_id)
+                    else
+                        scalar_one_id;
+
+                    const result_id = self.spv.allocId();
+                    try self.func.body.emit(self.spv.gpa, .OpSelect, .{
+                        .id_result_type = try self.resolveType(result_ty, .direct),
+                        .id_result = result_id,
+                        .condition = operand_id,
+                        .object_1 = one_id,
+                        .object_2 = zero_id,
+                    });
+                    return result_id;
+                }
+
+                const constituents = try self.gpa.alloc(IdRef, ty.vectorLen(mod));
+                for (constituents, 0..) |*id, i| {
+                    const element = try self.extractVectorComponent(scalar_ty, operand_id, @intCast(i));
+                    id.* = try self.convertToIndirect(scalar_ty, element);
+                }
+                return try self.constructVector(result_ty, constituents);
+            },
+            else => return operand_id,
+        }
     }
 
     fn extractField(self: *DeclGen, result_ty: Type, object: IdRef, field: u32) !IdRef {
@@ -2292,6 +2384,21 @@ const DeclGen = struct {
         return try self.convertToDirect(result_ty, result_id);
     }
 
+    fn extractVectorComponent(self: *DeclGen, result_ty: Type, vector_id: IdRef, field: u32) !IdRef {
+        // Whether this is an OpTypeVector or OpTypeArray, we need to emit the same instruction regardless.
+        const result_ty_id = try self.resolveType(result_ty, .direct);
+        const result_id = self.spv.allocId();
+        const indexes = [_]u32{field};
+        try self.func.body.emit(self.spv.gpa, .OpCompositeExtract, .{
+            .id_result_type = result_ty_id,
+            .id_result = result_id,
+            .composite = vector_id,
+            .indexes = &indexes,
+        });
+        // Vector components are already stored in direct representation.
+        return result_id;
+    }
+
     const MemoryOptions = struct {
         is_volatile: bool = false,
     };
@@ -2926,7 +3033,7 @@ const DeclGen = struct {
         const ov_ty = result_ty.structFieldType(1, self.module);
 
         const bool_ty_id = try self.resolveType(Type.bool, .direct);
-        const cmp_ty_id = if (self.isVector(operand_ty))
+        const cmp_ty_id = if (self.isSpvVector(operand_ty))
             // TODO: Resolving a vector type with .direct should return a SPIR-V vector
             try self.spv.vectorType(operand_ty.vectorLen(mod), try self.resolveType(Type.bool, .direct))
         else
@@ -3100,7 +3207,7 @@ const DeclGen = struct {
         const ov_ty = result_ty.structFieldType(1, self.module);
 
         const bool_ty_id = try self.resolveType(Type.bool, .direct);
-        const cmp_ty_id = if (self.isVector(operand_ty))
+        const cmp_ty_id = if (self.isSpvVector(operand_ty))
             // TODO: Resolving a vector type with .direct should return a SPIR-V vector
             try self.spv.vectorType(operand_ty.vectorLen(mod), try self.resolveType(Type.bool, .direct))
         else
@@ -3312,7 +3419,7 @@ const DeclGen = struct {
 
         const info = self.arithmeticTypeInfo(operand_ty);
 
-        var result_id = try self.extractField(scalar_ty, operand, 0);
+        var result_id = try self.extractVectorComponent(scalar_ty, operand, 0);
         const len = operand_ty.vectorLen(mod);
 
         switch (reduce.operation) {
@@ -3320,7 +3427,7 @@ const DeclGen = struct {
                 const cmp_op: std.math.CompareOperator = if (op == .Max) .gt else .lt;
                 for (1..len) |i| {
                     const lhs = result_id;
-                    const rhs = try self.extractField(scalar_ty, operand, @intCast(i));
+                    const rhs = try self.extractVectorComponent(scalar_ty, operand, @intCast(i));
                     result_id = try self.minMax(scalar_ty, cmp_op, lhs, rhs);
                 }
 
@@ -3354,7 +3461,7 @@ const DeclGen = struct {
 
         for (1..len) |i| {
             const lhs = result_id;
-            const rhs = try self.extractField(scalar_ty, operand, @intCast(i));
+            const rhs = try self.extractVectorComponent(scalar_ty, operand, @intCast(i));
             result_id = self.spv.allocId();
 
             try self.func.body.emitRaw(self.spv.gpa, opcode, 4);
@@ -3388,9 +3495,9 @@ const DeclGen = struct {
 
             const index = elem.toSignedInt(mod);
             if (index >= 0) {
-                result_id.* = try self.extractField(wip.ty, a, @intCast(index));
+                result_id.* = try self.extractVectorComponent(wip.ty, a, @intCast(index));
             } else {
-                result_id.* = try self.extractField(wip.ty, b, @intCast(~index));
+                result_id.* = try self.extractVectorComponent(wip.ty, b, @intCast(~index));
             }
         }
         return try wip.finalize();
@@ -4086,8 +4193,7 @@ const DeclGen = struct {
                 defer self.gpa.free(elem_ids);
 
                 for (elements, 0..) |element, i| {
-                    const id = try self.resolve(element);
-                    elem_ids[i] = try self.convertToIndirect(result_ty.childType(mod), id);
+                    elem_ids[i] = try self.resolve(element);
                 }
 
                 return try self.constructVector(result_ty, elem_ids);
@@ -4234,16 +4340,54 @@ const DeclGen = struct {
         const array_id = try self.resolve(bin_op.lhs);
         const index_id = try self.resolve(bin_op.rhs);
 
+        if (self.isSpvVector(array_ty)) {
+            const result_id = self.spv.allocId();
+            try self.func.body.emit(self.spv.gpa, .OpVectorExtractDynamic, .{
+                .id_result_type = try self.resolveType(elem_ty, .direct),
+                .id_result = result_id,
+                .vector = array_id,
+                .index = index_id,
+            });
+            return result_id;
+        }
+
         // SPIR-V doesn't have an array indexing function for some damn reason.
         // For now, just generate a temporary and use that.
         // TODO: This backend probably also should use isByRef from llvm...
 
-        const elem_ptr_ty_id = try self.ptrType(elem_ty, .Function);
+        const ptr_array_ty_id = try self.ptrType2(array_ty, .Function, .direct);
+        const ptr_elem_ty_id = try self.ptrType2(elem_ty, .Function, .direct);
+
+        const tmp_id = self.spv.allocId();
+        try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
+            .id_result_type = ptr_array_ty_id,
+            .id_result = tmp_id,
+            .storage_class = .Function,
+        });
+
+        try self.func.body.emit(self.spv.gpa, .OpStore, .{
+            .pointer = tmp_id,
+            .object = array_id,
+        });
+
+        const elem_ptr_id = try self.accessChainId(ptr_elem_ty_id, tmp_id, &.{index_id});
+
+        const result_id = self.spv.allocId();
+        try self.func.body.emit(self.spv.gpa, .OpLoad, .{
+            .id_result_type = try self.resolveType(elem_ty, .direct),
+            .id_result = result_id,
+            .pointer = elem_ptr_id,
+        });
+
+        if (array_ty.isVector(mod)) {
+            // Result is already in direct representation
+            return result_id;
+        }
+
+        // This is an array type; the elements are stored in indirect representation.
+        // We have to convert the type to direct.
 
-        const tmp_id = try self.alloc(array_ty, .{ .storage_class = .Function });
-        try self.store(array_ty, tmp_id, array_id, .{});
-        const elem_ptr_id = try self.accessChainId(elem_ptr_ty_id, tmp_id, &.{index_id});
-        return try self.load(elem_ty, elem_ptr_id, .{});
+        return try self.convertToDirect(elem_ty, result_id);
     }
 
     fn airPtrElemVal(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {