Commit 403c6262bb

Robin Voetter <robin@voetter.nl>
2024-01-15 23:38:43
spirv: use new vector stuff for arithOp and shift
1 parent cb9e20d
Changed files (1)
src
codegen
src/codegen/spirv.zig
@@ -1782,6 +1782,19 @@ const DeclGen = struct {
             wip.dg.gpa.free(wip.results);
         }
 
+        /// Return the scalar type of an input vector. This type is expected to be a vector
+        /// if `wip.is_vector`, and a scalar otherwise.
+        fn scalarType(wip: WipElementWise, ty: Type) Type {
+            const mod = wip.dg.module;
+            if (wip.is_vector) {
+                assert(ty.isVector(mod));
+                return ty.childType(mod);
+            } else {
+                assert(!ty.isVector(mod));
+                return ty;
+            }
+        }
+
         /// Utility function to extract the element at a particular index in an
         /// input vector. This type is expected to be a vector if `wip.is_vector`, and
         /// a scalar otherwise.
@@ -1789,7 +1802,7 @@ const DeclGen = struct {
             const mod = wip.dg.module;
             if (wip.is_vector) {
                 assert(ty.isVector(mod));
-                return try wip.dg.extractField(ty, value, @intCast(index));
+                return try wip.dg.extractField(ty.childType(mod), value, @intCast(index));
             } else {
                 assert(!ty.isVector(mod));
                 assert(index == 0);
@@ -2331,36 +2344,45 @@ const DeclGen = struct {
         const lhs_id = try self.resolve(bin_op.lhs);
         const rhs_id = try self.resolve(bin_op.rhs);
         const result_ty = self.typeOfIndex(inst);
-        const result_ty_ref = try self.resolveType(result_ty, .direct);
-
-        const result_id = self.spv.allocId();
 
         // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that,
         // so just manually upcast it if required.
-        const shift_ty_ref = try self.resolveType(self.typeOf(bin_op.rhs), .direct);
-        const shift_id = if (shift_ty_ref != result_ty_ref) blk: {
-            const shift_id = self.spv.allocId();
-            try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
-                .id_result_type = self.typeId(result_ty_ref),
-                .id_result = shift_id,
-                .unsigned_value = rhs_id,
-            });
-            break :blk shift_id;
-        } else rhs_id;
+        // TODO(robin)
 
-        const args = .{
-            .id_result_type = self.typeId(result_ty_ref),
-            .id_result = result_id,
-            .base = lhs_id,
-            .shift = shift_id,
-        };
+        var wip = try self.elementWise(result_ty);
+        defer wip.deinit();
 
-        if (result_ty.isSignedInt(mod)) {
-            try self.func.body.emit(self.spv.gpa, signed, args);
-        } else {
-            try self.func.body.emit(self.spv.gpa, unsigned, args);
+        const shift_ty = wip.scalarType(self.typeOf(bin_op.rhs));
+        const shift_ty_ref = try self.resolveType(shift_ty, .direct);
+
+        for (0..wip.results.len) |i| {
+            const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i);
+            const rhs_elem_id = try wip.elementAt(result_ty, rhs_id, i);
+
+            const shift_id = if (shift_ty_ref != wip.result_ty_ref) blk: {
+                const shift_id = self.spv.allocId();
+                try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
+                    .id_result_type = wip.scalar_ty_id,
+                    .id_result = shift_id,
+                    .unsigned_value = rhs_elem_id,
+                });
+                break :blk shift_id;
+            } else rhs_elem_id;
+
+            const args = .{
+                .id_result_type = wip.scalar_ty_id,
+                .id_result = wip.allocId(i),
+                .base = lhs_elem_id,
+                .shift = shift_id,
+            };
+
+            if (result_ty.isSignedInt(mod)) {
+                try self.func.body.emit(self.spv.gpa, signed, args);
+            } else {
+                try self.func.body.emit(self.spv.gpa, unsigned, args);
+            }
         }
-        return result_id;
+        return try wip.finalize();
     }
 
     fn airMinMax(self: *DeclGen, inst: Air.Inst.Index, op: std.math.CompareOperator) !?IdRef {
@@ -2483,35 +2505,14 @@ const DeclGen = struct {
     fn arithOp(
         self: *DeclGen,
         ty: Type,
-        lhs_id_: IdRef,
-        rhs_id_: IdRef,
+        lhs_id: IdRef,
+        rhs_id: IdRef,
         comptime fop: Opcode,
         comptime sop: Opcode,
         comptime uop: Opcode,
         /// true if this operation holds under modular arithmetic.
         comptime modular: bool,
     ) !IdRef {
-        var rhs_id = rhs_id_;
-        var lhs_id = lhs_id_;
-
-        const mod = self.module;
-        const result_ty_ref = try self.resolveType(ty, .direct);
-
-        if (ty.isVector(mod)) {
-            const child_ty = ty.childType(mod);
-            const vector_len = ty.vectorLen(mod);
-            const constituents = try self.gpa.alloc(IdRef, vector_len);
-            defer self.gpa.free(constituents);
-
-            for (constituents, 0..) |*constituent, i| {
-                const lhs_index_id = try self.extractField(child_ty, lhs_id, @intCast(i));
-                const rhs_index_id = try self.extractField(child_ty, rhs_id, @intCast(i));
-                constituent.* = try self.arithOp(child_ty, lhs_index_id, rhs_index_id, fop, sop, uop, modular);
-            }
-
-            return self.constructArray(ty, constituents);
-        }
-
         // Binary operations are generally applicable to both scalar and vector operations
         // in SPIR-V, but int and float versions of operations require different opcodes.
         const info = try self.arithmeticTypeInfo(ty);
@@ -2520,17 +2521,7 @@ const DeclGen = struct {
             .composite_integer => {
                 return self.todo("binary operations for composite integers", .{});
             },
-            .strange_integer => blk: {
-                if (!modular) {
-                    lhs_id = try self.normalizeInt(result_ty_ref, lhs_id, info);
-                    rhs_id = try self.normalizeInt(result_ty_ref, rhs_id, info);
-                }
-                break :blk switch (info.signedness) {
-                    .signed => @as(usize, 1),
-                    .unsigned => @as(usize, 2),
-                };
-            },
-            .integer => switch (info.signedness) {
+            .integer, .strange_integer => switch (info.signedness) {
                 .signed => @as(usize, 1),
                 .unsigned => @as(usize, 2),
             },
@@ -2538,24 +2529,41 @@ const DeclGen = struct {
             .bool => unreachable,
         };
 
-        const result_id = self.spv.allocId();
-        const operands = .{
-            .id_result_type = self.typeId(result_ty_ref),
-            .id_result = result_id,
-            .operand_1 = lhs_id,
-            .operand_2 = rhs_id,
-        };
+        var wip = try self.elementWise(ty);
+        defer wip.deinit();
+        for (0..wip.results.len) |i| {
+            const lhs_elem_id = try wip.elementAt(ty, lhs_id, i);
+            const rhs_elem_id = try wip.elementAt(ty, rhs_id, i);
 
-        switch (opcode_index) {
-            0 => try self.func.body.emit(self.spv.gpa, fop, operands),
-            1 => try self.func.body.emit(self.spv.gpa, sop, operands),
-            2 => try self.func.body.emit(self.spv.gpa, uop, operands),
-            else => unreachable,
+            const lhs_norm_id = if (modular and info.class == .strange_integer)
+                try self.normalizeInt(wip.scalar_ty_ref, lhs_elem_id, info)
+            else
+                lhs_elem_id;
+
+            const rhs_norm_id = if (modular and info.class == .strange_integer)
+                try self.normalizeInt(wip.scalar_ty_ref, rhs_elem_id, info)
+            else
+                rhs_elem_id;
+
+            const operands = .{
+                .id_result_type = wip.scalar_ty_id,
+                .id_result = wip.allocId(i),
+                .operand_1 = lhs_norm_id,
+                .operand_2 = rhs_norm_id,
+            };
+
+            switch (opcode_index) {
+                0 => try self.func.body.emit(self.spv.gpa, fop, operands),
+                1 => try self.func.body.emit(self.spv.gpa, sop, operands),
+                2 => try self.func.body.emit(self.spv.gpa, uop, operands),
+                else => unreachable,
+            }
+
+            // TODO: Trap on overflow? Probably going to be annoying.
+            // TODO: Look into SPV_KHR_no_integer_wrap_decoration which provides NoSignedWrap/NoUnsignedWrap.
         }
-        // TODO: Trap on overflow? Probably going to be annoying.
-        // TODO: Look into SPV_KHR_no_integer_wrap_decoration which provides NoSignedWrap/NoUnsignedWrap.
 
-        return result_id;
+        return try wip.finalize();
     }
 
     fn airAddSubOverflow(