Commit 4bd9d9b7e0
Changed files (1)
src
codegen
src/codegen/spirv.zig
@@ -22,6 +22,7 @@ const IdResultType = spec.IdResultType;
const StorageClass = spec.StorageClass;
const SpvModule = @import("spirv/Module.zig");
+const IdRange = SpvModule.IdRange;
const SpvSection = @import("spirv/Section.zig");
const SpvAssembler = @import("spirv/Assembler.zig");
@@ -32,7 +33,7 @@ pub const zig_call_abi_ver = 3;
const InternMap = std.AutoHashMapUnmanaged(struct { InternPool.Index, DeclGen.Repr }, IdResult);
const PtrTypeMap = std.AutoHashMapUnmanaged(
- struct { InternPool.Index, StorageClass },
+ struct { InternPool.Index, StorageClass, DeclGen.Repr },
struct { ty_id: IdRef, fwd_emitted: bool },
);
@@ -626,7 +627,7 @@ const DeclGen = struct {
}
/// Checks whether the type can be directly translated to SPIR-V vectors
- fn isVector(self: *DeclGen, ty: Type) bool {
+ fn isSpvVector(self: *DeclGen, ty: Type) bool {
const mod = self.module;
const target = self.getTarget();
if (ty.zigTypeTag(mod) != .Vector) return false;
@@ -798,26 +799,39 @@ const DeclGen = struct {
/// Construct a vector at runtime.
/// ty must be an vector type.
- /// Constituents should be in `indirect` representation (as the elements of an vector should be).
- /// Result is in `direct` representation.
fn constructVector(self: *DeclGen, ty: Type, constituents: []const IdRef) !IdRef {
- // The Khronos LLVM-SPIRV translator crashes because it cannot construct structs which'
- // operands are not constant.
+ const mod = self.module;
+ assert(ty.vectorLen(mod) == constituents.len);
+
+ // Note: older versions of the Khronos SPRIV-LLVM translator crash on this instruction
+ // because it cannot construct structs which' operands are not constant.
// See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349
- // For now, just initialize the struct by setting the fields manually...
- // TODO: Make this OpCompositeConstruct when we can
+ // Currently this is the case for Intel OpenCL CPU runtime (2023-WW46), but the
+ // alternatives dont work properly:
+ // - using temporaries/pointers doesn't work properly with vectors of bool, causes
+ // backends that use llvm to crash
+ // - using OpVectorInsertDynamic doesn't work for non-spirv-vectors of bool.
+
+ const result_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpCompositeConstruct, .{
+ .id_result_type = try self.resolveType(ty, .direct),
+ .id_result = result_id,
+ .constituents = constituents,
+ });
+ return result_id;
+ }
+
+ /// Construct a vector at runtime with all lanes set to the same value.
+ /// ty must be an vector type.
+ fn constructVectorSplat(self: *DeclGen, ty: Type, constituent: IdRef) !IdRef {
const mod = self.module;
- const ptr_composite_id = try self.alloc(ty, .{ .storage_class = .Function });
- const ptr_elem_ty_id = try self.ptrType(ty.elemType2(mod), .Function);
- for (constituents, 0..) |constitent_id, index| {
- const ptr_id = try self.accessChain(ptr_elem_ty_id, ptr_composite_id, &.{@as(u32, @intCast(index))});
- try self.func.body.emit(self.spv.gpa, .OpStore, .{
- .pointer = ptr_id,
- .object = constitent_id,
- });
- }
+ const n = ty.vectorLen(mod);
- return try self.load(ty, ptr_composite_id, .{});
+ const constituents = try self.gpa.alloc(IdRef, n);
+ defer self.gpa.free(constituents);
+ @memset(constituents, constituent);
+
+ return try self.constructVector(ty, constituents);
}
/// Construct an array at runtime.
@@ -1031,21 +1045,27 @@ const DeclGen = struct {
const constituents = try self.gpa.alloc(IdRef, @intCast(ty.arrayLenIncludingSentinel(mod)));
defer self.gpa.free(constituents);
+ const child_repr: Repr = switch (tag) {
+ .array_type => .indirect,
+ .vector_type => .direct,
+ else => unreachable,
+ };
+
switch (aggregate.storage) {
.bytes => |bytes| {
// TODO: This is really space inefficient, perhaps there is a better
// way to do it?
for (constituents, bytes.toSlice(constituents.len, ip)) |*constituent, byte| {
- constituent.* = try self.constInt(elem_ty, byte, .indirect);
+ constituent.* = try self.constInt(elem_ty, byte, child_repr);
}
},
.elems => |elems| {
for (constituents, elems) |*constituent, elem| {
- constituent.* = try self.constant(elem_ty, Value.fromInterned(elem), .indirect);
+ constituent.* = try self.constant(elem_ty, Value.fromInterned(elem), child_repr);
}
},
.repeated_elem => |elem| {
- @memset(constituents, try self.constant(elem_ty, Value.fromInterned(elem), .indirect));
+ @memset(constituents, try self.constant(elem_ty, Value.fromInterned(elem), child_repr));
},
}
@@ -1334,7 +1354,11 @@ const DeclGen = struct {
}
fn ptrType(self: *DeclGen, child_ty: Type, storage_class: StorageClass) !IdRef {
- const key = .{ child_ty.toIntern(), storage_class };
+ return try self.ptrType2(child_ty, storage_class, .indirect);
+ }
+
+ fn ptrType2(self: *DeclGen, child_ty: Type, storage_class: StorageClass, child_repr: Repr) !IdRef {
+ const key = .{ child_ty.toIntern(), storage_class, child_repr };
const entry = try self.ptr_types.getOrPut(self.gpa, key);
if (entry.found_existing) {
const fwd_id = entry.value_ptr.ty_id;
@@ -1354,7 +1378,7 @@ const DeclGen = struct {
.fwd_emitted = false,
};
- const child_ty_id = try self.resolveType(child_ty, .indirect);
+ const child_ty_id = try self.resolveType(child_ty, child_repr);
try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpTypePointer, .{
.id_result = result_id,
@@ -1645,11 +1669,10 @@ const DeclGen = struct {
},
.Vector => {
const elem_ty = ty.childType(mod);
- // TODO: Make `.direct`.
- const elem_ty_id = try self.resolveType(elem_ty, .indirect);
+ const elem_ty_id = try self.resolveType(elem_ty, repr);
const len = ty.vectorLen(mod);
- if (self.isVector(ty)) {
+ if (self.isSpvVector(ty)) {
return try self.spv.vectorType(len, elem_ty_id);
} else {
return try self.arrayType(len, elem_ty_id);
@@ -1948,7 +1971,7 @@ const DeclGen = struct {
const mod = wip.dg.module;
if (wip.is_array) {
assert(ty.isVector(mod));
- return try wip.dg.extractField(ty.childType(mod), value, @intCast(index));
+ return try wip.dg.extractVectorComponent(ty.childType(mod), value, @intCast(index));
} else {
assert(index == 0);
return value;
@@ -1961,11 +1984,7 @@ const DeclGen = struct {
/// Results is in `direct` representation.
fn finalize(wip: *WipElementWise) !IdRef {
if (wip.is_array) {
- // Convert all the constituents to indirect, as required for the array.
- for (wip.results) |*result| {
- result.* = try wip.dg.convertToIndirect(wip.ty, result.*);
- }
- return try wip.dg.constructArray(wip.result_ty, wip.results);
+ return try wip.dg.constructVector(wip.result_ty, wip.results);
} else {
return wip.results[0];
}
@@ -1982,7 +2001,7 @@ const DeclGen = struct {
/// Create a new element-wise operation.
fn elementWise(self: *DeclGen, result_ty: Type, force_element_wise: bool) !WipElementWise {
const mod = self.module;
- const is_array = result_ty.isVector(mod) and (!self.isVector(result_ty) or force_element_wise);
+ const is_array = result_ty.isVector(mod) and (!self.isSpvVector(result_ty) or force_element_wise);
const num_results = if (is_array) result_ty.vectorLen(mod) else 1;
const results = try self.gpa.alloc(IdRef, num_results);
@memset(results, undefined);
@@ -2253,29 +2272,102 @@ const DeclGen = struct {
/// This converts the argument type from resolveType(ty, .indirect) to resolveType(ty, .direct).
fn convertToDirect(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef {
const mod = self.module;
- return switch (ty.zigTypeTag(mod)) {
- .Bool => blk: {
- const result_id = self.spv.allocId();
- try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{
- .id_result_type = try self.resolveType(Type.bool, .direct),
- .id_result = result_id,
- .operand_1 = operand_id,
- .operand_2 = try self.constBool(false, .indirect),
- });
- break :blk result_id;
+ const scalar_ty = ty.scalarType(mod);
+ const is_spv_vector = self.isSpvVector(ty);
+ switch (scalar_ty.zigTypeTag(mod)) {
+ .Bool => {
+ // TODO: We may want to use something like elementWise in this function.
+ // First we need to audit whether this would recursively call into itself.
+ if (!ty.isVector(mod) or is_spv_vector) {
+ const result_id = self.spv.allocId();
+ const scalar_false_id = try self.constBool(false, .indirect);
+ const false_id = if (is_spv_vector) blk: {
+ const index = try mod.intern_pool.get(mod.gpa, .{
+ .vector_type = .{
+ .len = ty.vectorLen(mod),
+ .child = Type.u1.toIntern(),
+ },
+ });
+ const vec_ty = Type.fromInterned(index);
+ break :blk try self.constructVectorSplat(vec_ty, scalar_false_id);
+ } else scalar_false_id;
+
+ try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{
+ .id_result_type = try self.resolveType(ty, .direct),
+ .id_result = result_id,
+ .operand_1 = operand_id,
+ .operand_2 = false_id,
+ });
+ return result_id;
+ }
+
+ const constituents = try self.gpa.alloc(IdRef, ty.vectorLen(mod));
+ for (constituents, 0..) |*id, i| {
+ const element = try self.extractVectorComponent(scalar_ty, operand_id, @intCast(i));
+ id.* = try self.convertToDirect(scalar_ty, element);
+ }
+ return try self.constructVector(ty, constituents);
},
- else => operand_id,
- };
+ else => return operand_id,
+ }
}
/// Convert representation from direct (in 'register) to direct (in memory)
/// This converts the argument type from resolveType(ty, .direct) to resolveType(ty, .indirect).
fn convertToIndirect(self: *DeclGen, ty: Type, operand_id: IdRef) !IdRef {
const mod = self.module;
- return switch (ty.zigTypeTag(mod)) {
- .Bool => try self.intFromBool(Type.u1, operand_id),
- else => operand_id,
- };
+ const scalar_ty = ty.scalarType(mod);
+ const is_spv_vector = self.isSpvVector(ty);
+ switch (scalar_ty.zigTypeTag(mod)) {
+ .Bool => {
+ const result_ty = if (is_spv_vector) blk: {
+ const index = try mod.intern_pool.get(mod.gpa, .{
+ .vector_type = .{
+ .len = ty.vectorLen(mod),
+ .child = Type.u1.toIntern(),
+ },
+ });
+ break :blk Type.fromInterned(index);
+ } else Type.u1;
+
+ if (!ty.isVector(mod) or is_spv_vector) {
+ // TODO: We may want to use something like elementWise in this function.
+ // First we need to audit whether this would recursively call into itself.
+ // Also unify it with intFromBool
+
+ const scalar_zero_id = try self.constInt(Type.u1, 0, .direct);
+ const scalar_one_id = try self.constInt(Type.u1, 1, .direct);
+
+ const zero_id = if (is_spv_vector)
+ try self.constructVectorSplat(result_ty, scalar_zero_id)
+ else
+ scalar_zero_id;
+
+ const one_id = if (is_spv_vector)
+ try self.constructVectorSplat(result_ty, scalar_one_id)
+ else
+ scalar_one_id;
+
+ const result_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpSelect, .{
+ .id_result_type = try self.resolveType(result_ty, .direct),
+ .id_result = result_id,
+ .condition = operand_id,
+ .object_1 = one_id,
+ .object_2 = zero_id,
+ });
+ return result_id;
+ }
+
+ const constituents = try self.gpa.alloc(IdRef, ty.vectorLen(mod));
+ for (constituents, 0..) |*id, i| {
+ const element = try self.extractVectorComponent(scalar_ty, operand_id, @intCast(i));
+ id.* = try self.convertToIndirect(scalar_ty, element);
+ }
+ return try self.constructVector(result_ty, constituents);
+ },
+ else => return operand_id,
+ }
}
fn extractField(self: *DeclGen, result_ty: Type, object: IdRef, field: u32) !IdRef {
@@ -2292,6 +2384,21 @@ const DeclGen = struct {
return try self.convertToDirect(result_ty, result_id);
}
+ fn extractVectorComponent(self: *DeclGen, result_ty: Type, vector_id: IdRef, field: u32) !IdRef {
+ // Whether this is an OpTypeVector or OpTypeArray, we need to emit the same instruction regardless.
+ const result_ty_id = try self.resolveType(result_ty, .direct);
+ const result_id = self.spv.allocId();
+ const indexes = [_]u32{field};
+ try self.func.body.emit(self.spv.gpa, .OpCompositeExtract, .{
+ .id_result_type = result_ty_id,
+ .id_result = result_id,
+ .composite = vector_id,
+ .indexes = &indexes,
+ });
+ // Vector components are already stored in direct representation.
+ return result_id;
+ }
+
const MemoryOptions = struct {
is_volatile: bool = false,
};
@@ -2926,7 +3033,7 @@ const DeclGen = struct {
const ov_ty = result_ty.structFieldType(1, self.module);
const bool_ty_id = try self.resolveType(Type.bool, .direct);
- const cmp_ty_id = if (self.isVector(operand_ty))
+ const cmp_ty_id = if (self.isSpvVector(operand_ty))
// TODO: Resolving a vector type with .direct should return a SPIR-V vector
try self.spv.vectorType(operand_ty.vectorLen(mod), try self.resolveType(Type.bool, .direct))
else
@@ -3100,7 +3207,7 @@ const DeclGen = struct {
const ov_ty = result_ty.structFieldType(1, self.module);
const bool_ty_id = try self.resolveType(Type.bool, .direct);
- const cmp_ty_id = if (self.isVector(operand_ty))
+ const cmp_ty_id = if (self.isSpvVector(operand_ty))
// TODO: Resolving a vector type with .direct should return a SPIR-V vector
try self.spv.vectorType(operand_ty.vectorLen(mod), try self.resolveType(Type.bool, .direct))
else
@@ -3312,7 +3419,7 @@ const DeclGen = struct {
const info = self.arithmeticTypeInfo(operand_ty);
- var result_id = try self.extractField(scalar_ty, operand, 0);
+ var result_id = try self.extractVectorComponent(scalar_ty, operand, 0);
const len = operand_ty.vectorLen(mod);
switch (reduce.operation) {
@@ -3320,7 +3427,7 @@ const DeclGen = struct {
const cmp_op: std.math.CompareOperator = if (op == .Max) .gt else .lt;
for (1..len) |i| {
const lhs = result_id;
- const rhs = try self.extractField(scalar_ty, operand, @intCast(i));
+ const rhs = try self.extractVectorComponent(scalar_ty, operand, @intCast(i));
result_id = try self.minMax(scalar_ty, cmp_op, lhs, rhs);
}
@@ -3354,7 +3461,7 @@ const DeclGen = struct {
for (1..len) |i| {
const lhs = result_id;
- const rhs = try self.extractField(scalar_ty, operand, @intCast(i));
+ const rhs = try self.extractVectorComponent(scalar_ty, operand, @intCast(i));
result_id = self.spv.allocId();
try self.func.body.emitRaw(self.spv.gpa, opcode, 4);
@@ -3388,9 +3495,9 @@ const DeclGen = struct {
const index = elem.toSignedInt(mod);
if (index >= 0) {
- result_id.* = try self.extractField(wip.ty, a, @intCast(index));
+ result_id.* = try self.extractVectorComponent(wip.ty, a, @intCast(index));
} else {
- result_id.* = try self.extractField(wip.ty, b, @intCast(~index));
+ result_id.* = try self.extractVectorComponent(wip.ty, b, @intCast(~index));
}
}
return try wip.finalize();
@@ -4086,8 +4193,7 @@ const DeclGen = struct {
defer self.gpa.free(elem_ids);
for (elements, 0..) |element, i| {
- const id = try self.resolve(element);
- elem_ids[i] = try self.convertToIndirect(result_ty.childType(mod), id);
+ elem_ids[i] = try self.resolve(element);
}
return try self.constructVector(result_ty, elem_ids);
@@ -4234,16 +4340,54 @@ const DeclGen = struct {
const array_id = try self.resolve(bin_op.lhs);
const index_id = try self.resolve(bin_op.rhs);
+ if (self.isSpvVector(array_ty)) {
+ const result_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpVectorExtractDynamic, .{
+ .id_result_type = try self.resolveType(elem_ty, .direct),
+ .id_result = result_id,
+ .vector = array_id,
+ .index = index_id,
+ });
+ return result_id;
+ }
+
// SPIR-V doesn't have an array indexing function for some damn reason.
// For now, just generate a temporary and use that.
// TODO: This backend probably also should use isByRef from llvm...
- const elem_ptr_ty_id = try self.ptrType(elem_ty, .Function);
+ const ptr_array_ty_id = try self.ptrType2(array_ty, .Function, .direct);
+ const ptr_elem_ty_id = try self.ptrType2(elem_ty, .Function, .direct);
+
+ const tmp_id = self.spv.allocId();
+ try self.func.prologue.emit(self.spv.gpa, .OpVariable, .{
+ .id_result_type = ptr_array_ty_id,
+ .id_result = tmp_id,
+ .storage_class = .Function,
+ });
+
+ try self.func.body.emit(self.spv.gpa, .OpStore, .{
+ .pointer = tmp_id,
+ .object = array_id,
+ });
+
+ const elem_ptr_id = try self.accessChainId(ptr_elem_ty_id, tmp_id, &.{index_id});
+
+ const result_id = self.spv.allocId();
+ try self.func.body.emit(self.spv.gpa, .OpLoad, .{
+ .id_result_type = try self.resolveType(elem_ty, .direct),
+ .id_result = result_id,
+ .pointer = elem_ptr_id,
+ });
+
+ if (array_ty.isVector(mod)) {
+ // Result is already in direct representation
+ return result_id;
+ }
+
+ // This is an array type; the elements are stored in indirect representation.
+ // We have to convert the type to direct.
- const tmp_id = try self.alloc(array_ty, .{ .storage_class = .Function });
- try self.store(array_ty, tmp_id, array_id, .{});
- const elem_ptr_id = try self.accessChainId(elem_ptr_ty_id, tmp_id, &.{index_id});
- return try self.load(elem_ty, elem_ptr_id, .{});
+ return try self.convertToDirect(elem_ty, result_id);
}
fn airPtrElemVal(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {