Commit 7dfd403da1

Robin Voetter <robin@voetter.nl>
2024-01-21 12:17:19
spirv: air mul_add
1 parent 345d6e2
Changed files (2)
src
codegen
test
behavior
src/codegen/spirv.zig
@@ -2160,9 +2160,9 @@ const DeclGen = struct {
         const air_tags = self.air.instructions.items(.tag);
         const maybe_result_id: ?IdRef = switch (air_tags[@intFromEnum(inst)]) {
             // zig fmt: off
-            .add, .add_wrap => try self.airArithOp(inst, .OpFAdd, .OpIAdd, .OpIAdd),
-            .sub, .sub_wrap => try self.airArithOp(inst, .OpFSub, .OpISub, .OpISub),
-            .mul, .mul_wrap => try self.airArithOp(inst, .OpFMul, .OpIMul, .OpIMul),
+            .add, .add_wrap, .add_optimized => try self.airArithOp(inst, .OpFAdd, .OpIAdd, .OpIAdd),
+            .sub, .sub_wrap, .sub_optimized => try self.airArithOp(inst, .OpFSub, .OpISub, .OpISub),
+            .mul, .mul_wrap, .mul_optimized => try self.airArithOp(inst, .OpFMul, .OpIMul, .OpIMul),
 
             .div_float,
             .div_float_optimized,
@@ -2179,6 +2179,8 @@ const DeclGen = struct {
             .sub_with_overflow => try self.airAddSubOverflow(inst, .OpISub, .OpUGreaterThan, .OpSGreaterThan),
             .shl_with_overflow => try self.airShlOverflow(inst),
 
+            .mul_add => try self.airMulAdd(inst),
+
             .reduce, .reduce_optimized => try self.airReduce(inst),
             .shuffle => try self.airShuffle(inst),
 
@@ -2439,40 +2441,38 @@ const DeclGen = struct {
         switch (info.class) {
             .integer, .bool, .float => return value_id,
             .composite_integer => unreachable, // TODO
-            .strange_integer => {
-                switch (info.signedness) {
-                    .unsigned => {
-                        const mask_value = if (info.bits == 64) 0xFFFF_FFFF_FFFF_FFFF else (@as(u64, 1) << @as(u6, @intCast(info.bits))) - 1;
-                        const result_id = self.spv.allocId();
-                        const mask_id = try self.constInt(ty_ref, mask_value);
-                        try self.func.body.emit(self.spv.gpa, .OpBitwiseAnd, .{
-                            .id_result_type = self.typeId(ty_ref),
-                            .id_result = result_id,
-                            .operand_1 = value_id,
-                            .operand_2 = mask_id,
-                        });
-                        return result_id;
-                    },
-                    .signed => {
-                        // Shift left and right so that we can copy the sight bit that way.
-                        const shift_amt_id = try self.constInt(ty_ref, info.backing_bits - info.bits);
-                        const left_id = self.spv.allocId();
-                        try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{
-                            .id_result_type = self.typeId(ty_ref),
-                            .id_result = left_id,
-                            .base = value_id,
-                            .shift = shift_amt_id,
-                        });
-                        const right_id = self.spv.allocId();
-                        try self.func.body.emit(self.spv.gpa, .OpShiftRightArithmetic, .{
-                            .id_result_type = self.typeId(ty_ref),
-                            .id_result = right_id,
-                            .base = left_id,
-                            .shift = shift_amt_id,
-                        });
-                        return right_id;
-                    },
-                }
+            .strange_integer => switch (info.signedness) {
+                .unsigned => {
+                    const mask_value = if (info.bits == 64) 0xFFFF_FFFF_FFFF_FFFF else (@as(u64, 1) << @as(u6, @intCast(info.bits))) - 1;
+                    const result_id = self.spv.allocId();
+                    const mask_id = try self.constInt(ty_ref, mask_value);
+                    try self.func.body.emit(self.spv.gpa, .OpBitwiseAnd, .{
+                        .id_result_type = self.typeId(ty_ref),
+                        .id_result = result_id,
+                        .operand_1 = value_id,
+                        .operand_2 = mask_id,
+                    });
+                    return result_id;
+                },
+                .signed => {
+                    // Shift left and right so that we can copy the sight bit that way.
+                    const shift_amt_id = try self.constInt(ty_ref, info.backing_bits - info.bits);
+                    const left_id = self.spv.allocId();
+                    try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{
+                        .id_result_type = self.typeId(ty_ref),
+                        .id_result = left_id,
+                        .base = value_id,
+                        .shift = shift_amt_id,
+                    });
+                    const right_id = self.spv.allocId();
+                    try self.func.body.emit(self.spv.gpa, .OpShiftRightArithmetic, .{
+                        .id_result_type = self.typeId(ty_ref),
+                        .id_result = right_id,
+                        .base = left_id,
+                        .shift = shift_amt_id,
+                    });
+                    return right_id;
+                },
             },
         }
     }
@@ -2761,6 +2761,42 @@ const DeclGen = struct {
         );
     }
 
+    fn airMulAdd(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+        if (self.liveness.isUnused(inst)) return null;
+
+        const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op;
+        const extra = self.air.extraData(Air.Bin, pl_op.payload).data;
+
+        const mulend1 = try self.resolve(extra.lhs);
+        const mulend2 = try self.resolve(extra.rhs);
+        const addend = try self.resolve(pl_op.operand);
+
+        const ty = self.typeOfIndex(inst);
+
+        const info = self.arithmeticTypeInfo(ty);
+        assert(info.class == .float); // .mul_add is only emitted for floats
+
+        var wip = try self.elementWise(ty);
+        defer wip.deinit();
+        for (0..wip.results.len) |i| {
+            const mul_result = self.spv.allocId();
+            try self.func.body.emit(self.spv.gpa, .OpFMul, .{
+                .id_result_type = wip.scalar_ty_id,
+                .id_result = mul_result,
+                .operand_1 = try wip.elementAt(ty, mulend1, i),
+                .operand_2 = try wip.elementAt(ty, mulend2, i),
+            });
+
+            try self.func.body.emit(self.spv.gpa, .OpFAdd, .{
+                .id_result_type = wip.scalar_ty_id,
+                .id_result = wip.allocId(i),
+                .operand_1 = mul_result,
+                .operand_2 = try wip.elementAt(ty, addend, i),
+            });
+        }
+        return try wip.finalize();
+    }
+
     fn airReduce(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
         if (self.liveness.isUnused(inst)) return null;
         const mod = self.module;
test/behavior/muladd.zig
@@ -10,7 +10,6 @@ test "@mulAdd" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     try comptime testMulAdd();
     try testMulAdd();
@@ -37,7 +36,6 @@ test "@mulAdd f16" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf) return error.SkipZigTest;
 
     try comptime testMulAdd16();
@@ -111,7 +109,6 @@ test "vector f16" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     try comptime vector16();
     try vector16();
@@ -136,7 +133,6 @@ test "vector f32" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     try comptime vector32();
     try vector32();
@@ -161,7 +157,6 @@ test "vector f64" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     try comptime vector64();
     try vector64();