Commit b1499df1b8

Robin Voetter <robin@voetter.nl>
2023-10-08 18:52:58
spirv: sign-extension for strange integers
1 parent dc44baf
Changed files (2)
src
codegen
test
behavior
src/codegen/spirv.zig
@@ -270,6 +270,12 @@ const DeclGen = struct {
         /// This is the actual number of bits of the type, not the size of the backing integer.
         bits: u16,
 
+        /// The number of bits required to store the type.
+        /// For `integer` and `float`, this is equal to `bits`.
+        /// For `strange_integer` and `bool` this is the size of the backing integer.
+        /// For `composite_integer` this is 0 (TODO)
+        backing_bits: u16,
+
         /// Whether the type is a vector.
         is_vector: bool,
 
@@ -499,12 +505,14 @@ const DeclGen = struct {
         return switch (ty.zigTypeTag(mod)) {
             .Bool => ArithmeticTypeInfo{
                 .bits = 1, // Doesn't matter for this class.
+                .backing_bits = self.backingIntBits(1).?,
                 .is_vector = false,
                 .signedness = .unsigned, // Technically, but doesn't matter for this class.
                 .class = .bool,
             },
             .Float => ArithmeticTypeInfo{
                 .bits = ty.floatBits(target),
+                .backing_bits = ty.floatBits(target), // TODO: F80?
                 .is_vector = false,
                 .signedness = .signed, // Technically, but doesn't matter for this class.
                 .class = .float,
@@ -515,6 +523,7 @@ const DeclGen = struct {
                 const maybe_backing_bits = self.backingIntBits(int_info.bits);
                 break :blk ArithmeticTypeInfo{
                     .bits = int_info.bits,
+                    .backing_bits = maybe_backing_bits orelse 0,
                     .is_vector = false,
                     .signedness = int_info.signedness,
                     .class = if (maybe_backing_bits) |backing_bits|
@@ -2154,17 +2163,48 @@ const DeclGen = struct {
         return result_id;
     }
 
-    fn maskStrangeInt(self: *DeclGen, ty_ref: CacheRef, value_id: IdRef, bits: u16) !IdRef {
-        const mask_value = if (bits == 64) 0xFFFF_FFFF_FFFF_FFFF else (@as(u64, 1) << @as(u6, @intCast(bits))) - 1;
-        const result_id = self.spv.allocId();
-        const mask_id = try self.constInt(ty_ref, mask_value);
-        try self.func.body.emit(self.spv.gpa, .OpBitwiseAnd, .{
-            .id_result_type = self.typeId(ty_ref),
-            .id_result = result_id,
-            .operand_1 = value_id,
-            .operand_2 = mask_id,
-        });
-        return result_id;
+    /// This function canonicalizes a "strange" integer value:
+    /// For unsigned integers, the value is masked so that only the relevant bits can contain
+    /// non-zeros.
+    /// For signed integers, the value is also sign extended.
+    fn normalizeInt(self: *DeclGen, ty_ref: CacheRef, value_id: IdRef, info: ArithmeticTypeInfo) !IdRef {
+        if (info.bits == info.backing_bits) {
+            return value_id;
+        }
+
+        switch (info.signedness) {
+            .unsigned => {
+                const mask_value = if (info.bits == 64) 0xFFFF_FFFF_FFFF_FFFF else (@as(u64, 1) << @as(u6, @intCast(info.bits))) - 1;
+                const result_id = self.spv.allocId();
+                const mask_id = try self.constInt(ty_ref, mask_value);
+                try self.func.body.emit(self.spv.gpa, .OpBitwiseAnd, .{
+                    .id_result_type = self.typeId(ty_ref),
+                    .id_result = result_id,
+                    .operand_1 = value_id,
+                    .operand_2 = mask_id,
+                });
+                return result_id;
+            },
+            .signed => {
+                // Shift left and right so that we can copy the sight bit that way.
+                const shift_amt_id = try self.constInt(ty_ref, info.backing_bits - info.bits);
+                const left_id = self.spv.allocId();
+                try self.func.body.emit(self.spv.gpa, .OpShiftLeftLogical, .{
+                    .id_result_type = self.typeId(ty_ref),
+                    .id_result = left_id,
+                    .base = value_id,
+                    .shift = shift_amt_id,
+                });
+                const right_id = self.spv.allocId();
+                try self.func.body.emit(self.spv.gpa, .OpShiftRightArithmetic, .{
+                    .id_result_type = self.typeId(ty_ref),
+                    .id_result = right_id,
+                    .base = left_id,
+                    .shift = shift_amt_id,
+                });
+                return right_id;
+            },
+        }
     }
 
     fn airArithOp(
@@ -2199,8 +2239,8 @@ const DeclGen = struct {
             },
             .strange_integer => blk: {
                 if (!modular) {
-                    lhs_id = try self.maskStrangeInt(result_ty_ref, lhs_id, info.bits);
-                    rhs_id = try self.maskStrangeInt(result_ty_ref, rhs_id, info.bits);
+                    lhs_id = try self.normalizeInt(result_ty_ref, lhs_id, info);
+                    rhs_id = try self.normalizeInt(result_ty_ref, rhs_id, info);
                 }
                 break :blk switch (info.signedness) {
                     .signed => @as(usize, 1),
@@ -2565,8 +2605,8 @@ const DeclGen = struct {
                 .strange_integer => sign: {
                     const op_ty_ref = try self.resolveType(op_ty, .direct);
                     // Mask operands before performing comparison.
-                    cmp_lhs_id = try self.maskStrangeInt(op_ty_ref, cmp_lhs_id, info.bits);
-                    cmp_rhs_id = try self.maskStrangeInt(op_ty_ref, cmp_rhs_id, info.bits);
+                    cmp_lhs_id = try self.normalizeInt(op_ty_ref, cmp_lhs_id, info);
+                    cmp_rhs_id = try self.normalizeInt(op_ty_ref, cmp_rhs_id, info);
                     break :sign info.signedness;
                 },
                 .integer => info.signedness,
test/behavior/math.zig
@@ -468,6 +468,9 @@ fn testDivision() !void {
     try expect(mod(i32, -14, -12) == -2);
     try expect(mod(i32, -2, -12) == -2);
 
+    try expect(divTrunc(i20, 20, -5) == -4);
+    try expect(divTrunc(i20, -20, -4) == 5);
+
     comptime {
         try expect(
             1194735857077236777412821811143690633098347576 % 508740759824825164163191790951174292733114988 == 177254337427586449086438229241342047632117600,