Commit 761594e226

Robin Voetter <robin@voetter.nl>
2024-01-19 01:12:56
spirv: reduce, reduce_optimized
1 parent 2f81585
Changed files (2)
src
codegen
test
behavior
src/codegen/spirv.zig
@@ -2187,6 +2187,7 @@ const DeclGen = struct {
             .sub_with_overflow => try self.airAddSubOverflow(inst, .OpISub, .OpUGreaterThan, .OpSGreaterThan),
             .shl_with_overflow => try self.airShlOverflow(inst),
 
+            .reduce, .reduce_optimized => try self.airReduce(inst),
             .shuffle => try self.airShuffle(inst),
 
             .ptr_add => try self.airPtrAdd(inst),
@@ -2388,9 +2389,14 @@ const DeclGen = struct {
         const lhs_id = try self.resolve(bin_op.lhs);
         const rhs_id = try self.resolve(bin_op.rhs);
         const result_ty = self.typeOfIndex(inst);
-        const result_ty_ref = try self.resolveType(result_ty, .direct);
 
+        return try self.minMax(result_ty, op, lhs_id, rhs_id);
+    }
+
+    fn minMax(self: *DeclGen, result_ty: Type, op: std.math.CompareOperator, lhs_id: IdRef, rhs_id: IdRef) !IdRef {
+        const result_ty_ref = try self.resolveType(result_ty, .direct);
         const info = try self.arithmeticTypeInfo(result_ty);
+
         // TODO: Use fmin for OpenCL
         const cmp_id = try self.cmp(op, Type.bool, result_ty, lhs_id, rhs_id);
         const selection_id = switch (info.class) {
@@ -2758,6 +2764,73 @@ const DeclGen = struct {
         );
     }
 
+    fn airReduce(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+        if (self.liveness.isUnused(inst)) return null;
+        const mod = self.module;
+        const reduce = self.air.instructions.items(.data)[@intFromEnum(inst)].reduce;
+        const operand = try self.resolve(reduce.operand);
+        const operand_ty = self.typeOf(reduce.operand);
+        const scalar_ty = operand_ty.scalarType(mod);
+        const scalar_ty_ref = try self.resolveType(scalar_ty, .direct);
+        const scalar_ty_id = self.typeId(scalar_ty_ref);
+
+        const info = try self.arithmeticTypeInfo(operand_ty);
+
+        var result_id = try self.extractField(scalar_ty, operand, 0);
+        const len = operand_ty.vectorLen(mod);
+
+        switch (reduce.operation) {
+            .Min, .Max => |op| {
+                const cmp_op: std.math.CompareOperator = if (op == .Max) .gt else .lt;
+                for (1..len) |i| {
+                    const lhs = result_id;
+                    const rhs = try self.extractField(scalar_ty, operand, @intCast(i));
+                    result_id = try self.minMax(scalar_ty, cmp_op, lhs, rhs);
+                }
+
+                return result_id;
+            },
+            else => {},
+        }
+
+        const opcode: Opcode = switch (info.class) {
+            .bool => switch (reduce.operation) {
+                .And => .OpLogicalAnd,
+                .Or => .OpLogicalOr,
+                .Xor => .OpLogicalNotEqual,
+                else => unreachable,
+            },
+            .strange_integer, .integer => switch (reduce.operation) {
+                .And => .OpBitwiseAnd,
+                .Or => .OpBitwiseOr,
+                .Xor => .OpBitwiseXor,
+                .Add => .OpIAdd,
+                .Mul => .OpIMul,
+                else => unreachable,
+            },
+            .float => switch (reduce.operation) {
+                .Add => .OpFAdd,
+                .Mul => .OpFMul,
+                else => unreachable,
+            },
+            .composite_integer => unreachable, // TODO
+        };
+
+        for (1..len) |i| {
+            const lhs = result_id;
+            const rhs = try self.extractField(scalar_ty, operand, @intCast(i));
+            result_id = self.spv.allocId();
+
+            try self.func.body.emitRaw(self.spv.gpa, opcode, 4);
+            self.func.body.writeOperand(spec.IdResultType, scalar_ty_id);
+            self.func.body.writeOperand(spec.IdResult, result_id);
+            self.func.body.writeOperand(spec.IdResultType, lhs);
+            self.func.body.writeOperand(spec.IdResultType, rhs);
+        }
+
+        return result_id;
+    }
+
     fn airShuffle(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
         const mod = self.module;
         if (self.liveness.isUnused(inst)) return null;
test/behavior/vector.zig
@@ -1231,7 +1231,6 @@ test "byte vector initialized in inline function" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     if (comptime builtin.zig_backend == .stage2_llvm and builtin.cpu.arch == .x86_64 and
         builtin.cpu.features.isEnabled(@intFromEnum(std.Target.x86.Feature.avx512f)))
@@ -1301,7 +1300,6 @@ test "@intCast to u0" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var zeros = @Vector(2, u32){ 0, 0 };
     _ = &zeros;