Commit b41aad0193

Ali Chraghi <alichraghi@proton.me>
2024-02-01 17:08:23
spirv: emit vectors whenever we can
1 parent afa7793
Changed files (2)
src
codegen
src/codegen/spirv/Module.zig
@@ -508,6 +508,13 @@ pub fn intType(self: *Module, signedness: std.builtin.Signedness, bits: u16) !Ca
     } });
 }
 
+pub fn vectorType(self: *Module, len: u32, elem_ty_ref: CacheRef) !CacheRef {
+    return try self.resolve(.{ .vector_type = .{
+        .component_type = elem_ty_ref,
+        .component_count = len,
+    } });
+}
+
 pub fn arrayType(self: *Module, len: u32, elem_ty_ref: CacheRef) !CacheRef {
     const len_ty_ref = try self.resolve(.{ .int_type = .{
         .signedness = .unsigned,
src/codegen/spirv.zig
@@ -744,6 +744,30 @@ const DeclGen = struct {
         return try self.load(ty, ptr_composite_id, .{});
     }
 
+    /// Construct a vector at runtime.
+    /// ty must be an vector type.
+    /// Constituents should be in `indirect` representation (as the elements of an vector should be).
+    /// Result is in `direct` representation.
+    fn constructVector(self: *DeclGen, ty: Type, constituents: []const IdRef) !IdRef {
+        // The Khronos LLVM-SPIRV translator crashes because it cannot construct structs which'
+        // operands are not constant.
+        // See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349
+        // For now, just initialize the struct by setting the fields manually...
+        // TODO: Make this OpCompositeConstruct when we can
+        const mod = self.module;
+        const ptr_composite_id = try self.alloc(ty, .{ .storage_class = .Function });
+        const ptr_elem_ty_ref = try self.ptrType(ty.elemType2(mod), .Function);
+        for (constituents, 0..) |constitent_id, index| {
+            const ptr_id = try self.accessChain(ptr_elem_ty_ref, ptr_composite_id, &.{@as(u32, @intCast(index))});
+            try self.func.body.emit(self.spv.gpa, .OpStore, .{
+                .pointer = ptr_id,
+                .object = constitent_id,
+            });
+        }
+
+        return try self.load(ty, ptr_composite_id, .{});
+    }
+
     /// Construct an array at runtime.
     /// ty must be an array type.
     /// Constituents should be in `indirect` representation (as the elements of an array should be).
@@ -963,13 +987,16 @@ const DeclGen = struct {
                     }
 
                     switch (tag) {
-                        inline .array_type => if (array_type.sentinel != .none) {
-                            constituents[constituents.len - 1] = try self.constant(elem_ty, Value.fromInterned(array_type.sentinel), .indirect);
+                        inline .array_type => {
+                            if (array_type.sentinel != .none) {
+                                const sentinel = Value.fromInterned(array_type.sentinel);
+                                constituents[constituents.len - 1] = try self.constant(elem_ty, sentinel, .indirect);
+                            }
+                            return self.constructArray(ty, constituents);
                         },
-                        else => {},
+                        inline .vector_type => return self.constructVector(ty, constituents),
+                        else => unreachable,
                     }
-
-                    return try self.constructArray(ty, constituents);
                 },
                 .struct_type => {
                     const struct_type = mod.typeToStruct(ty).?;
@@ -1492,8 +1519,14 @@ const DeclGen = struct {
 
                 const elem_ty = ty.childType(mod);
                 const elem_ty_ref = try self.resolveType(elem_ty, .indirect);
+                const len = ty.vectorLen(mod);
+                const is_scalar = elem_ty.isNumeric(mod) or elem_ty.toIntern() == .bool_type;
+
+                const ty_ref = if (is_scalar and len > 1 and len <= 4)
+                    try self.spv.vectorType(ty.vectorLen(mod), elem_ty_ref)
+                else
+                    try self.spv.arrayType(ty.vectorLen(mod), elem_ty_ref);
 
-                const ty_ref = try self.spv.arrayType(ty.vectorLen(mod), elem_ty_ref);
                 try self.type_map.put(self.gpa, ty.toIntern(), .{ .ty_ref = ty_ref });
                 return ty_ref;
             },
@@ -3688,7 +3721,19 @@ const DeclGen = struct {
                     constituents[0..index],
                 );
             },
-            .Vector, .Array => {
+            .Vector => {
+                const n_elems = result_ty.vectorLen(mod);
+                const elem_ids = try self.gpa.alloc(IdRef, n_elems);
+                defer self.gpa.free(elem_ids);
+
+                for (elements, 0..) |element, i| {
+                    const id = try self.resolve(element);
+                    elem_ids[i] = try self.convertToIndirect(result_ty.childType(mod), id);
+                }
+
+                return try self.constructVector(result_ty, elem_ids);
+            },
+            .Array => {
                 const array_info = result_ty.arrayInfo(mod);
                 const n_elems: usize = @intCast(result_ty.arrayLenIncludingSentinel(mod));
                 const elem_ids = try self.gpa.alloc(IdRef, n_elems);