Commit 45a1945dc4

Ali Chraghi <alichraghi@proton.me>
2023-10-14 02:35:04
spirv: simple binary and comparison vector operations
1 parent 9c20449
Changed files (2)
src
codegen
test
behavior
src/codegen/spirv.zig
@@ -2027,12 +2027,13 @@ const DeclGen = struct {
             .struct_field_ptr_index_2 => try self.airStructFieldPtrIndex(inst, 2),
             .struct_field_ptr_index_3 => try self.airStructFieldPtrIndex(inst, 3),
 
-            .cmp_eq  => try self.airCmp(inst, .eq),
-            .cmp_neq => try self.airCmp(inst, .neq),
-            .cmp_gt  => try self.airCmp(inst, .gt),
-            .cmp_gte => try self.airCmp(inst, .gte),
-            .cmp_lt  => try self.airCmp(inst, .lt),
-            .cmp_lte => try self.airCmp(inst, .lte),
+            .cmp_eq     => try self.airCmp(inst, .eq),
+            .cmp_neq    => try self.airCmp(inst, .neq),
+            .cmp_gt     => try self.airCmp(inst, .gt),
+            .cmp_gte    => try self.airCmp(inst, .gte),
+            .cmp_lt     => try self.airCmp(inst, .lt),
+            .cmp_lte    => try self.airCmp(inst, .lte),
+            .cmp_vector => try self.airVectorCmp(inst),
 
             .arg     => self.airArg(),
             .alloc   => try self.airAlloc(inst),
@@ -2088,13 +2089,30 @@ const DeclGen = struct {
         try self.inst_results.putNoClobber(self.gpa, inst, result_id);
     }
 
-    fn airBinOpSimple(self: *DeclGen, inst: Air.Inst.Index, comptime opcode: Opcode) !?IdRef {
-        if (self.liveness.isUnused(inst)) return null;
-        const bin_op = self.air.instructions.items(.data)[inst].bin_op;
-        const lhs_id = try self.resolve(bin_op.lhs);
-        const rhs_id = try self.resolve(bin_op.rhs);
+    fn binOpSimple(self: *DeclGen, ty: Type, lhs_id: IdRef, rhs_id: IdRef, comptime opcode: Opcode) !IdRef {
+        const mod = self.module;
+
+        if (ty.isVector(mod)) {
+            const child_ty = ty.childType(mod);
+            const vector_len = ty.vectorLen(mod);
+
+            var constituents = try self.gpa.alloc(IdRef, vector_len);
+            defer self.gpa.free(constituents);
+
+            for (constituents, 0..) |*constituent, i| {
+                const lhs_index_id = try self.extractField(child_ty, lhs_id, @intCast(i));
+                const rhs_index_id = try self.extractField(child_ty, rhs_id, @intCast(i));
+                const result_id = try self.binOpSimple(child_ty, lhs_index_id, rhs_index_id, opcode);
+                constituent.* = try self.convertToIndirect(child_ty, result_id);
+            }
+
+            const result_ty = try self.resolveType(child_ty, .indirect);
+            const result_ty_ref = try self.spv.arrayType(vector_len, result_ty);
+            return try self.constructArray(result_ty_ref, constituents);
+        }
+
         const result_id = self.spv.allocId();
-        const result_type_id = try self.resolveTypeId(self.typeOfIndex(inst));
+        const result_type_id = try self.resolveTypeId(ty);
         try self.func.body.emit(self.spv.gpa, opcode, .{
             .id_result_type = result_type_id,
             .id_result = result_id,
@@ -2104,6 +2122,17 @@ const DeclGen = struct {
         return result_id;
     }
 
+    fn airBinOpSimple(self: *DeclGen, inst: Air.Inst.Index, comptime opcode: Opcode) !?IdRef {
+        if (self.liveness.isUnused(inst)) return null;
+
+        const bin_op = self.air.instructions.items(.data)[inst].bin_op;
+        const lhs_id = try self.resolve(bin_op.lhs);
+        const rhs_id = try self.resolve(bin_op.rhs);
+        const ty = self.typeOf(bin_op.lhs);
+
+        return try self.binOpSimple(ty, lhs_id, rhs_id, opcode);
+    }
+
     fn airShift(self: *DeclGen, inst: Air.Inst.Index, comptime opcode: Opcode) !?IdRef {
         if (self.liveness.isUnused(inst)) return null;
         const bin_op = self.air.instructions.items(.data)[inst].bin_op;
@@ -2676,6 +2705,24 @@ const DeclGen = struct {
                 }
                 return result_id;
             },
+            .Vector => {
+                const child_ty = ty.childType(mod);
+                const vector_len = ty.vectorLen(mod);
+                const bool_ty_ref_indirect = try self.resolveType(Type.bool, .indirect);
+
+                var constituents = try self.gpa.alloc(IdRef, vector_len);
+                defer self.gpa.free(constituents);
+
+                for (constituents, 0..) |*constituent, i| {
+                    const lhs_index_id = try self.extractField(child_ty, cmp_lhs_id, @intCast(i));
+                    const rhs_index_id = try self.extractField(child_ty, cmp_rhs_id, @intCast(i));
+                    const result_id = try self.cmp(op, child_ty, lhs_index_id, rhs_index_id);
+                    constituent.* = try self.convertToIndirect(Type.bool, result_id);
+                }
+
+                const result_ty_ref = try self.spv.arrayType(vector_len, bool_ty_ref_indirect);
+                return try self.constructArray(result_ty_ref, constituents);
+            },
             else => unreachable,
         };
 
@@ -2751,6 +2798,19 @@ const DeclGen = struct {
         return try self.cmp(op, ty, lhs_id, rhs_id);
     }
 
+    fn airVectorCmp(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+        if (self.liveness.isUnused(inst)) return null;
+
+        const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
+        const vec_cmp = self.air.extraData(Air.VectorCmp, ty_pl.payload).data;
+        const lhs_id = try self.resolve(vec_cmp.lhs);
+        const rhs_id = try self.resolve(vec_cmp.rhs);
+        const op = vec_cmp.compareOperator();
+        const ty = self.typeOf(vec_cmp.lhs);
+
+        return try self.cmp(op, ty, lhs_id, rhs_id);
+    }
+
     fn bitCast(
         self: *DeclGen,
         dst_ty: Type,
test/behavior/vector.zig
@@ -53,7 +53,6 @@ test "vector bin compares with mem.eql" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -123,7 +122,6 @@ test "vector bit operators" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -142,7 +140,6 @@ test "implicit cast vector to array" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -160,7 +157,6 @@ test "array to vector" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -285,7 +281,6 @@ test "vector casts of sizes not divisible by 8" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -366,7 +361,6 @@ test "load vector elements via comptime index" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -388,7 +382,6 @@ test "store vector elements via comptime index" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -416,7 +409,6 @@ test "load vector elements via runtime index" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -438,7 +430,6 @@ test "store vector elements via runtime index" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -461,7 +452,6 @@ test "initialize vector which is a struct field" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const Vec4Obj = struct {
         data: @Vector(4, f32),
@@ -485,7 +475,6 @@ test "vector comparison operators" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -1167,7 +1156,6 @@ test "loading the second vector from a slice of vectors" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     @setRuntimeSafety(false);
     var small_bases = [2]@Vector(2, u8){
@@ -1184,7 +1172,6 @@ test "array of vectors is copied" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const Vec3 = @Vector(3, i32);
     var points = [_]Vec3{
@@ -1255,7 +1242,6 @@ test "zero multiplicand" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const zeros = @Vector(2, u32){ 0.0, 0.0 };
     var ones = @Vector(2, u32){ 1.0, 1.0 };
@@ -1316,7 +1302,6 @@ test "load packed vector element" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
 
@@ -1347,7 +1332,6 @@ test "store to vector in slice" {
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var v = [_]@Vector(3, f32){
         .{ 1, 1, 1 },
@@ -1411,7 +1395,6 @@ test "store vector with memset" {
 test "addition of vectors represented as strings" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const V = @Vector(3, u8);
     const foo: V = "foo".*;
@@ -1437,7 +1420,6 @@ test "vector pointer is indexable" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO
 
     const V = @Vector(2, u32);
 
@@ -1478,7 +1460,6 @@ test "bitcast to vector with different child type" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO
 
     const S = struct {
         fn doTheTest() !void {