Commit 9641d2ebdb

Robin Voetter <robin@voetter.nl>
2024-01-21 20:12:25
spirv: vectorize max, min
1 parent 9f0227a
Changed files (2)
src
codegen
test
src/codegen/spirv.zig
@@ -2391,45 +2391,51 @@ const DeclGen = struct {
     }
 
     fn minMax(self: *DeclGen, result_ty: Type, op: std.math.CompareOperator, lhs_id: IdRef, rhs_id: IdRef) !IdRef {
-        const result_ty_ref = try self.resolveType(result_ty, .direct);
         const info = self.arithmeticTypeInfo(result_ty);
 
-        // TODO: Use fmin for OpenCL
-        const cmp_id = try self.cmp(op, Type.bool, result_ty, lhs_id, rhs_id);
-        const selection_id = switch (info.class) {
-            .float => blk: {
-                // cmp uses OpFOrd. When we have 0 [<>] nan this returns false,
-                // but we want it to pick lhs. Therefore we also have to check if
-                // rhs is nan. We don't need to care about the result when both
-                // are nan.
-                const rhs_is_nan_id = self.spv.allocId();
-                const bool_ty_ref = try self.resolveType(Type.bool, .direct);
-                try self.func.body.emit(self.spv.gpa, .OpIsNan, .{
-                    .id_result_type = self.typeId(bool_ty_ref),
-                    .id_result = rhs_is_nan_id,
-                    .x = rhs_id,
-                });
-                const float_cmp_id = self.spv.allocId();
-                try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{
-                    .id_result_type = self.typeId(bool_ty_ref),
-                    .id_result = float_cmp_id,
-                    .operand_1 = cmp_id,
-                    .operand_2 = rhs_is_nan_id,
-                });
-                break :blk float_cmp_id;
-            },
-            else => cmp_id,
-        };
+        var wip = try self.elementWise(result_ty);
+        defer wip.deinit();
+        for (wip.results, 0..) |*result_id, i| {
+            const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i);
+            const rhs_elem_id = try wip.elementAt(result_ty, rhs_id, i);
+
+            // TODO: Use fmin for OpenCL
+            const cmp_id = try self.cmp(op, Type.bool, wip.scalar_ty, lhs_elem_id, rhs_elem_id);
+            const selection_id = switch (info.class) {
+                .float => blk: {
+                    // cmp uses OpFOrd. When we have 0 [<>] nan this returns false,
+                    // but we want it to pick lhs. Therefore we also have to check if
+                    // rhs is nan. We don't need to care about the result when both
+                    // are nan.
+                    const rhs_is_nan_id = self.spv.allocId();
+                    const bool_ty_ref = try self.resolveType(Type.bool, .direct);
+                    try self.func.body.emit(self.spv.gpa, .OpIsNan, .{
+                        .id_result_type = self.typeId(bool_ty_ref),
+                        .id_result = rhs_is_nan_id,
+                        .x = rhs_elem_id,
+                    });
+                    const float_cmp_id = self.spv.allocId();
+                    try self.func.body.emit(self.spv.gpa, .OpLogicalOr, .{
+                        .id_result_type = self.typeId(bool_ty_ref),
+                        .id_result = float_cmp_id,
+                        .operand_1 = cmp_id,
+                        .operand_2 = rhs_is_nan_id,
+                    });
+                    break :blk float_cmp_id;
+                },
+                else => cmp_id,
+            };
 
-        const result_id = self.spv.allocId();
-        try self.func.body.emit(self.spv.gpa, .OpSelect, .{
-            .id_result_type = self.typeId(result_ty_ref),
-            .id_result = result_id,
-            .condition = selection_id,
-            .object_1 = lhs_id,
-            .object_2 = rhs_id,
-        });
-        return result_id;
+            result_id.* = self.spv.allocId();
+            try self.func.body.emit(self.spv.gpa, .OpSelect, .{
+                .id_result_type = wip.scalar_ty_id,
+                .id_result = result_id.*,
+                .condition = selection_id,
+                .object_1 = lhs_elem_id,
+                .object_2 = rhs_elem_id,
+            });
+        }
+        return wip.finalize();
     }
 
     /// This function normalizes values to a canonical representation
@@ -3107,20 +3113,15 @@ const DeclGen = struct {
                 return result_id;
             },
             .Vector => {
-                const child_ty = ty.childType(mod);
-                const vector_len = ty.vectorLen(mod);
-
-                const constituents = try self.gpa.alloc(IdRef, vector_len);
-                defer self.gpa.free(constituents);
-
-                for (constituents, 0..) |*constituent, i| {
-                    const lhs_index_id = try self.extractField(child_ty, cmp_lhs_id, @intCast(i));
-                    const rhs_index_id = try self.extractField(child_ty, cmp_rhs_id, @intCast(i));
-                    const result_id = try self.cmp(op, Type.bool, child_ty, lhs_index_id, rhs_index_id);
-                    constituent.* = try self.convertToIndirect(Type.bool, result_id);
+                var wip = try self.elementWise(result_ty);
+                defer wip.deinit();
+                const scalar_ty = ty.scalarType(mod);
+                for (wip.results, 0..) |*result_id, i| {
+                    const lhs_elem_id = try wip.elementAt(ty, lhs_id, i);
+                    const rhs_elem_id = try wip.elementAt(ty, rhs_id, i);
+                    result_id.* = try self.cmp(op, Type.bool, scalar_ty, lhs_elem_id, rhs_elem_id);
                 }
-
-                return try self.constructArray(result_ty, constituents);
+                return wip.finalize();
             },
             else => unreachable,
         };
test/behavior/maximum_minimum.zig
@@ -31,7 +31,6 @@ test "@max on vectors" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_x86_64 and
         !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .sse4_1)) return error.SkipZigTest;
 
@@ -86,7 +85,6 @@ test "@min for vectors" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_x86_64 and
         !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .sse4_1)) return error.SkipZigTest;
 
@@ -199,7 +197,6 @@ test "@min/@max notices vector bounds" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
 
     var x: @Vector(2, u16) = .{ 140, 40 };
@@ -253,7 +250,6 @@ test "@min/@max notices bounds from vector types" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
 
     var x: @Vector(2, u16) = .{ 30, 67 };
@@ -295,7 +291,6 @@ test "@min/@max notices bounds from vector types when element of comptime-known
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_x86_64 and
         !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .avx)) return error.SkipZigTest;