Commit 2f815853dc

Robin Voetter <robin@voetter.nl>
2024-01-16 23:06:15
spirv: shlWithOverflow
1 parent 15cf5f8
Changed files (3)
src
codegen
test
src/codegen/spirv.zig
@@ -1782,19 +1782,6 @@ const DeclGen = struct {
             wip.dg.gpa.free(wip.results);
         }
 
-        /// Return the scalar type of an input vector. This type is expected to be a vector
-        /// if `wip.is_vector`, and a scalar otherwise.
-        fn scalarType(wip: WipElementWise, ty: Type) Type {
-            const mod = wip.dg.module;
-            if (wip.is_vector) {
-                assert(ty.isVector(mod));
-                return ty.childType(mod);
-            } else {
-                assert(!ty.isVector(mod));
-                return ty;
-            }
-        }
-
         /// Utility function to extract the element at a particular index in an
         /// input vector. This type is expected to be a vector if `wip.is_vector`, and
         /// a scalar otherwise.
@@ -1844,7 +1831,7 @@ const DeclGen = struct {
         const results = try self.gpa.alloc(IdRef, num_results);
         for (results) |*result| result.* = undefined;
 
-        const scalar_ty = if (is_vector) result_ty.childType(mod) else result_ty;
+        const scalar_ty = result_ty.scalarType(mod);
         const scalar_ty_ref = try self.resolveType(scalar_ty, .direct);
 
         return .{
@@ -2198,6 +2185,7 @@ const DeclGen = struct {
 
             .add_with_overflow => try self.airAddSubOverflow(inst, .OpIAdd, .OpULessThan, .OpSLessThan),
             .sub_with_overflow => try self.airAddSubOverflow(inst, .OpISub, .OpUGreaterThan, .OpSGreaterThan),
+            .shl_with_overflow => try self.airShlOverflow(inst),
 
             .shuffle => try self.airShuffle(inst),
 
@@ -2343,23 +2331,30 @@ const DeclGen = struct {
         const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
         const lhs_id = try self.resolve(bin_op.lhs);
         const rhs_id = try self.resolve(bin_op.rhs);
+
         const result_ty = self.typeOfIndex(inst);
+        const shift_ty = self.typeOf(bin_op.rhs);
+        const scalar_shift_ty_ref = try self.resolveType(shift_ty.scalarType(mod), .direct);
 
-        // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that,
-        // so just manually upcast it if required.
-        // TODO(robin)
+        const info = try self.arithmeticTypeInfo(result_ty);
+        switch (info.class) {
+            .composite_integer => return self.todo("shift ops for composite integers", .{}),
+            .integer, .strange_integer => {},
+            .float, .bool => unreachable,
+        }
 
         var wip = try self.elementWise(result_ty);
         defer wip.deinit();
-
-        const shift_ty = wip.scalarType(self.typeOf(bin_op.rhs));
-        const shift_ty_ref = try self.resolveType(shift_ty, .direct);
-
         for (0..wip.results.len) |i| {
             const lhs_elem_id = try wip.elementAt(result_ty, lhs_id, i);
-            const rhs_elem_id = try wip.elementAt(result_ty, rhs_id, i);
+            const rhs_elem_id = try wip.elementAt(shift_ty, rhs_id, i);
+
+            // TODO: Can we omit normalizing lhs?
+            const lhs_norm_id = try self.normalizeInt(wip.scalar_ty_ref, lhs_elem_id, info);
 
-            const shift_id = if (shift_ty_ref != wip.result_ty_ref) blk: {
+            // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that,
+            // so just manually upcast it if required.
+            const shift_id = if (scalar_shift_ty_ref != wip.scalar_ty_ref) blk: {
                 const shift_id = self.spv.allocId();
                 try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
                     .id_result_type = wip.scalar_ty_id,
@@ -2368,12 +2363,13 @@ const DeclGen = struct {
                 });
                 break :blk shift_id;
             } else rhs_elem_id;
+            const shift_norm_id = try self.normalizeInt(wip.scalar_ty_ref, shift_id, info);
 
             const args = .{
                 .id_result_type = wip.scalar_ty_id,
                 .id_result = wip.allocId(i),
-                .base = lhs_elem_id,
-                .shift = shift_id,
+                .base = lhs_norm_id,
+                .shift = shift_norm_id,
             };
 
             if (result_ty.isSignedInt(mod)) {
@@ -2680,6 +2676,88 @@ const DeclGen = struct {
         );
     }
 
+    fn airShlOverflow(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+        if (self.liveness.isUnused(inst)) return null;
+        const mod = self.module;
+        const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
+        const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;
+        const lhs = try self.resolve(extra.lhs);
+        const rhs = try self.resolve(extra.rhs);
+
+        const result_ty = self.typeOfIndex(inst);
+        const operand_ty = self.typeOf(extra.lhs);
+        const shift_ty = self.typeOf(extra.rhs);
+        const scalar_shift_ty_ref = try self.resolveType(shift_ty.scalarType(mod), .direct);
+
+        const ov_ty = result_ty.structFieldType(1, self.module);
+
+        const bool_ty_ref = try self.resolveType(Type.bool, .direct);
+
+        const info = try self.arithmeticTypeInfo(operand_ty);
+        switch (info.class) {
+            .composite_integer => return self.todo("overflow shift for composite integers", .{}),
+            .integer, .strange_integer => {},
+            .float, .bool => unreachable,
+        }
+
+        var wip_result = try self.elementWise(operand_ty);
+        defer wip_result.deinit();
+        var wip_ov = try self.elementWise(ov_ty);
+        defer wip_ov.deinit();
+        for (0..wip_result.results.len, wip_ov.results) |i, *ov_id| {
+            const lhs_elem_id = try wip_result.elementAt(operand_ty, lhs, i);
+            const rhs_elem_id = try wip_result.elementAt(shift_ty, rhs, i);
+
+            // Normalize both so that we can shift back and check if the result is the same.
+            const lhs_norm_id = try self.normalizeInt(wip_result.scalar_ty_ref, lhs_elem_id, info);
+
+            // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that,
+            // so just manually upcast it if required.
+            const shift_id = if (scalar_shift_ty_ref != wip_result.scalar_ty_ref) blk: {
+                const shift_id = self.spv.allocId();
+                try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
+                    .id_result_type = wip_result.scalar_ty_id,
+                    .id_result = shift_id,
+                    .unsigned_value = rhs_elem_id,
+                });
+                break :blk shift_id;
+            } else rhs_elem_id;
+            const shift_norm_id = try self.normalizeInt(wip_result.scalar_ty_ref, shift_id, info);
+
+            try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{
+                .id_result_type = wip_result.scalar_ty_id,
+                .id_result = wip_result.allocId(i),
+                .base = lhs_norm_id,
+                .shift = shift_norm_id,
+            });
+
+            // To check if overflow happened, just check if the right-shifted result is the same value.
+            const right_shift_id = self.spv.allocId();
+            try self.func.body.emit(self.spv.gpa, .OpShiftRightLogical, .{
+                .id_result_type = wip_result.scalar_ty_id,
+                .id_result = right_shift_id,
+                .base = try self.normalizeInt(wip_result.scalar_ty_ref, wip_result.results[i], info),
+                .shift = shift_norm_id,
+            });
+
+            const overflowed_id = self.spv.allocId();
+            try self.func.body.emit(self.spv.gpa, .OpINotEqual, .{
+                .id_result_type = self.typeId(bool_ty_ref),
+                .id_result = overflowed_id,
+                .operand_1 = lhs_norm_id,
+                .operand_2 = right_shift_id,
+            });
+
+            ov_id.* = try self.intFromBool(wip_ov.scalar_ty_ref, overflowed_id);
+        }
+
+        return try self.constructStruct(
+            result_ty,
+            &.{ operand_ty, ov_ty },
+            &.{ try wip_result.finalize(), try wip_ov.finalize() },
+        );
+    }
+
     fn airShuffle(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
         const mod = self.module;
         if (self.liveness.isUnused(inst)) return null;
test/behavior/math.zig
@@ -1328,8 +1328,6 @@ fn testShlTrunc(x: u16) !void {
 }
 
 test "exact shift left" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     try testShlExact(0b00110101);
     try comptime testShlExact(0b00110101);
 
test/behavior/vector.zig
@@ -179,7 +179,6 @@ test "array vector coercion - odd sizes" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
 
@@ -219,7 +218,6 @@ test "array to vector with element type coercion" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf) return error.SkipZigTest;
 
@@ -659,7 +657,6 @@ test "vector shift operators" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTestShift(x: anytype, y: anytype) !void {
@@ -1168,7 +1165,6 @@ test "@shlWithOverflow" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -1453,7 +1449,6 @@ test "compare vectors with different element types" {
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO
 
     var a: @Vector(2, u8) = .{ 1, 2 };
     var b: @Vector(2, u9) = .{ 3, 0 };