Commit eb2d61d02e

Ali Chraghi <alichraghi@proton.me>
2024-02-06 09:38:25
spirv: merge `construct(Struct/Vector/Array)` into `constructComposite`
1 parent 42fcca4
Changed files (2)
src
codegen
link
src/codegen/spirv.zig
@@ -721,75 +721,18 @@ const DeclGen = struct {
         };
     }
 
-    /// Construct a struct at runtime.
-    /// ty must be a struct type.
-    /// Constituents should be in `indirect` representation (as the elements of a struct should be).
-    /// Result is in `direct` representation.
-    fn constructStruct(self: *DeclGen, ty: Type, types: []const Type, constituents: []const IdRef) !IdRef {
-        assert(types.len == constituents.len);
-        // The Khronos LLVM-SPIRV translator crashes because it cannot construct structs which'
-        // operands are not constant.
-        // See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349
-        // For now, just initialize the struct by setting the fields manually...
-        // TODO: Make this OpCompositeConstruct when we can
-        const ptr_composite_id = try self.alloc(ty, .{ .storage_class = .Function });
-        for (constituents, types, 0..) |constitent_id, member_ty, index| {
-            const ptr_member_ty_ref = try self.ptrType(member_ty, .Function);
-            const ptr_id = try self.accessChain(ptr_member_ty_ref, ptr_composite_id, &.{@as(u32, @intCast(index))});
-            try self.func.body.emit(self.spv.gpa, .OpStore, .{
-                .pointer = ptr_id,
-                .object = constitent_id,
-            });
-        }
-        return try self.load(ty, ptr_composite_id, .{});
-    }
-
-    /// Construct a vector at runtime.
-    /// ty must be an vector type.
-    /// Constituents should be in `indirect` representation (as the elements of an vector should be).
-    /// Result is in `direct` representation.
-    fn constructVector(self: *DeclGen, ty: Type, constituents: []const IdRef) !IdRef {
-        // The Khronos LLVM-SPIRV translator crashes because it cannot construct structs which'
-        // operands are not constant.
-        // See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349
-        // For now, just initialize the struct by setting the fields manually...
-        // TODO: Make this OpCompositeConstruct when we can
-        const mod = self.module;
-        const ptr_composite_id = try self.alloc(ty, .{ .storage_class = .Function });
-        const ptr_elem_ty_ref = try self.ptrType(ty.elemType2(mod), .Function);
-        for (constituents, 0..) |constitent_id, index| {
-            const ptr_id = try self.accessChain(ptr_elem_ty_ref, ptr_composite_id, &.{@as(u32, @intCast(index))});
-            try self.func.body.emit(self.spv.gpa, .OpStore, .{
-                .pointer = ptr_id,
-                .object = constitent_id,
-            });
-        }
-
-        return try self.load(ty, ptr_composite_id, .{});
-    }
-
-    /// Construct an array at runtime.
-    /// ty must be an array type.
-    /// Constituents should be in `indirect` representation (as the elements of an array should be).
-    /// Result is in `direct` representation.
-    fn constructArray(self: *DeclGen, ty: Type, constituents: []const IdRef) !IdRef {
-        // The Khronos LLVM-SPIRV translator crashes because it cannot construct structs which'
-        // operands are not constant.
-        // See https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1349
-        // For now, just initialize the struct by setting the fields manually...
-        // TODO: Make this OpCompositeConstruct when we can
-        const mod = self.module;
-        const ptr_composite_id = try self.alloc(ty, .{ .storage_class = .Function });
-        const ptr_elem_ty_ref = try self.ptrType(ty.elemType2(mod), .Function);
-        for (constituents, 0..) |constitent_id, index| {
-            const ptr_id = try self.accessChain(ptr_elem_ty_ref, ptr_composite_id, &.{@as(u32, @intCast(index))});
-            try self.func.body.emit(self.spv.gpa, .OpStore, .{
-                .pointer = ptr_id,
-                .object = constitent_id,
-            });
-        }
-
-        return try self.load(ty, ptr_composite_id, .{});
+    /// Construct a composite value at runtime. If the parameters are in direct
+    /// representation, then the result is also in direct representation. Otherwise,
+    /// if the parameters are in indirect representation, then the result is too.
+    fn constructComposite(self: *DeclGen, ty: Type, constituents: []const IdRef) !IdRef {
+        const constituents_id = self.spv.allocId();
+        const type_id = try self.resolveTypeId(ty);
+        try self.func.body.emit(self.spv.gpa, .OpCompositeConstruct, .{
+            .id_result_type = type_id,
+            .id_result = constituents_id,
+            .constituents = constituents,
+        });
+        return constituents_id;
     }
 
     /// This function generates a load for a constant in direct (ie, non-memory) representation.
@@ -897,18 +840,15 @@ const DeclGen = struct {
                 });
 
                 var constituents: [2]IdRef = undefined;
-                var types: [2]Type = undefined;
                 if (eu_layout.error_first) {
                     constituents[0] = try self.constant(err_ty, err_val, .indirect);
                     constituents[1] = try self.constant(payload_ty, payload_val, .indirect);
-                    types = .{ err_ty, payload_ty };
                 } else {
                     constituents[0] = try self.constant(payload_ty, payload_val, .indirect);
                     constituents[1] = try self.constant(err_ty, err_val, .indirect);
-                    types = .{ payload_ty, err_ty };
                 }
 
-                return try self.constructStruct(ty, &types, &constituents);
+                return try self.constructComposite(ty, &constituents);
             },
             .enum_tag => {
                 const int_val = try val.intFromEnum(ty, mod);
@@ -920,11 +860,7 @@ const DeclGen = struct {
                 const ptr_ty = ty.slicePtrFieldType(mod);
                 const ptr_id = try self.constantPtr(ptr_ty, Value.fromInterned(slice.ptr));
                 const len_id = try self.constant(Type.usize, Value.fromInterned(slice.len), .indirect);
-                return self.constructStruct(
-                    ty,
-                    &.{ ptr_ty, Type.usize },
-                    &.{ ptr_id, len_id },
-                );
+                return self.constructComposite(ty, &.{ ptr_id, len_id });
             },
             .opt => {
                 const payload_ty = ty.optionalChild(mod);
@@ -951,11 +887,7 @@ const DeclGen = struct {
                 else
                     try self.spv.constUndef(try self.resolveType(payload_ty, .indirect));
 
-                return try self.constructStruct(
-                    ty,
-                    &.{ payload_ty, Type.bool },
-                    &.{ payload_id, has_pl_id },
-                );
+                return try self.constructComposite(ty, &.{ payload_id, has_pl_id });
             },
             .aggregate => |aggregate| switch (ip.indexToKey(ty.ip_index)) {
                 inline .array_type, .vector_type => |array_type, tag| {
@@ -992,9 +924,9 @@ const DeclGen = struct {
                                 const sentinel = Value.fromInterned(array_type.sentinel);
                                 constituents[constituents.len - 1] = try self.constant(elem_ty, sentinel, .indirect);
                             }
-                            return self.constructArray(ty, constituents);
+                            return self.constructComposite(ty, constituents);
                         },
-                        inline .vector_type => return self.constructVector(ty, constituents),
+                        inline .vector_type => return self.constructComposite(ty, constituents),
                         else => unreachable,
                     }
                 },
@@ -1004,9 +936,6 @@ const DeclGen = struct {
                         return self.todo("packed struct constants", .{});
                     }
 
-                    var types = std.ArrayList(Type).init(self.gpa);
-                    defer types.deinit();
-
                     var constituents = std.ArrayList(IdRef).init(self.gpa);
                     defer constituents.deinit();
 
@@ -1022,11 +951,10 @@ const DeclGen = struct {
                         const field_val = try val.fieldValue(mod, field_index);
                         const field_id = try self.constant(field_ty, field_val, .indirect);
 
-                        try types.append(field_ty);
                         try constituents.append(field_id);
                     }
 
-                    return try self.constructStruct(ty, types.items, constituents.items);
+                    return try self.constructComposite(ty, constituents.items);
                 },
                 .anon_struct_type => unreachable, // TODO
                 else => unreachable,
@@ -1870,7 +1798,7 @@ const DeclGen = struct {
                 for (wip.results) |*result| {
                     result.* = try wip.dg.convertToIndirect(wip.scalar_ty, result.*);
                 }
-                return try wip.dg.constructArray(wip.result_ty, wip.results);
+                return try wip.dg.constructComposite(wip.result_ty, wip.results);
             } else {
                 return wip.results[0];
             }
@@ -2814,9 +2742,8 @@ const DeclGen = struct {
             ov_id.* = try self.intFromBool(wip_ov.scalar_ty_ref, overflowed_id);
         }
 
-        return try self.constructStruct(
+        return try self.constructComposite(
             result_ty,
-            &.{ operand_ty, ov_ty },
             &.{ try wip_result.finalize(), try wip_ov.finalize() },
         );
     }
@@ -2905,9 +2832,8 @@ const DeclGen = struct {
             ov_id.* = try self.intFromBool(wip_ov.scalar_ty_ref, overflowed_id);
         }
 
-        return try self.constructStruct(
+        return try self.constructComposite(
             result_ty,
-            &.{ operand_ty, ov_ty },
             &.{ try wip_result.finalize(), try wip_ov.finalize() },
         );
     }
@@ -3637,9 +3563,8 @@ const DeclGen = struct {
             // Convert the pointer-to-array to a pointer to the first element.
             try self.accessChain(elem_ptr_ty_ref, array_ptr_id, &.{0});
 
-        return try self.constructStruct(
+        return try self.constructComposite(
             slice_ty,
-            &.{ elem_ptr_ty, Type.usize },
             &.{ elem_ptr_id, len_id },
         );
     }
@@ -3651,14 +3576,12 @@ const DeclGen = struct {
         const bin_op = self.air.extraData(Air.Bin, ty_pl.payload).data;
         const ptr_id = try self.resolve(bin_op.lhs);
         const len_id = try self.resolve(bin_op.rhs);
-        const ptr_ty = self.typeOf(bin_op.lhs);
         const slice_ty = self.typeOfIndex(inst);
 
         // Note: Types should not need to be converted to direct, these types
         // dont need to be converted.
-        return try self.constructStruct(
+        return try self.constructComposite(
             slice_ty,
-            &.{ ptr_ty, Type.usize },
             &.{ ptr_id, len_id },
         );
     }
@@ -3680,8 +3603,6 @@ const DeclGen = struct {
                     unreachable; // TODO
                 }
 
-                const types = try self.gpa.alloc(Type, elements.len);
-                defer self.gpa.free(types);
                 const constituents = try self.gpa.alloc(IdRef, elements.len);
                 defer self.gpa.free(constituents);
                 var index: usize = 0;
@@ -3693,7 +3614,6 @@ const DeclGen = struct {
                             assert(Type.fromInterned(field_ty).hasRuntimeBits(mod));
 
                             const id = try self.resolve(element);
-                            types[index] = Type.fromInterned(field_ty);
                             constituents[index] = try self.convertToIndirect(Type.fromInterned(field_ty), id);
                             index += 1;
                         }
@@ -3707,7 +3627,6 @@ const DeclGen = struct {
                             assert(field_ty.hasRuntimeBitsIgnoreComptime(mod));
 
                             const id = try self.resolve(element);
-                            types[index] = field_ty;
                             constituents[index] = try self.convertToIndirect(field_ty, id);
                             index += 1;
                         }
@@ -3715,11 +3634,7 @@ const DeclGen = struct {
                     else => unreachable,
                 }
 
-                return try self.constructStruct(
-                    result_ty,
-                    types[0..index],
-                    constituents[0..index],
-                );
+                return try self.constructComposite(result_ty, constituents[0..index]);
             },
             .Vector => {
                 const n_elems = result_ty.vectorLen(mod);
@@ -3731,7 +3646,7 @@ const DeclGen = struct {
                     elem_ids[i] = try self.convertToIndirect(result_ty.childType(mod), id);
                 }
 
-                return try self.constructVector(result_ty, elem_ids);
+                return try self.constructComposite(result_ty, elem_ids);
             },
             .Array => {
                 const array_info = result_ty.arrayInfo(mod);
@@ -3748,7 +3663,7 @@ const DeclGen = struct {
                     elem_ids[n_elems - 1] = try self.constant(array_info.elem_type, sentinel_val, .indirect);
                 }
 
-                return try self.constructArray(result_ty, elem_ids);
+                return try self.constructComposite(result_ty, elem_ids);
             },
             else => unreachable,
         }
@@ -4885,11 +4800,7 @@ const DeclGen = struct {
         members[eu_layout.errorFieldIndex()] = operand_id;
         members[eu_layout.payloadFieldIndex()] = try self.spv.constUndef(payload_ty_ref);
 
-        var types: [2]Type = undefined;
-        types[eu_layout.errorFieldIndex()] = Type.anyerror;
-        types[eu_layout.payloadFieldIndex()] = payload_ty;
-
-        return try self.constructStruct(err_union_ty, &types, &members);
+        return try self.constructComposite(err_union_ty, &members);
     }
 
     fn airWrapErrUnionPayload(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
@@ -4910,11 +4821,7 @@ const DeclGen = struct {
         members[eu_layout.errorFieldIndex()] = try self.constInt(err_ty_ref, 0);
         members[eu_layout.payloadFieldIndex()] = try self.convertToIndirect(payload_ty, operand_id);
 
-        var types: [2]Type = undefined;
-        types[eu_layout.errorFieldIndex()] = Type.anyerror;
-        types[eu_layout.payloadFieldIndex()] = payload_ty;
-
-        return try self.constructStruct(err_union_ty, &types, &members);
+        return try self.constructComposite(err_union_ty, &members);
     }
 
     fn airIsNull(self: *DeclGen, inst: Air.Inst.Index, is_pointer: bool, pred: enum { is_null, is_non_null }) !?IdRef {
@@ -5091,8 +4998,7 @@ const DeclGen = struct {
 
         const payload_id = try self.convertToIndirect(payload_ty, operand_id);
         const members = [_]IdRef{ payload_id, try self.constBool(true, .indirect) };
-        const types = [_]Type{ payload_ty, Type.bool };
-        return try self.constructStruct(optional_ty, &types, &members);
+        return try self.constructComposite(optional_ty, &members);
     }
 
     fn airSwitchBr(self: *DeclGen, inst: Air.Inst.Index) !void {
src/link/SpirV.zig
@@ -163,7 +163,7 @@ pub fn updateExports(
             .Vertex => spec.ExecutionModel.Vertex,
             .Fragment => spec.ExecutionModel.Fragment,
             .Kernel => spec.ExecutionModel.Kernel,
-            else => unreachable,
+            else => return,
         };
         const is_vulkan = target.os.tag == .vulkan;