Commit 947dd36341

Marc Tiehuis <marc@tiehu.is>
2023-04-13 08:20:17
optimize udivmod
See https://reviews.llvm.org/D81809 for upstream description. In summary this is ~10x improvement for small divisors and similar performance for equal divisors. Closes #13523.
1 parent 0f5aff3
Changed files (1)
lib
compiler_rt
lib/compiler_rt/udivmod.zig
@@ -1,201 +1,157 @@
+const std = @import("std");
 const builtin = @import("builtin");
 const is_test = builtin.is_test;
-const native_endian = builtin.cpu.arch.endian();
-const std = @import("std");
+const Log2Int = std.math.Log2Int;
 
-const low = switch (native_endian) {
+const lo = switch (builtin.cpu.arch.endian()) {
     .Big => 1,
     .Little => 0,
 };
-const high = 1 - low;
+const hi = 1 - lo;
 
-pub fn udivmod(comptime DoubleInt: type, a: DoubleInt, b: DoubleInt, maybe_rem: ?*DoubleInt) DoubleInt {
+fn HalfInt(comptime T: type) type {
+    std.debug.assert(@typeInfo(T) == .Int);
+    std.debug.assert(@bitSizeOf(T) % 2 == 0);
+    return std.meta.Int(.unsigned, @bitSizeOf(T) / 2);
+}
+
+// Performs division of a double-word specified in its single-word components. Most commonly used
+// for computing u128 bit divisions in terms of 64-bit integers.
+//
+// q = U / v
+// r = U % v
+// where  U = (u1 | u0)
+fn divwide_generic(comptime T: type, _u1: T, _u0: T, v_: T, r: *T) T {
     @setRuntimeSafety(is_test);
+    var v = v_;
 
-    const double_int_bits = @typeInfo(DoubleInt).Int.bits;
-    const single_int_bits = @divExact(double_int_bits, 2);
-    const SingleInt = std.meta.Int(.unsigned, single_int_bits);
-    const SignedDoubleInt = std.meta.Int(.signed, double_int_bits);
-    const Log2SingleInt = std.math.Log2Int(SingleInt);
-
-    const n = @bitCast([2]SingleInt, a);
-    const d = @bitCast([2]SingleInt, b);
-    var q: [2]SingleInt = undefined;
-    var r: [2]SingleInt = undefined;
-    var sr: c_uint = undefined;
-    // special cases, X is unknown, K != 0
-    if (n[high] == 0) {
-        if (d[high] == 0) {
-            // 0 X
-            // ---
-            // 0 X
-            if (maybe_rem) |rem| {
-                rem.* = n[low] % d[low];
-            }
-            return n[low] / d[low];
-        }
-        // 0 X
-        // ---
-        // K X
+    const b = @as(T, 1) << (@bitSizeOf(T) / 2);
+    var un64: T = undefined;
+    var un10: T = undefined;
+
+    const s = @intCast(Log2Int(T), @clz(v));
+    if (s > 0) {
+        // Normalize divisor
+        v <<= s;
+        un64 = (_u1 << s) | (_u0 >> @intCast(Log2Int(T), (@bitSizeOf(T) - @intCast(T, s))));
+        un10 = _u0 << s;
+    } else {
+        // Avoid undefined behavior of (u0 >> @bitSizeOf(T))
+        un64 = _u1;
+        un10 = _u0;
+    }
+
+    // Break divisor up into two 32-bit digits
+    const vn1 = v >> (@bitSizeOf(T) / 2);
+    const vn0 = v & std.math.maxInt(HalfInt(T));
+
+    // Break right half of dividend into two digits
+    const un1 = un10 >> (@bitSizeOf(T) / 2);
+    const un0 = un10 & std.math.maxInt(HalfInt(T));
+
+    // Compute the first quotient digit, q1
+    var q1 = un64 / vn1;
+    var rhat = un64 -% q1 *% vn1;
+
+    // q1 has at most error 2. No more than 2 iterations
+    while (q1 >= b or q1 * vn0 > b * rhat + un1) {
+        q1 -= 1;
+        rhat += vn1;
+        if (rhat >= b) break;
+    }
+
+    var un21 = un64 *% b +% un1 -% q1 *% v;
+
+    // Compute the second quotient digit
+    var q0 = un21 / vn1;
+    rhat = un21 -% q0 *% vn1;
+
+    // q0 has at most error 2. No more than 2 iterations.
+    while (q0 >= b or q0 * vn0 > b * rhat + un0) {
+        q0 -= 1;
+        rhat += vn1;
+        if (rhat >= b) break;
+    }
+
+    r.* = (un21 *% b +% un0 -% q0 *% v) >> s;
+    return q1 *% b +% q0;
+}
+
+fn divwide(comptime T: type, _u1: T, _u0: T, v: T, r: *T) T {
+    @setRuntimeSafety(is_test);
+    if (T == u64 and builtin.target.cpu.arch == .x86_64) {
+        var rem: T = undefined;
+        const quo = asm (
+            \\divq %[v]
+            : [_] "={rax}" (-> T),
+              [_] "={rdx}" (rem),
+            : [v] "r" (v),
+              [_] "{rax}" (_u0),
+              [_] "{rdx}" (_u1),
+        );
+        r.* = rem;
+        return quo;
+    } else {
+        return divwide_generic(T, _u1, _u0, v, r);
+    }
+}
+
+// return q = a / b, *r = a % b
+pub fn udivmod(comptime T: type, a_: T, b_: T, maybe_rem: ?*T) T {
+    @setRuntimeSafety(is_test);
+    const HalfT = HalfInt(T);
+    const SignedT = std.meta.Int(.signed, @bitSizeOf(T));
+
+    if (b_ > a_) {
         if (maybe_rem) |rem| {
-            rem.* = n[low];
+            rem.* = a_;
         }
         return 0;
     }
-    // n[high] != 0
-    if (d[low] == 0) {
-        if (d[high] == 0) {
-            // K X
-            // ---
-            // 0 0
-            if (maybe_rem) |rem| {
-                rem.* = n[high] % d[low];
-            }
-            return n[high] / d[low];
-        }
-        // d[high] != 0
-        if (n[low] == 0) {
-            // K 0
-            // ---
-            // K 0
-            if (maybe_rem) |rem| {
-                r[high] = n[high] % d[high];
-                r[low] = 0;
-                rem.* = @bitCast(DoubleInt, r);
-            }
-            return n[high] / d[high];
-        }
-        // K K
-        // ---
-        // K 0
-        if ((d[high] & (d[high] - 1)) == 0) {
-            // d is a power of 2
-            if (maybe_rem) |rem| {
-                r[low] = n[low];
-                r[high] = n[high] & (d[high] - 1);
-                rem.* = @bitCast(DoubleInt, r);
-            }
-            return n[high] >> @intCast(Log2SingleInt, @ctz(d[high]));
-        }
-        // K K
-        // ---
-        // K 0
-        sr = @bitCast(c_uint, @as(c_int, @clz(d[high])) - @as(c_int, @clz(n[high])));
-        // 0 <= sr <= single_int_bits - 2 or sr large
-        if (sr > single_int_bits - 2) {
-            if (maybe_rem) |rem| {
-                rem.* = a;
-            }
-            return 0;
-        }
-        sr += 1;
-        // 1 <= sr <= single_int_bits - 1
-        // q.all = a << (double_int_bits - sr);
-        q[low] = 0;
-        q[high] = n[low] << @intCast(Log2SingleInt, single_int_bits - sr);
-        // r.all = a >> sr;
-        r[high] = n[high] >> @intCast(Log2SingleInt, sr);
-        r[low] = (n[high] << @intCast(Log2SingleInt, single_int_bits - sr)) | (n[low] >> @intCast(Log2SingleInt, sr));
-    } else {
-        // d[low] != 0
-        if (d[high] == 0) {
-            // K X
-            // ---
-            // 0 K
-            if ((d[low] & (d[low] - 1)) == 0) {
-                // d is a power of 2
-                if (maybe_rem) |rem| {
-                    rem.* = n[low] & (d[low] - 1);
-                }
-                if (d[low] == 1) {
-                    return a;
-                }
-                sr = @ctz(d[low]);
-                q[high] = n[high] >> @intCast(Log2SingleInt, sr);
-                q[low] = (n[high] << @intCast(Log2SingleInt, single_int_bits - sr)) | (n[low] >> @intCast(Log2SingleInt, sr));
-                return @bitCast(DoubleInt, q);
-            }
-            // K X
-            // ---
-            // 0 K
-            sr = 1 + single_int_bits + @as(c_uint, @clz(d[low])) - @as(c_uint, @clz(n[high]));
-            // 2 <= sr <= double_int_bits - 1
-            // q.all = a << (double_int_bits - sr);
-            // r.all = a >> sr;
-            if (sr == single_int_bits) {
-                q[low] = 0;
-                q[high] = n[low];
-                r[high] = 0;
-                r[low] = n[high];
-            } else if (sr < single_int_bits) {
-                // 2 <= sr <= single_int_bits - 1
-                q[low] = 0;
-                q[high] = n[low] << @intCast(Log2SingleInt, single_int_bits - sr);
-                r[high] = n[high] >> @intCast(Log2SingleInt, sr);
-                r[low] = (n[high] << @intCast(Log2SingleInt, single_int_bits - sr)) | (n[low] >> @intCast(Log2SingleInt, sr));
-            } else {
-                // single_int_bits + 1 <= sr <= double_int_bits - 1
-                q[low] = n[low] << @intCast(Log2SingleInt, double_int_bits - sr);
-                q[high] = (n[high] << @intCast(Log2SingleInt, double_int_bits - sr)) | (n[low] >> @intCast(Log2SingleInt, sr - single_int_bits));
-                r[high] = 0;
-                r[low] = n[high] >> @intCast(Log2SingleInt, sr - single_int_bits);
-            }
+
+    var a = @bitCast([2]HalfT, a_);
+    var b = @bitCast([2]HalfT, b_);
+    var q: [2]HalfT = undefined;
+    var r: [2]HalfT = undefined;
+
+    // When the divisor fits in 64 bits, we can use an optimized path
+    if (b[hi] == 0) {
+        r[hi] = 0;
+        if (a[hi] < b[lo]) {
+            // The result fits in 64 bits
+            q[hi] = 0;
+            q[lo] = divwide(HalfT, a[hi], a[lo], b[lo], &r[lo]);
         } else {
-            // K X
-            // ---
-            // K K
-            sr = @bitCast(c_uint, @as(c_int, @clz(d[high])) - @as(c_int, @clz(n[high])));
-            // 0 <= sr <= single_int_bits - 1 or sr large
-            if (sr > single_int_bits - 1) {
-                if (maybe_rem) |rem| {
-                    rem.* = a;
-                }
-                return 0;
-            }
-            sr += 1;
-            // 1 <= sr <= single_int_bits
-            // q.all = a << (double_int_bits - sr);
-            // r.all = a >> sr;
-            q[low] = 0;
-            if (sr == single_int_bits) {
-                q[high] = n[low];
-                r[high] = 0;
-                r[low] = n[high];
-            } else {
-                r[high] = n[high] >> @intCast(Log2SingleInt, sr);
-                r[low] = (n[high] << @intCast(Log2SingleInt, single_int_bits - sr)) | (n[low] >> @intCast(Log2SingleInt, sr));
-                q[high] = n[low] << @intCast(Log2SingleInt, single_int_bits - sr);
-            }
+            // First, divide with the high part to get the remainder. After that a_hi < b_lo.
+            q[hi] = a[hi] / b[lo];
+            q[lo] = divwide(HalfT, a[hi] % b[lo], a[lo], b[lo], &r[lo]);
         }
+        if (maybe_rem) |rem| {
+            rem.* = @bitCast(T, r);
+        }
+        return @bitCast(T, q);
     }
-    // Not a special case
-    // q and r are initialized with:
-    // q.all = a << (double_int_bits - sr);
-    // r.all = a >> sr;
-    // 1 <= sr <= double_int_bits - 1
-    var carry: u32 = 0;
-    var r_all: DoubleInt = undefined;
-    while (sr > 0) : (sr -= 1) {
-        // r:q = ((r:q)  << 1) | carry
-        r[high] = (r[high] << 1) | (r[low] >> (single_int_bits - 1));
-        r[low] = (r[low] << 1) | (q[high] >> (single_int_bits - 1));
-        q[high] = (q[high] << 1) | (q[low] >> (single_int_bits - 1));
-        q[low] = (q[low] << 1) | carry;
-        // carry = 0;
-        // if (r.all >= b)
-        // {
-        //     r.all -= b;
-        //      carry = 1;
+
+    // 0 <= shift <= 63
+    var shift: Log2Int(T) = @clz(b[hi]) - @clz(a[hi]);
+    var af = @bitCast(T, a);
+    var bf = @bitCast(T, b) << shift;
+    q = @bitCast([2]HalfT, @as(T, 0));
+
+    for (0..shift + 1) |_| {
+        q[lo] <<= 1;
+        // Branchless version of:
+        // if (a >= b) {
+        //     a -= b;
+        //     q[lo] |= 1;
         // }
-        r_all = @bitCast(DoubleInt, r);
-        const s: SignedDoubleInt = @bitCast(SignedDoubleInt, b -% r_all -% 1) >> (double_int_bits - 1);
-        carry = @intCast(u32, s & 1);
-        r_all -= b & @bitCast(DoubleInt, s);
-        r = @bitCast([2]SingleInt, r_all);
+        const s = @bitCast(SignedT, bf -% af -% 1) >> (@bitSizeOf(T) - 1);
+        q[lo] |= @intCast(HalfT, s & 1);
+        af -= bf & @bitCast(T, s);
+        bf >>= 1;
     }
-    const q_all = (@bitCast(DoubleInt, q) << 1) | carry;
     if (maybe_rem) |rem| {
-        rem.* = r_all;
+        rem.* = @bitCast(T, af);
     }
-    return q_all;
+    return @bitCast(T, q);
 }