Commit 2e6a6d7564

Jacob Young <jacobly0@users.noreply.github.com>
2023-05-10 07:56:48
llvm/cbe: fix signed `@mod`/`@divFloor` computations
Closes #15636
1 parent 31a13ce
Changed files (3)
lib
src
codegen
test
behavior
lib/zig.h
@@ -487,14 +487,14 @@ typedef ptrdiff_t intptr_t;
     zig_basic_operator(uint##w##_t, div_floor_u##w, /) \
 \
     static inline int##w##_t zig_div_floor_i##w(int##w##_t lhs, int##w##_t rhs) { \
-        return lhs / rhs - (((lhs ^ rhs) & (lhs % rhs)) < INT##w##_C(0)); \
+        return lhs / rhs + (lhs % rhs != INT##w##_C(0) ? zig_shr_i##w(lhs ^ rhs, UINT8_C(w) - UINT8_C(1)) : INT##w##_C(0)); \
     } \
 \
     zig_basic_operator(uint##w##_t, mod_u##w, %) \
 \
     static inline int##w##_t zig_mod_i##w(int##w##_t lhs, int##w##_t rhs) { \
         int##w##_t rem = lhs % rhs; \
-        return rem + (((lhs ^ rhs) & rem) < INT##w##_C(0) ? rhs : INT##w##_C(0)); \
+        return rem + (rem != INT##w##_C(0) ? rhs & zig_shr_i##w(lhs ^ rhs, UINT8_C(w) - UINT8_C(1)) : INT##w##_C(0)); \
     } \
 \
     static inline uint##w##_t zig_shlw_u##w(uint##w##_t lhs, uint8_t rhs, uint8_t bits) { \
@@ -1078,7 +1078,7 @@ static inline int64_t zig_bit_reverse_i64(int64_t val, uint8_t bits) {
         uint##w##_t temp = val - ((val >> 1) & (UINT##w##_MAX / 3)); \
         temp = (temp & (UINT##w##_MAX / 5)) + ((temp >> 2) & (UINT##w##_MAX / 5)); \
         temp = (temp + (temp >> 4)) & (UINT##w##_MAX / 17); \
-        return temp * (UINT##w##_MAX / 255) >> (w - 8); \
+        return temp * (UINT##w##_MAX / 255) >> (UINT8_C(w) - UINT8_C(8)); \
     } \
 \
     zig_builtin_popcount_common(w)
@@ -1298,15 +1298,6 @@ static inline zig_i128 zig_rem_i128(zig_i128 lhs, zig_i128 rhs) {
     return lhs % rhs;
 }
 
-static inline zig_i128 zig_div_floor_i128(zig_i128 lhs, zig_i128 rhs) {
-    return zig_div_trunc_i128(lhs, rhs) - (((lhs ^ rhs) & zig_rem_i128(lhs, rhs)) < zig_make_i128(0, 0));
-}
-
-static inline zig_i128 zig_mod_i128(zig_i128 lhs, zig_i128 rhs) {
-    zig_i128 rem = zig_rem_i128(lhs, rhs);
-    return rem + (((lhs ^ rhs) & rem) < zig_make_i128(0, 0) ? rhs : zig_make_i128(0, 0));
-}
-
 #else /* zig_has_int128 */
 
 static inline zig_u128 zig_not_u128(zig_u128 val, uint8_t bits) {
@@ -1394,20 +1385,26 @@ static zig_i128 zig_rem_i128(zig_i128 lhs, zig_i128 rhs) {
     return __modti3(lhs, rhs);
 }
 
-static inline zig_i128 zig_mod_i128(zig_i128 lhs, zig_i128 rhs) {
-    zig_i128 rem = zig_rem_i128(lhs, rhs);
-    return zig_add_i128(rem, ((lhs.hi ^ rhs.hi) & rem.hi) < INT64_C(0) ? rhs : zig_make_i128(0, 0));
-}
+#endif /* zig_has_int128 */
+
+#define zig_div_floor_u128 zig_div_trunc_u128
 
 static inline zig_i128 zig_div_floor_i128(zig_i128 lhs, zig_i128 rhs) {
-    return zig_sub_i128(zig_div_trunc_i128(lhs, rhs), zig_make_i128(0, zig_cmp_i128(zig_and_i128(zig_xor_i128(lhs, rhs), zig_rem_i128(lhs, rhs)), zig_make_i128(0, 0)) < INT32_C(0)));
+    zig_i128 rem = zig_rem_i128(lhs, rhs);
+    int64_t mask = zig_or_u64((uint64_t)zig_hi_i128(rem), zig_lo_i128(rem)) != UINT64_C(0)
+        ? zig_shr_i64(zig_xor_i64(zig_hi_i128(lhs), zig_hi_i128(rhs)), UINT8_C(63)) : INT64_C(0);
+    return zig_add_i128(zig_div_trunc_i128(lhs, rhs), zig_make_i128(mask, (uint64_t)mask));
 }
 
-#endif /* zig_has_int128 */
-
-#define zig_div_floor_u128 zig_div_trunc_u128
 #define zig_mod_u128 zig_rem_u128
 
+static inline zig_i128 zig_mod_i128(zig_i128 lhs, zig_i128 rhs) {
+    zig_i128 rem = zig_rem_i128(lhs, rhs);
+    int64_t mask = zig_or_u64((uint64_t)zig_hi_i128(rem), zig_lo_i128(rem)) != UINT64_C(0)
+        ? zig_shr_i64(zig_xor_i64(zig_hi_i128(lhs), zig_hi_i128(rhs)), UINT8_C(63)) : INT64_C(0);
+    return zig_add_i128(rem, zig_and_i128(rhs, zig_make_i128(mask, (uint64_t)mask)));
+}
+
 static inline zig_u128 zig_min_u128(zig_u128 lhs, zig_u128 rhs) {
     return zig_cmp_u128(lhs, rhs) < INT32_C(0) ? lhs : rhs;
 }
src/codegen/llvm.zig
@@ -7215,20 +7215,28 @@ pub const FuncGen = struct {
             return self.buildFloatOp(.floor, inst_ty, 1, .{result});
         }
         if (scalar_ty.isSignedInt()) {
-            // const d = @divTrunc(a, b);
-            // const r = @rem(a, b);
-            // return if (r == 0) d else d - ((a < 0) ^ (b < 0));
-            const result_llvm_ty = try self.dg.lowerType(inst_ty);
-            const zero = result_llvm_ty.constNull();
-            const div_trunc = self.builder.buildSDiv(lhs, rhs, "");
+            const target = self.dg.module.getTarget();
+            const inst_llvm_ty = try self.dg.lowerType(inst_ty);
+            const scalar_bit_size_minus_one = scalar_ty.bitSize(target) - 1;
+            const bit_size_minus_one = if (inst_ty.zigTypeTag() == .Vector) const_vector: {
+                const vec_len = inst_ty.vectorLen();
+                const scalar_llvm_ty = try self.dg.lowerType(scalar_ty);
+
+                const shifts = try self.gpa.alloc(*llvm.Value, vec_len);
+                defer self.gpa.free(shifts);
+
+                @memset(shifts, scalar_llvm_ty.constInt(scalar_bit_size_minus_one, .False));
+                break :const_vector llvm.constVector(shifts.ptr, vec_len);
+            } else inst_llvm_ty.constInt(scalar_bit_size_minus_one, .False);
+
+            const div = self.builder.buildSDiv(lhs, rhs, "");
             const rem = self.builder.buildSRem(lhs, rhs, "");
-            const rem_eq_0 = self.builder.buildICmp(.EQ, rem, zero, "");
-            const a_lt_0 = self.builder.buildICmp(.SLT, lhs, zero, "");
-            const b_lt_0 = self.builder.buildICmp(.SLT, rhs, zero, "");
-            const a_b_xor = self.builder.buildXor(a_lt_0, b_lt_0, "");
-            const a_b_xor_ext = self.builder.buildZExt(a_b_xor, div_trunc.typeOf(), "");
-            const d_sub_xor = self.builder.buildSub(div_trunc, a_b_xor_ext, "");
-            return self.builder.buildSelect(rem_eq_0, div_trunc, d_sub_xor, "");
+            const div_sign = self.builder.buildXor(lhs, rhs, "");
+            const div_sign_mask = self.builder.buildAShr(div_sign, bit_size_minus_one, "");
+            const zero = inst_llvm_ty.constNull();
+            const rem_nonzero = self.builder.buildICmp(.NE, rem, zero, "");
+            const correction = self.builder.buildSelect(rem_nonzero, div_sign_mask, zero, "");
+            return self.builder.buildNSWAdd(div, correction, "");
         }
         return self.builder.buildUDiv(lhs, rhs, "");
     }
@@ -7280,12 +7288,27 @@ pub const FuncGen = struct {
             return self.builder.buildSelect(ltz, c, a, "");
         }
         if (scalar_ty.isSignedInt()) {
-            const a = self.builder.buildSRem(lhs, rhs, "");
-            const b = self.builder.buildNSWAdd(a, rhs, "");
-            const c = self.builder.buildSRem(b, rhs, "");
+            const target = self.dg.module.getTarget();
+            const scalar_bit_size_minus_one = scalar_ty.bitSize(target) - 1;
+            const bit_size_minus_one = if (inst_ty.zigTypeTag() == .Vector) const_vector: {
+                const vec_len = inst_ty.vectorLen();
+                const scalar_llvm_ty = try self.dg.lowerType(scalar_ty);
+
+                const shifts = try self.gpa.alloc(*llvm.Value, vec_len);
+                defer self.gpa.free(shifts);
+
+                @memset(shifts, scalar_llvm_ty.constInt(scalar_bit_size_minus_one, .False));
+                break :const_vector llvm.constVector(shifts.ptr, vec_len);
+            } else inst_llvm_ty.constInt(scalar_bit_size_minus_one, .False);
+
+            const rem = self.builder.buildSRem(lhs, rhs, "");
+            const div_sign = self.builder.buildXor(lhs, rhs, "");
+            const div_sign_mask = self.builder.buildAShr(div_sign, bit_size_minus_one, "");
+            const rhs_masked = self.builder.buildAnd(rhs, div_sign_mask, "");
             const zero = inst_llvm_ty.constNull();
-            const ltz = self.builder.buildICmp(.SLT, lhs, zero, "");
-            return self.builder.buildSelect(ltz, c, a, "");
+            const rem_nonzero = self.builder.buildICmp(.NE, rem, zero, "");
+            const correction = self.builder.buildSelect(rem_nonzero, rhs_masked, zero, "");
+            return self.builder.buildNSWAdd(rem, correction, "");
         }
         return self.builder.buildURem(lhs, rhs, "");
     }
test/behavior/math.zig
@@ -449,6 +449,9 @@ fn testDivision() !void {
     try expect(mod(i32, 10, 12) == 10);
     try expect(mod(i32, -14, 12) == 10);
     try expect(mod(i32, -2, 12) == 10);
+    try expect(mod(i32, 10, -12) == -2);
+    try expect(mod(i32, -14, -12) == -2);
+    try expect(mod(i32, -2, -12) == -2);
 
     comptime {
         try expect(