Commit 44c31194e3

Ali Chraghi <alichraghi@proton.me>
2024-02-15 14:55:33
spirv: use extended instructions whenever possible
1 parent 6fe90a9
Changed files (3)
src
codegen
link
src/codegen/spirv/Module.zig
@@ -114,8 +114,10 @@ sections: struct {
     capabilities: Section = .{},
     /// OpExtension instructions
     extensions: Section = .{},
-    // OpExtInstImport instructions - skip for now.
-    // memory model defined by target, not required here.
+    /// OpExtInstImport
+    extended_instruction_set: Section = .{},
+    /// memory model defined by target
+    memory_model: Section = .{},
     /// OpEntryPoint instructions - Handled by `self.entry_points`.
     /// OpExecutionMode and OpExecutionModeId instructions.
     execution_modes: Section = .{},
@@ -172,6 +174,9 @@ globals: struct {
     section: Section = .{},
 } = .{},
 
+/// The list of extended instruction sets that should be imported.
+extended_instruction_set: std.AutoHashMapUnmanaged(ExtendedInstructionSet, IdRef) = .{},
+
 pub fn init(gpa: Allocator) Module {
     return .{
         .gpa = gpa,
@@ -182,6 +187,8 @@ pub fn init(gpa: Allocator) Module {
 pub fn deinit(self: *Module) void {
     self.sections.capabilities.deinit(self.gpa);
     self.sections.extensions.deinit(self.gpa);
+    self.sections.extended_instruction_set.deinit(self.gpa);
+    self.sections.memory_model.deinit(self.gpa);
     self.sections.execution_modes.deinit(self.gpa);
     self.sections.debug_strings.deinit(self.gpa);
     self.sections.debug_names.deinit(self.gpa);
@@ -200,6 +207,8 @@ pub fn deinit(self: *Module) void {
     self.globals.globals.deinit(self.gpa);
     self.globals.section.deinit(self.gpa);
 
+    self.extended_instruction_set.deinit(self.gpa);
+
     self.* = undefined;
 }
 
@@ -448,6 +457,8 @@ pub fn flush(self: *Module, file: std.fs.File, target: std.Target) !void {
         &header,
         self.sections.capabilities.toWords(),
         self.sections.extensions.toWords(),
+        self.sections.extended_instruction_set.toWords(),
+        self.sections.memory_model.toWords(),
         entry_points.toWords(),
         self.sections.execution_modes.toWords(),
         source.toWords(),
@@ -482,6 +493,29 @@ pub fn addFunction(self: *Module, decl_index: Decl.Index, func: Fn) !void {
     try self.declareDeclDeps(decl_index, func.decl_deps.keys());
 }
 
+pub const ExtendedInstructionSet = enum {
+    glsl,
+    opencl,
+};
+
+/// Imports or returns the existing id of an extended instruction set
+pub fn importInstructionSet(self: *Module, set: ExtendedInstructionSet) !IdRef {
+    const gop = try self.extended_instruction_set.getOrPut(self.gpa, set);
+    if (gop.found_existing) return gop.value_ptr.*;
+
+    const result_id = self.allocId();
+    try self.sections.extended_instruction_set.emit(self.gpa, .OpExtInstImport, .{
+        .id_result = result_id,
+        .name = switch (set) {
+            .glsl => "GLSL.std.450",
+            .opencl => "OpenCL.std",
+        },
+    });
+    gop.value_ptr.* = result_id;
+
+    return result_id;
+}
+
 /// Fetch the result-id of an OpString instruction that encodes the path of the source
 /// file of the decl. This function may also emit an OpSource with source-level information regarding
 /// the decl.
src/codegen/spirv.zig
@@ -632,11 +632,15 @@ const DeclGen = struct {
     /// Checks whether the type can be directly translated to SPIR-V vectors
     fn isVector(self: *DeclGen, ty: Type) bool {
         const mod = self.module;
+        const target = self.getTarget();
         if (ty.zigTypeTag(mod) != .Vector) return false;
         const elem_ty = ty.childType(mod);
+
         const len = ty.vectorLen(mod);
         const is_scalar = elem_ty.isNumeric(mod) or elem_ty.toIntern() == .bool_type;
-        return is_scalar and len > 1 and len <= 4;
+        const spirv_len = len > 1 and len <= 4;
+        const opencl_len = if (target.os.tag == .opencl) (len == 8 or len == 16) else false;
+        return is_scalar and (spirv_len or opencl_len);
     }
 
     fn arithmeticTypeInfo(self: *DeclGen, ty: Type) ArithmeticTypeInfo {
@@ -1968,7 +1972,10 @@ const DeclGen = struct {
             try self.func.prologue.emit(self.spv.gpa, .OpFunction, .{
                 .id_result_type = self.typeId(return_ty_ref),
                 .id_result = decl_id,
-                .function_control = .{}, // TODO: We can set inline here if the type requires it.
+                .function_control = switch (fn_info.cc) {
+                    .Inline => .{ .Inline = true },
+                    else => .{},
+                },
                 .function_type = prototype_id,
             });
 
@@ -2437,48 +2444,71 @@ const DeclGen = struct {
 
     fn minMax(self: *DeclGen, result_ty: Type, op: std.math.CompareOperator, lhs_id: IdRef, rhs_id: IdRef) !IdRef {
         const info = self.arithmeticTypeInfo(result_ty);
+        const target = self.getTarget();
 
-        var wip = try self.elementWise(result_ty, true);
+        const use_backup_codegen = target.os.tag == .opencl and info.class != .float;
+        var wip = try self.elementWise(result_ty, use_backup_codegen);
         defer wip.deinit();
+
         for (wip.results, 0..) |*result_id, i| {
             const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i);
             const rhs_elem_id = try wip.elementAt(result_ty, rhs_id, i);
 
-            // TODO: Use fmin for OpenCL
-            const cmp_id = try self.cmp(op, Type.bool, wip.ty, lhs_elem_id, rhs_elem_id);
-            const selection_id = switch (info.class) {
-                .float => blk: {
-                    // cmp uses OpFOrd. When we have 0 [<>] nan this returns false,
-                    // but we want it to pick lhs. Therefore we also have to check if
-                    // rhs is nan. We don't need to care about the result when both
-                    // are nan.
-                    const rhs_is_nan_id = self.spv.allocId();
-                    const bool_ty_ref = try self.resolveType(Type.bool, .direct);
-                    try self.func.body.emit(self.spv.gpa, .OpIsNan, .{
-                        .id_result_type = self.typeId(bool_ty_ref),
-                        .id_result = rhs_is_nan_id,
-                        .x = rhs_elem_id,
-                    });
-                    const float_cmp_id = self.spv.allocId();
-                    try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{
-                        .id_result_type = self.typeId(bool_ty_ref),
-                        .id_result = float_cmp_id,
-                        .operand_1 = cmp_id,
-                        .operand_2 = rhs_is_nan_id,
-                    });
-                    break :blk float_cmp_id;
-                },
-                else => cmp_id,
-            };
+            if (use_backup_codegen) {
+                const cmp_id = try self.cmp(op, Type.bool, wip.ty, lhs_elem_id, rhs_elem_id);
+                result_id.* = self.spv.allocId();
+                try self.func.body.emit(self.spv.gpa, .OpSelect, .{
+                    .id_result_type = wip.ty_id,
+                    .id_result = result_id.*,
+                    .condition = cmp_id,
+                    .object_1 = lhs_elem_id,
+                    .object_2 = rhs_elem_id,
+                });
+            } else {
+                const ext_inst: Word = switch (target.os.tag) {
+                    .opencl => switch (op) {
+                        .lt => 28, // fmin
+                        .gt => 27, // fmax
+                        else => unreachable,
+                    },
+                    .vulkan => switch (info.class) {
+                        .float => switch (op) {
+                            .lt => 37, // FMin
+                            .gt => 40, // FMax
+                            else => unreachable,
+                        },
+                        .integer, .strange_integer => switch (info.signedness) {
+                            .signed => switch (op) {
+                                .lt => 39, // SMin
+                                .gt => 42, // SMax
+                                else => unreachable,
+                            },
+                            .unsigned => switch (op) {
+                                .lt => 38, // UMin
+                                .gt => 41, // UMax
+                                else => unreachable,
+                            },
+                        },
+                        .composite_integer => unreachable, // TODO
+                        .bool => unreachable,
+                    },
+                    else => unreachable,
+                };
+                const set_id = switch (target.os.tag) {
+                    .opencl => try self.spv.importInstructionSet(.opencl),
+                    .vulkan => try self.spv.importInstructionSet(.glsl),
+                    else => unreachable,
+                };
 
-            result_id.* = self.spv.allocId();
-            try self.func.body.emit(self.spv.gpa, .OpSelect, .{
-                .id_result_type = wip.ty_id,
-                .id_result = result_id.*,
-                .condition = selection_id,
-                .object_1 = lhs_elem_id,
-                .object_2 = rhs_elem_id,
-            });
+                result_id.* = self.spv.allocId();
+                try self.func.body.emit(self.spv.gpa, .OpExtInst, .{
+                    .id_result_type = wip.ty_id,
+                    .id_result = result_id.*,
+                    .set = set_id,
+                    .instruction = .{ .inst = ext_inst },
+                    .id_ref_4 = &.{ lhs_elem_id, rhs_elem_id },
+                });
+            }
         }
         return wip.finalize();
     }
@@ -2607,57 +2637,52 @@ const DeclGen = struct {
     }
 
     fn airAbs(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
-        const mod = self.module;
+        const target = self.getTarget();
         const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
         const operand_id = try self.resolve(ty_op.operand);
         // Note: operand_ty may be signed, while ty is always unsigned!
         const operand_ty = self.typeOf(ty_op.operand);
         const result_ty = self.typeOfIndex(inst);
-        const info = self.arithmeticTypeInfo(result_ty);
-        const operand_scalar_ty = operand_ty.scalarType(mod);
-        const operand_scalar_ty_ref = try self.resolveType(operand_scalar_ty, .direct);
+        const operand_info = self.arithmeticTypeInfo(operand_ty);
 
-        var wip = try self.elementWise(result_ty, true);
+        var wip = try self.elementWise(result_ty, false);
         defer wip.deinit();
 
-        const zero_id = switch (info.class) {
-            .float => try self.constFloat(operand_scalar_ty_ref, 0),
-            .integer, .strange_integer => try self.constInt(operand_scalar_ty_ref, 0),
-            .composite_integer => unreachable, // TODO
-            .bool => unreachable,
-        };
         for (wip.results, 0..) |*result_id, i| {
             const elem_id = try wip.elementAt(operand_ty, operand_id, i);
-            // Idk why spir-v doesn't have a dedicated abs() instruction in the base
-            // instruction set. For now we're just going to negate and check to avoid
-            // importing the extinst.
-            // TODO: Make this a call to compiler rt / ext inst
-            const neg_id = self.spv.allocId();
-            const args = .{
-                .id_result_type = self.typeId(operand_scalar_ty_ref),
-                .id_result = neg_id,
-                .operand_1 = zero_id,
-                .operand_2 = elem_id,
+
+            const ext_inst: Word = switch (target.os.tag) {
+                .opencl => switch (operand_info.class) {
+                    .float => 23, // fabs
+                    .integer, .strange_integer => switch (operand_info.signedness) {
+                        .signed => 141, // s_abs
+                        .unsigned => 201, // u_abs
+                    },
+                    .composite_integer => unreachable, // TODO
+                    .bool => unreachable,
+                },
+                .vulkan => switch (operand_info.class) {
+                    .float => 4, // FAbs
+                    .integer, .strange_integer => 5, // SAbs
+                    .composite_integer => unreachable, // TODO
+                    .bool => unreachable,
+                },
+                else => unreachable,
             };
-            switch (info.class) {
-                .float => try self.func.body.emit(self.spv.gpa, .OpFSub, args),
-                .integer, .strange_integer => try self.func.body.emit(self.spv.gpa, .OpISub, args),
-                .composite_integer => unreachable, // TODO
-                .bool => unreachable,
-            }
-            const neg_norm_id = try self.normalize(wip.ty_ref, neg_id, info);
-
-            const gt_zero_id = try self.cmp(.gt, Type.bool, operand_scalar_ty, elem_id, zero_id);
-            const abs_id = self.spv.allocId();
-            try self.func.body.emit(self.spv.gpa, .OpSelect, .{
-                .id_result_type = self.typeId(operand_scalar_ty_ref),
-                .id_result = abs_id,
-                .condition = gt_zero_id,
-                .object_1 = elem_id,
-                .object_2 = neg_norm_id,
+            const set_id = switch (target.os.tag) {
+                .opencl => try self.spv.importInstructionSet(.opencl),
+                .vulkan => try self.spv.importInstructionSet(.glsl),
+                else => unreachable,
+            };
+
+            result_id.* = self.spv.allocId();
+            try self.func.body.emit(self.spv.gpa, .OpExtInst, .{
+                .id_result_type = wip.ty_id,
+                .id_result = result_id.*,
+                .set = set_id,
+                .instruction = .{ .inst = ext_inst },
+                .id_ref_4 = &.{elem_id},
             });
-            // For Shader, we may need to cast from signed to unsigned here.
-            result_id.* = try self.bitCast(wip.ty, operand_scalar_ty, abs_id);
         }
         return try wip.finalize();
     }
src/link/SpirV.zig
@@ -246,7 +246,7 @@ fn writeCapabilities(spv: *SpvModule, target: std.Target) !void {
     const gpa = spv.gpa;
     // TODO: Integrate with a hypothetical feature system
     const caps: []const spec.Capability = switch (target.os.tag) {
-        .opencl => &.{ .Kernel, .Addresses, .Int8, .Int16, .Int64, .Float64, .Float16, .GenericPointer },
+        .opencl => &.{ .Kernel, .Addresses, .Int8, .Int16, .Int64, .Float64, .Float16, .Vector16, .GenericPointer },
         .glsl450 => &.{.Shader},
         .vulkan => &.{ .Shader, .VariablePointersStorageBuffer, .Int8, .Int16, .Int64, .Float64, .Float16 },
         else => unreachable, // TODO
@@ -279,8 +279,7 @@ fn writeMemoryModel(spv: *SpvModule, target: std.Target) !void {
         else => unreachable,
     };
 
-    // TODO: Put this in a proper section.
-    try spv.sections.extensions.emit(gpa, .OpMemoryModel, .{
+    try spv.sections.memory_model.emit(gpa, .OpMemoryModel, .{
         .addressing_model = addressing_model,
         .memory_model = memory_model,
     });