Commit 6e49ba77f3

Andrew Kelley <andrew@ziglang.org>
2022-03-10 06:50:27
std: add sort method to ArrayHashMap and MultiArrayList
This also adds `std.sort.sortContext` and `std.sort.insertionSortContext` which are more advanced methods that allow overriding the `swap` method. The former calls the latter for now because reworking the main sort implementation is a big task that can be done later without any changes to the API.
1 parent f736cde
lib/std/array_hash_map.zig
@@ -408,6 +408,13 @@ pub fn ArrayHashMap(
             return self.unmanaged.reIndexContext(self.allocator, self.ctx);
         }
 
+        /// Sorts the entries and then rebuilds the index.
+        /// `sort_ctx` must have this method:
+        /// `fn lessThan(ctx: @TypeOf(ctx), a_index: usize, b_index: usize) bool`
+        pub fn sort(self: *Self, sort_ctx: anytype) void {
+            return self.unmanaged.sortContext(sort_ctx, self.ctx);
+        }
+
         /// Shrinks the underlying `Entry` array to `new_len` elements and discards any associated
         /// index entries. Keeps capacity the same.
         pub fn shrinkRetainingCapacity(self: *Self, new_len: usize) void {
@@ -1169,6 +1176,22 @@ pub fn ArrayHashMapUnmanaged(
             self.index_header = new_header;
         }
 
+        /// Sorts the entries and then rebuilds the index.
+        /// `sort_ctx` must have this method:
+        /// `fn lessThan(ctx: @TypeOf(ctx), a_index: usize, b_index: usize) bool`
+        pub inline fn sort(self: *Self, sort_ctx: anytype) void {
+            if (@sizeOf(ByIndexContext) != 0)
+                @compileError("Cannot infer context " ++ @typeName(Context) ++ ", call sortContext instead.");
+            return self.sortContext(sort_ctx, undefined);
+        }
+
+        pub fn sortContext(self: *Self, sort_ctx: anytype, ctx: Context) void {
+            self.entries.sort(sort_ctx);
+            const header = self.index_header orelse return;
+            header.reset();
+            self.insertAllEntriesIntoNewHeader(if (store_hash) {} else ctx, header);
+        }
+
         /// Shrinks the underlying `Entry` array to `new_len` elements and discards any associated
         /// index entries. Keeps capacity the same.
         pub fn shrinkRetainingCapacity(self: *Self, new_len: usize) void {
@@ -1868,6 +1891,14 @@ const IndexHeader = struct {
         allocator.free(slice);
     }
 
+    /// Puts an IndexHeader into the state that it would be in after being freshly allocated.
+    fn reset(header: *IndexHeader) void {
+        const index_size = hash_map.capacityIndexSize(header.bit_index);
+        const ptr = @ptrCast([*]align(@alignOf(IndexHeader)) u8, header);
+        const nbytes = @sizeOf(IndexHeader) + header.length() * index_size;
+        @memset(ptr + @sizeOf(IndexHeader), 0xff, nbytes - @sizeOf(IndexHeader));
+    }
+
     // Verify that the header has sufficient alignment to produce aligned arrays.
     comptime {
         if (@alignOf(u32) > @alignOf(IndexHeader))
@@ -2218,6 +2249,32 @@ test "auto store_hash" {
     try testing.expect(meta.fieldInfo(HasExpensiveEqlUn.Data, .hash).field_type != void);
 }
 
+test "sort" {
+    var map = AutoArrayHashMap(i32, i32).init(std.testing.allocator);
+    defer map.deinit();
+
+    for ([_]i32{ 8, 3, 12, 10, 2, 4, 9, 5, 6, 13, 14, 15, 16, 1, 11, 17, 7 }) |x| {
+        try map.put(x, x * 3);
+    }
+
+    const C = struct {
+        keys: []i32,
+
+        pub fn lessThan(ctx: @This(), a_index: usize, b_index: usize) bool {
+            return ctx.keys[a_index] < ctx.keys[b_index];
+        }
+    };
+
+    map.sort(C{ .keys = map.keys() });
+
+    var x: i32 = 1;
+    for (map.keys()) |key, i| {
+        try testing.expect(key == x);
+        try testing.expect(map.values()[i] == x * 3);
+        x += 1;
+    }
+}
+
 pub fn getHashPtrAddrFn(comptime K: type, comptime Context: type) (fn (Context, K) u32) {
     return struct {
         fn hash(ctx: Context, key: K) u32 {
lib/std/multi_array_list.zig
@@ -392,6 +392,34 @@ pub fn MultiArrayList(comptime S: type) type {
             return result;
         }
 
+        /// `ctx` has the following method:
+        /// `fn lessThan(ctx: @TypeOf(ctx), a_index: usize, b_index: usize) bool`
+        pub fn sort(self: Self, ctx: anytype) void {
+            const SortContext = struct {
+                sub_ctx: @TypeOf(ctx),
+                slice: Slice,
+
+                pub fn swap(sc: @This(), a_index: usize, b_index: usize) void {
+                    inline for (fields) |field_info, i| {
+                        if (@sizeOf(field_info.field_type) != 0) {
+                            const field = @intToEnum(Field, i);
+                            const ptr = sc.slice.items(field);
+                            mem.swap(field_info.field_type, &ptr[a_index], &ptr[b_index]);
+                        }
+                    }
+                }
+
+                pub fn lessThan(sc: @This(), a_index: usize, b_index: usize) bool {
+                    return sc.sub_ctx.lessThan(a_index, b_index);
+                }
+            };
+
+            std.sort.sortContext(self.len, SortContext{
+                .sub_ctx = ctx,
+                .slice = self.slice(),
+            });
+        }
+
         fn capacityInBytes(capacity: usize) usize {
             const sizes_vector: std.meta.Vector(sizes.bytes.len, usize) = sizes.bytes;
             const capacity_vector = @splat(sizes.bytes.len, capacity);
lib/std/sort.zig
@@ -73,7 +73,10 @@ test "binarySearch" {
     );
 }
 
-/// Stable in-place sort. O(n) best case, O(pow(n, 2)) worst case. O(1) memory (no allocator required).
+/// Stable in-place sort. O(n) best case, O(pow(n, 2)) worst case.
+/// O(1) memory (no allocator required).
+/// This can be expressed in terms of `insertionSortContext` but the glue
+/// code is slightly longer than the direct implementation.
 pub fn insertionSort(
     comptime T: type,
     items: []T,
@@ -91,6 +94,18 @@ pub fn insertionSort(
     }
 }
 
+/// Stable in-place sort. O(n) best case, O(pow(n, 2)) worst case.
+/// O(1) memory (no allocator required).
+pub fn insertionSortContext(len: usize, context: anytype) void {
+    var i: usize = 1;
+    while (i < len) : (i += 1) {
+        var j: usize = i;
+        while (j > 0 and context.lessThan(j, j - 1)) : (j -= 1) {
+            context.swap(j, j - 1);
+        }
+    }
+}
+
 const Range = struct {
     start: usize,
     end: usize,
@@ -178,7 +193,8 @@ const Pull = struct {
     range: Range,
 };
 
-/// Stable in-place sort. O(n) best case, O(n*log(n)) worst case and average case. O(1) memory (no allocator required).
+/// Stable in-place sort. O(n) best case, O(n*log(n)) worst case and average case.
+/// O(1) memory (no allocator required).
 /// Currently implemented as block sort.
 pub fn sort(
     comptime T: type,
@@ -186,6 +202,7 @@ pub fn sort(
     context: anytype,
     comptime lessThan: fn (context: @TypeOf(context), lhs: T, rhs: T) bool,
 ) void {
+
     // Implementation ported from https://github.com/BonzaiThePenguin/WikiSort/blob/master/WikiSort.c
     var cache: [512]T = undefined;
 
@@ -291,10 +308,13 @@ pub fn sort(
 
     // then merge sort the higher levels, which can be 8-15, 16-31, 32-63, 64-127, etc.
     while (true) {
-        // if every A and B block will fit into the cache, use a special branch specifically for merging with the cache
-        // (we use < rather than <= since the block size might be one more than iterator.length())
+        // if every A and B block will fit into the cache, use a special branch
+        // specifically for merging with the cache
+        // (we use < rather than <= since the block size might be one more than
+        // iterator.length())
         if (iterator.length() < cache.len) {
-            // if four subarrays fit into the cache, it's faster to merge both pairs of subarrays into the cache,
+            // if four subarrays fit into the cache, it's faster to merge both
+            // pairs of subarrays into the cache,
             // then merge the two merged subarrays from the cache back into the original array
             if ((iterator.length() + 1) * 4 <= cache.len and iterator.length() * 4 <= items.len) {
                 iterator.begin();
@@ -767,11 +787,15 @@ pub fn sort(
                 }
             }
 
-            // when we're finished with this merge step we should have the one or two internal buffers left over, where the second buffer is all jumbled up
-            // insertion sort the second buffer, then redistribute the buffers back into the items using the opposite process used for creating the buffer
+            // when we're finished with this merge step we should have the one
+            // or two internal buffers left over, where the second buffer is all jumbled up
+            // insertion sort the second buffer, then redistribute the buffers
+            // back into the items using the opposite process used for creating the buffer
 
-            // while an unstable sort like quicksort could be applied here, in benchmarks it was consistently slightly slower than a simple insertion sort,
-            // even for tens of millions of items. this may be because insertion sort is quite fast when the data is already somewhat sorted, like it is here
+            // while an unstable sort like quicksort could be applied here, in benchmarks
+            // it was consistently slightly slower than a simple insertion sort,
+            // even for tens of millions of items. this may be because insertion
+            // sort is quite fast when the data is already somewhat sorted, like it is here
             insertionSort(T, items[buffer2.start..buffer2.end], context, lessThan);
 
             pull_index = 0;
@@ -808,6 +832,12 @@ pub fn sort(
     }
 }
 
+/// TODO currently this just calls `insertionSortContext`. The block sort implementation
+/// in this file needs to be adapted to use the sort context.
+pub fn sortContext(len: usize, context: anytype) void {
+    return insertionSortContext(len, context);
+}
+
 // merge operation without a buffer
 fn mergeInPlace(
     comptime T: type,