Commit 7d511d6428

Niles Salter <Validark@pm.me>
2023-06-22 19:32:28
[heapsort] Protect against integer overflow
(Firstly, I changed `n` to `b`, as that is less confusing. It's not a length, it's a right boundary.) The invariant maintained is `cur < b`. In the worst case `2*cur + 1` results in a maximum of `2b`. Since `2b` is not guaranteed to be lower than `maxInt`, we have to add one overflow check to `siftDown` to make sure we avoid undefined behavior. LLVM also seems to have a nicer time compiling this version of the function. It is about 2x faster in my tests (I think LLVM was stumped by the `child += @intFromBool` line), and adding/removing the overflow check has a negligible performance difference on my machine. Of course, we could check `2b <= maxInt` in the parent function, and dispatch to a version of the function without the overflow check in the common case, but that probably is not worth the code size just to eliminate a single instruction.
1 parent c608967
Changed files (1)
lib
lib/std/sort.zig
@@ -36,6 +36,8 @@ pub fn insertion(
 /// O(1) memory (no allocator required).
 /// Sorts in ascending order with respect to the given `lessThan` function.
 pub fn insertionContext(a: usize, b: usize, context: anytype) void {
+    assert(a <= b);
+
     var i = a + 1;
     while (i < b) : (i += 1) {
         var j = i;
@@ -73,6 +75,7 @@ pub fn heap(
 /// O(1) memory (no allocator required).
 /// Sorts in ascending order with respect to the given `lessThan` function.
 pub fn heapContext(a: usize, b: usize, context: anytype) void {
+    assert(a <= b);
     // build the heap in linear time.
     var i = a + (b - a) / 2;
     while (i > a) {
@@ -89,22 +92,33 @@ pub fn heapContext(a: usize, b: usize, context: anytype) void {
     }
 }
 
-fn siftDown(a: usize, root: usize, n: usize, context: anytype) void {
-    var node = root;
+fn siftDown(a: usize, target: usize, b: usize, context: anytype) void {
+    var cur = target;
     while (true) {
-        var child = a + 2 * (node - a) + 1;
-        if (child >= n) break;
+        // When we don't overflow from the multiply below, the following expression equals (2*cur) - (2*a) + a + 1
+        // The `+ a + 1` is safe because:
+        //  for `a > 0` then `2a >= a + 1`.
+        //  for `a = 0`, the expression equals `2*cur+1`. `2*cur` is an even number, therefore adding 1 is safe.
+        var child = (math.mul(usize, cur - a, 2) catch break) + a + 1;
+
+        // stop if we overshot the boundary
+        if (!(child < b)) break;
 
-        // choose the greater child.
-        child += @intFromBool(child + 1 < n and context.lessThan(child, child + 1));
+        // `next_child` is at most `b`, therefore no overflow is possible
+        const next_child = child + 1;
+
+        // store the greater child in `child`
+        if (next_child < b and context.lessThan(child, next_child)) {
+            child = next_child;
+        }
 
-        // stop if the invariant holds at `node`.
-        if (!context.lessThan(node, child)) break;
+        // stop if the Heap invariant holds at `cur`.
+        if (context.lessThan(child, cur)) break;
 
-        // swap `node` with the greater child,
+        // swap `cur` with the greater child,
         // move one step down, and continue sifting.
-        context.swap(node, child);
-        node = child;
+        context.swap(child, cur);
+        cur = child;
     }
 }