Commit 700ea694b2

Niles Salter <Validark@pm.me>
2023-06-13 22:55:58
Fix pdqSort+heapSort for ranges besides 0..len (#15982)
1 parent 129afba
Changed files (2)
lib
lib/std/sort/pdq.zig
@@ -43,7 +43,7 @@ pub fn pdqContext(a: usize, b: usize, context: anytype) void {
     // slices of up to this length get sorted using insertion sort.
     const max_insertion = 24;
     // number of allowed imbalanced partitions before switching to heap sort.
-    const max_limit = std.math.floorPowerOfTwo(usize, b) + 1;
+    const max_limit = std.math.floorPowerOfTwo(usize, b - a) + 1;
 
     // set upper bound on stack memory usage.
     const Range = struct { a: usize, b: usize, limit: usize };
@@ -100,7 +100,7 @@ pub fn pdqContext(a: usize, b: usize, context: anytype) void {
             // if the chosen pivot is equal to the predecessor, then it's the smallest element in the
             // slice. Partition the slice into elements equal to and elements greater than the pivot.
             // This case is usually hit when the slice contains many duplicate elements.
-            if (range.a > 0 and !context.lessThan(range.a - 1, pivot)) {
+            if (range.a > a and !context.lessThan(range.a - 1, pivot)) {
                 range.a = partitionEqual(range.a, range.b, pivot, context);
                 continue;
             }
@@ -284,13 +284,13 @@ fn chosePivot(a: usize, b: usize, pivot: *usize, context: anytype) Hint {
     if (len >= 8) {
         if (len >= shortest_ninther) {
             // find medians in the neighborhoods of `i`, `j` and `k`
-            i = sort3(i - 1, i, i + 1, &swaps, context);
-            j = sort3(j - 1, j, j + 1, &swaps, context);
-            k = sort3(k - 1, k, k + 1, &swaps, context);
+            sort3(i - 1, i, i + 1, &swaps, context);
+            sort3(j - 1, j, j + 1, &swaps, context);
+            sort3(k - 1, k, k + 1, &swaps, context);
         }
 
-        // find the median among `i`, `j` and `k`
-        j = sort3(i, j, k, &swaps, context);
+        // find the median among `i`, `j` and `k` and stores it in `j`
+        sort3(i, j, k, &swaps, context);
     }
 
     pivot.* = j;
@@ -301,7 +301,7 @@ fn chosePivot(a: usize, b: usize, pivot: *usize, context: anytype) Hint {
     };
 }
 
-fn sort3(a: usize, b: usize, c: usize, swaps: *usize, context: anytype) usize {
+fn sort3(a: usize, b: usize, c: usize, swaps: *usize, context: anytype) void {
     if (context.lessThan(b, a)) {
         swaps.* += 1;
         context.swap(b, a);
@@ -316,8 +316,6 @@ fn sort3(a: usize, b: usize, c: usize, swaps: *usize, context: anytype) usize {
         swaps.* += 1;
         context.swap(b, a);
     }
-
-    return b;
 }
 
 fn reverseRange(a: usize, b: usize, context: anytype) void {
lib/std/sort.zig
@@ -74,29 +74,29 @@ pub fn heap(
 /// Sorts in ascending order with respect to the given `lessThan` function.
 pub fn heapContext(a: usize, b: usize, context: anytype) void {
     // build the heap in linear time.
-    var i = b / 2;
-    while (i > a) : (i -= 1) {
-        siftDown(i - 1, b, context);
+    var i = a + (b - a) / 2;
+    while (i > a) {
+        i -= 1;
+        siftDown(a, i, b, context);
     }
 
     // pop maximal elements from the heap.
     i = b;
-    while (i > a) : (i -= 1) {
-        context.swap(a, i - 1);
-        siftDown(a, i - 1, context);
+    while (i > a) {
+        i -= 1;
+        context.swap(a, i);
+        siftDown(a, a, i, context);
     }
 }
 
-fn siftDown(root: usize, n: usize, context: anytype) void {
+fn siftDown(a: usize, root: usize, n: usize, context: anytype) void {
     var node = root;
     while (true) {
-        var child = 2 * node + 1;
+        var child = a + 2 * (node - a) + 1;
         if (child >= n) break;
 
         // choose the greater child.
-        if (child + 1 < n and context.lessThan(child, child + 1)) {
-            child += 1;
-        }
+        child += @boolToInt(child + 1 < n and context.lessThan(child, child + 1));
 
         // stop if the invariant holds at `node`.
         if (!context.lessThan(node, child)) break;
@@ -138,6 +138,13 @@ const sort_funcs = &[_]fn (comptime type, anytype, anytype, comptime anytype) vo
     heap,
 };
 
+const context_sort_funcs = &[_]fn (usize, usize, anytype) void{
+    // blockContext,
+    pdqContext,
+    insertionContext,
+    heapContext,
+};
+
 const IdAndValue = struct {
     id: usize,
     value: i32,
@@ -248,11 +255,15 @@ test "sort" {
             &[_]i32{ 2, 1, 3 },
             &[_]i32{ 1, 2, 3 },
         },
+        &[_][]const i32{
+            &[_]i32{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 55, 32, 39, 58, 21, 88, 43, 22, 59 },
+            &[_]i32{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 21, 22, 32, 39, 43, 55, 58, 59, 88 },
+        },
     };
 
     inline for (sort_funcs) |sortFn| {
         for (u8cases) |case| {
-            var buf: [8]u8 = undefined;
+            var buf: [20]u8 = undefined;
             const slice = buf[0..case[0].len];
             @memcpy(slice, case[0]);
             sortFn(u8, slice, {}, asc_u8);
@@ -260,7 +271,7 @@ test "sort" {
         }
 
         for (i32cases) |case| {
-            var buf: [8]i32 = undefined;
+            var buf: [20]i32 = undefined;
             const slice = buf[0..case[0].len];
             @memcpy(slice, case[0]);
             sortFn(i32, slice, {}, asc_i32);
@@ -308,6 +319,45 @@ test "sort descending" {
     }
 }
 
+test "sort with context in the middle of a slice" {
+    const Context = struct {
+        items: []i32,
+
+        pub fn lessThan(ctx: @This(), a: usize, b: usize) bool {
+            return ctx.items[a] < ctx.items[b];
+        }
+
+        pub fn swap(ctx: @This(), a: usize, b: usize) void {
+            return mem.swap(i32, &ctx.items[a], &ctx.items[b]);
+        }
+    };
+
+    const i32cases = [_][]const []const i32{
+        &[_][]const i32{
+            &[_]i32{ 0, 1, 8, 3, 6, 5, 4, 2, 9, 7, 10, 55, 32, 39, 58, 21, 88, 43, 22, 59 },
+            &[_]i32{ 50, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 21, 22, 32, 39, 43, 55, 58, 59, 88 },
+        },
+    };
+
+    const ranges = [_]struct { start: usize, end: usize }{
+        .{ .start = 10, .end = 20 },
+        .{ .start = 1, .end = 11 },
+        .{ .start = 3, .end = 7 },
+    };
+
+    inline for (context_sort_funcs) |sortFn| {
+        for (i32cases) |case| {
+            for (ranges) |range| {
+                var buf: [20]i32 = undefined;
+                const slice = buf[0..case[0].len];
+                @memcpy(slice, case[0]);
+                sortFn(range.start, range.end, Context{ .items = slice });
+                try testing.expectEqualSlices(i32, slice[range.start..range.end], case[1][range.start..range.end]);
+            }
+        }
+    }
+}
+
 test "sort fuzz testing" {
     var prng = std.rand.DefaultPrng.init(0x12345678);
     const random = prng.random();