Commit b67d983abd

Robin Voetter <robin@voetter.nl>
2024-01-19 23:56:02
spirv: vectorize add/sub overflow
1 parent 761594e
Changed files (2)
src
codegen
test
behavior
src/codegen/spirv.zig
@@ -2582,103 +2582,108 @@ const DeclGen = struct {
         const lhs = try self.resolve(extra.lhs);
         const rhs = try self.resolve(extra.rhs);
 
-        const operand_ty = self.typeOf(extra.lhs);
         const result_ty = self.typeOfIndex(inst);
+        const operand_ty = self.typeOf(extra.lhs);
+        const ov_ty = result_ty.structFieldType(1, self.module);
+
+        const bool_ty_ref = try self.resolveType(Type.bool, .direct);
 
         const info = try self.arithmeticTypeInfo(operand_ty);
         switch (info.class) {
             .composite_integer => return self.todo("overflow ops for composite integers", .{}),
-            .strange_integer => return self.todo("overflow ops for strange integers", .{}),
-            .integer => {},
+            .strange_integer, .integer => {},
             .float, .bool => unreachable,
         }
 
-        // The operand type must be the same as the result type in SPIR-V, which
-        // is the same as in Zig.
-        const operand_ty_ref = try self.resolveType(operand_ty, .direct);
-        const operand_ty_id = self.typeId(operand_ty_ref);
+        var wip_result = try self.elementWise(operand_ty);
+        defer wip_result.deinit();
+        var wip_ov = try self.elementWise(ov_ty);
+        defer wip_ov.deinit();
+        for (wip_result.results, wip_ov.results, 0..) |*value_id, *ov_id, i| {
+            const lhs_elem_id = try wip_result.elementAt(operand_ty, lhs, i);
+            const rhs_elem_id = try wip_result.elementAt(operand_ty, rhs, i);
 
-        const bool_ty_ref = try self.resolveType(Type.bool, .direct);
+            // Normalize both so that we can properly check for overflow
+            const lhs_norm_id = try self.normalizeInt(wip_result.scalar_ty_ref, lhs_elem_id, info);
+            const rhs_norm_id = try self.normalizeInt(wip_result.scalar_ty_ref, rhs_elem_id, info);
+            const op_result_id = self.spv.allocId();
 
-        const ov_ty = result_ty.structFieldType(1, self.module);
-        // Note: result is stored in a struct, so indirect representation.
-        const ov_ty_ref = try self.resolveType(ov_ty, .indirect);
-
-        // TODO: Operations other than addition.
-        const value_id = self.spv.allocId();
-        try self.func.body.emit(self.spv.gpa, add, .{
-            .id_result_type = operand_ty_id,
-            .id_result = value_id,
-            .operand_1 = lhs,
-            .operand_2 = rhs,
-        });
+            try self.func.body.emit(self.spv.gpa, add, .{
+                .id_result_type = wip_result.scalar_ty_id,
+                .id_result = op_result_id,
+                .operand_1 = lhs_norm_id,
+                .operand_2 = rhs_norm_id,
+            });
 
-        const overflowed_id = switch (info.signedness) {
-            .unsigned => blk: {
-                // Overflow happened if the result is smaller than either of the operands. It doesn't matter which.
-                // For subtraction the conditions need to be swapped.
-                const overflowed_id = self.spv.allocId();
-                try self.func.body.emit(self.spv.gpa, ucmp, .{
-                    .id_result_type = self.typeId(bool_ty_ref),
-                    .id_result = overflowed_id,
-                    .operand_1 = value_id,
-                    .operand_2 = lhs,
-                });
-                break :blk overflowed_id;
-            },
-            .signed => blk: {
-                // lhs - rhs
-                // For addition, overflow happened if:
-                // - rhs is negative and value > lhs
-                // - rhs is positive and value < lhs
-                // This can be shortened to:
-                //   (rhs < 0 and value > lhs) or (rhs >= 0 and value <= lhs)
-                // = (rhs < 0) == (value > lhs)
-                // = (rhs < 0) == (lhs < value)
-                // Note that signed overflow is also wrapping in spir-v.
-                // For subtraction, overflow happened if:
-                // - rhs is negative and value < lhs
-                // - rhs is positive and value > lhs
-                // This can be shortened to:
-                //   (rhs < 0 and value < lhs) or (rhs >= 0 and value >= lhs)
-                // = (rhs < 0) == (value < lhs)
-                // = (rhs < 0) == (lhs > value)
-
-                const rhs_lt_zero_id = self.spv.allocId();
-                const zero_id = try self.constInt(operand_ty_ref, 0);
-                try self.func.body.emit(self.spv.gpa, .OpSLessThan, .{
-                    .id_result_type = self.typeId(bool_ty_ref),
-                    .id_result = rhs_lt_zero_id,
-                    .operand_1 = rhs,
-                    .operand_2 = zero_id,
-                });
+            // Normalize the result so that the comparisons go well
+            value_id.* = try self.normalizeInt(wip_result.scalar_ty_ref, op_result_id, info);
+
+            const overflowed_id = switch (info.signedness) {
+                .unsigned => blk: {
+                    // Overflow happened if the result is smaller than either of the operands. It doesn't matter which.
+                    // For subtraction the conditions need to be swapped.
+                    const overflowed_id = self.spv.allocId();
+                    try self.func.body.emit(self.spv.gpa, ucmp, .{
+                        .id_result_type = self.typeId(bool_ty_ref),
+                        .id_result = overflowed_id,
+                        .operand_1 = value_id.*,
+                        .operand_2 = lhs_norm_id,
+                    });
+                    break :blk overflowed_id;
+                },
+                .signed => blk: {
+                    // lhs - rhs
+                    // For addition, overflow happened if:
+                    // - rhs is negative and value > lhs
+                    // - rhs is positive and value < lhs
+                    // This can be shortened to:
+                    //   (rhs < 0 and value > lhs) or (rhs >= 0 and value <= lhs)
+                    // = (rhs < 0) == (value > lhs)
+                    // = (rhs < 0) == (lhs < value)
+                    // Note that signed overflow is also wrapping in spir-v.
+                    // For subtraction, overflow happened if:
+                    // - rhs is negative and value < lhs
+                    // - rhs is positive and value > lhs
+                    // This can be shortened to:
+                    //   (rhs < 0 and value < lhs) or (rhs >= 0 and value >= lhs)
+                    // = (rhs < 0) == (value < lhs)
+                    // = (rhs < 0) == (lhs > value)
+
+                    const rhs_lt_zero_id = self.spv.allocId();
+                    const zero_id = try self.constInt(wip_result.scalar_ty_ref, 0);
+                    try self.func.body.emit(self.spv.gpa, .OpSLessThan, .{
+                        .id_result_type = self.typeId(bool_ty_ref),
+                        .id_result = rhs_lt_zero_id,
+                        .operand_1 = rhs_norm_id,
+                        .operand_2 = zero_id,
+                    });
 
-                const value_gt_lhs_id = self.spv.allocId();
-                try self.func.body.emit(self.spv.gpa, scmp, .{
-                    .id_result_type = self.typeId(bool_ty_ref),
-                    .id_result = value_gt_lhs_id,
-                    .operand_1 = lhs,
-                    .operand_2 = value_id,
-                });
+                    const value_gt_lhs_id = self.spv.allocId();
+                    try self.func.body.emit(self.spv.gpa, scmp, .{
+                        .id_result_type = self.typeId(bool_ty_ref),
+                        .id_result = value_gt_lhs_id,
+                        .operand_1 = lhs_norm_id,
+                        .operand_2 = value_id.*,
+                    });
 
-                const overflowed_id = self.spv.allocId();
-                try self.func.body.emit(self.spv.gpa, .OpLogicalEqual, .{
-                    .id_result_type = self.typeId(bool_ty_ref),
-                    .id_result = overflowed_id,
-                    .operand_1 = rhs_lt_zero_id,
-                    .operand_2 = value_gt_lhs_id,
-                });
-                break :blk overflowed_id;
-            },
-        };
+                    const overflowed_id = self.spv.allocId();
+                    try self.func.body.emit(self.spv.gpa, .OpLogicalEqual, .{
+                        .id_result_type = self.typeId(bool_ty_ref),
+                        .id_result = overflowed_id,
+                        .operand_1 = rhs_lt_zero_id,
+                        .operand_2 = value_gt_lhs_id,
+                    });
+                    break :blk overflowed_id;
+                },
+            };
+
+            ov_id.* = try self.intFromBool(wip_ov.scalar_ty_ref, overflowed_id);
+        }
 
-        // Construct the struct that Zig wants as result.
-        // The value should already be the correct type.
-        const ov_id = try self.intFromBool(ov_ty_ref, overflowed_id);
         return try self.constructStruct(
             result_ty,
             &.{ operand_ty, ov_ty },
-            &.{ value_id, ov_id },
+            &.{ try wip_result.finalize(), try wip_ov.finalize() },
         );
     }
 
test/behavior/vector.zig
@@ -259,7 +259,6 @@ test "tuple to vector" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     if (builtin.zig_backend == .stage2_llvm and builtin.cpu.arch == .aarch64) {
         // Regressed with LLVM 14:
@@ -1063,7 +1062,7 @@ test "@addWithOverflow" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+    // if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -1111,7 +1110,6 @@ test "@subWithOverflow" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {