master
  1const std = @import("std.zig");
  2const Allocator = std.mem.Allocator;
  3const assert = std.debug.assert;
  4const Order = std.math.Order;
  5const testing = std.testing;
  6const expect = testing.expect;
  7const expectEqual = testing.expectEqual;
  8const expectError = testing.expectError;
  9
 10/// Priority queue for storing generic data. Initialize with `init`.
 11/// Provide `compareFn` that returns `Order.lt` when its second
 12/// argument should get popped before its third argument,
 13/// `Order.eq` if the arguments are of equal priority, or `Order.gt`
 14/// if the third argument should be popped first.
 15/// For example, to make `pop` return the smallest number, provide
 16/// `fn lessThan(context: void, a: T, b: T) Order { _ = context; return std.math.order(a, b); }`
 17pub fn PriorityQueue(comptime T: type, comptime Context: type, comptime compareFn: fn (context: Context, a: T, b: T) Order) type {
 18    return struct {
 19        const Self = @This();
 20
 21        items: []T,
 22        cap: usize,
 23        allocator: Allocator,
 24        context: Context,
 25
 26        /// Initialize and return a priority queue.
 27        pub fn init(allocator: Allocator, context: Context) Self {
 28            return Self{
 29                .items = &[_]T{},
 30                .cap = 0,
 31                .allocator = allocator,
 32                .context = context,
 33            };
 34        }
 35
 36        /// Free memory used by the queue.
 37        pub fn deinit(self: Self) void {
 38            self.allocator.free(self.allocatedSlice());
 39        }
 40
 41        /// Insert a new element, maintaining priority.
 42        pub fn add(self: *Self, elem: T) !void {
 43            try self.ensureUnusedCapacity(1);
 44            addUnchecked(self, elem);
 45        }
 46
 47        fn addUnchecked(self: *Self, elem: T) void {
 48            self.items.len += 1;
 49            self.items[self.items.len - 1] = elem;
 50            siftUp(self, self.items.len - 1);
 51        }
 52
 53        fn siftUp(self: *Self, start_index: usize) void {
 54            const child = self.items[start_index];
 55            var child_index = start_index;
 56            while (child_index > 0) {
 57                const parent_index = ((child_index - 1) >> 1);
 58                const parent = self.items[parent_index];
 59                if (compareFn(self.context, child, parent) != .lt) break;
 60                self.items[child_index] = parent;
 61                child_index = parent_index;
 62            }
 63            self.items[child_index] = child;
 64        }
 65
 66        /// Add each element in `items` to the queue.
 67        pub fn addSlice(self: *Self, items: []const T) !void {
 68            try self.ensureUnusedCapacity(items.len);
 69            for (items) |e| {
 70                self.addUnchecked(e);
 71            }
 72        }
 73
 74        /// Look at the highest priority element in the queue. Returns
 75        /// `null` if empty.
 76        pub fn peek(self: *Self) ?T {
 77            return if (self.items.len > 0) self.items[0] else null;
 78        }
 79
 80        /// Pop the highest priority element from the queue. Returns
 81        /// `null` if empty.
 82        pub fn removeOrNull(self: *Self) ?T {
 83            return if (self.items.len > 0) self.remove() else null;
 84        }
 85
 86        /// Remove and return the highest priority element from the
 87        /// queue.
 88        pub fn remove(self: *Self) T {
 89            return self.removeIndex(0);
 90        }
 91
 92        /// Remove and return element at index. Indices are in the
 93        /// same order as iterator, which is not necessarily priority
 94        /// order.
 95        pub fn removeIndex(self: *Self, index: usize) T {
 96            assert(self.items.len > index);
 97            const last = self.items[self.items.len - 1];
 98            const item = self.items[index];
 99            self.items[index] = last;
100            self.items.len -= 1;
101
102            if (index == self.items.len) {
103                // Last element removed, nothing more to do.
104            } else if (index == 0) {
105                siftDown(self, index);
106            } else {
107                const parent_index = ((index - 1) >> 1);
108                const parent = self.items[parent_index];
109                if (compareFn(self.context, last, parent) == .gt) {
110                    siftDown(self, index);
111                } else {
112                    siftUp(self, index);
113                }
114            }
115
116            return item;
117        }
118
119        /// Return the number of elements remaining in the priority
120        /// queue.
121        pub fn count(self: Self) usize {
122            return self.items.len;
123        }
124
125        /// Return the number of elements that can be added to the
126        /// queue before more memory is allocated.
127        pub fn capacity(self: Self) usize {
128            return self.cap;
129        }
130
131        /// Returns a slice of all the items plus the extra capacity, whose memory
132        /// contents are `undefined`.
133        fn allocatedSlice(self: Self) []T {
134            // `items.len` is the length, not the capacity.
135            return self.items.ptr[0..self.cap];
136        }
137
138        fn siftDown(self: *Self, target_index: usize) void {
139            const target_element = self.items[target_index];
140            var index = target_index;
141            while (true) {
142                var lesser_child_i = (std.math.mul(usize, index, 2) catch break) | 1;
143                if (!(lesser_child_i < self.items.len)) break;
144
145                const next_child_i = lesser_child_i + 1;
146                if (next_child_i < self.items.len and compareFn(self.context, self.items[next_child_i], self.items[lesser_child_i]) == .lt) {
147                    lesser_child_i = next_child_i;
148                }
149
150                if (compareFn(self.context, target_element, self.items[lesser_child_i]) == .lt) break;
151
152                self.items[index] = self.items[lesser_child_i];
153                index = lesser_child_i;
154            }
155            self.items[index] = target_element;
156        }
157
158        /// PriorityQueue takes ownership of the passed in slice. The slice must have been
159        /// allocated with `allocator`.
160        /// Deinitialize with `deinit`.
161        pub fn fromOwnedSlice(allocator: Allocator, items: []T, context: Context) Self {
162            var self = Self{
163                .items = items,
164                .cap = items.len,
165                .allocator = allocator,
166                .context = context,
167            };
168
169            var i = self.items.len >> 1;
170            while (i > 0) {
171                i -= 1;
172                self.siftDown(i);
173            }
174            return self;
175        }
176
177        /// Ensure that the queue can fit at least `new_capacity` items.
178        pub fn ensureTotalCapacity(self: *Self, new_capacity: usize) !void {
179            var better_capacity = self.cap;
180            if (better_capacity >= new_capacity) return;
181            while (true) {
182                better_capacity += better_capacity / 2 + 8;
183                if (better_capacity >= new_capacity) break;
184            }
185            try self.ensureTotalCapacityPrecise(better_capacity);
186        }
187
188        pub fn ensureTotalCapacityPrecise(self: *Self, new_capacity: usize) !void {
189            if (self.capacity() >= new_capacity) return;
190
191            const old_memory = self.allocatedSlice();
192            const new_memory = try self.allocator.realloc(old_memory, new_capacity);
193            self.items.ptr = new_memory.ptr;
194            self.cap = new_memory.len;
195        }
196
197        /// Ensure that the queue can fit at least `additional_count` **more** item.
198        pub fn ensureUnusedCapacity(self: *Self, additional_count: usize) !void {
199            return self.ensureTotalCapacity(self.items.len + additional_count);
200        }
201
202        /// Reduce allocated capacity to `new_capacity`.
203        pub fn shrinkAndFree(self: *Self, new_capacity: usize) void {
204            assert(new_capacity <= self.cap);
205
206            // Cannot shrink to smaller than the current queue size without invalidating the heap property
207            assert(new_capacity >= self.items.len);
208
209            const old_memory = self.allocatedSlice();
210            const new_memory = self.allocator.realloc(old_memory, new_capacity) catch |e| switch (e) {
211                error.OutOfMemory => { // no problem, capacity is still correct then.
212                    return;
213                },
214            };
215
216            self.items.ptr = new_memory.ptr;
217            self.cap = new_memory.len;
218        }
219
220        pub fn clearRetainingCapacity(self: *Self) void {
221            self.items.len = 0;
222        }
223
224        pub fn clearAndFree(self: *Self) void {
225            self.allocator.free(self.allocatedSlice());
226            self.items.len = 0;
227            self.cap = 0;
228        }
229
230        pub fn update(self: *Self, elem: T, new_elem: T) !void {
231            const update_index = blk: {
232                var idx: usize = 0;
233                while (idx < self.items.len) : (idx += 1) {
234                    const item = self.items[idx];
235                    if (compareFn(self.context, item, elem) == .eq) break :blk idx;
236                }
237                return error.ElementNotFound;
238            };
239            const old_elem: T = self.items[update_index];
240            self.items[update_index] = new_elem;
241            switch (compareFn(self.context, new_elem, old_elem)) {
242                .lt => siftUp(self, update_index),
243                .gt => siftDown(self, update_index),
244                .eq => {}, // Nothing to do as the items have equal priority
245            }
246        }
247
248        pub const Iterator = struct {
249            queue: *PriorityQueue(T, Context, compareFn),
250            count: usize,
251
252            pub fn next(it: *Iterator) ?T {
253                if (it.count >= it.queue.items.len) return null;
254                const out = it.count;
255                it.count += 1;
256                return it.queue.items[out];
257            }
258
259            pub fn reset(it: *Iterator) void {
260                it.count = 0;
261            }
262        };
263
264        /// Return an iterator that walks the queue without consuming
265        /// it. The iteration order may differ from the priority order.
266        /// Invalidated if the heap is modified.
267        pub fn iterator(self: *Self) Iterator {
268            return Iterator{
269                .queue = self,
270                .count = 0,
271            };
272        }
273
274        fn dump(self: *Self) void {
275            const print = std.debug.print;
276            print("{{ ", .{});
277            print("items: ", .{});
278            for (self.items) |e| {
279                print("{}, ", .{e});
280            }
281            print("array: ", .{});
282            for (self.items) |e| {
283                print("{}, ", .{e});
284            }
285            print("len: {} ", .{self.items.len});
286            print("capacity: {}", .{self.cap});
287            print(" }}\n", .{});
288        }
289    };
290}
291
292fn lessThan(context: void, a: u32, b: u32) Order {
293    _ = context;
294    return std.math.order(a, b);
295}
296
297fn greaterThan(context: void, a: u32, b: u32) Order {
298    return lessThan(context, a, b).invert();
299}
300
301const PQlt = PriorityQueue(u32, void, lessThan);
302const PQgt = PriorityQueue(u32, void, greaterThan);
303
304test "add and remove min heap" {
305    var queue = PQlt.init(testing.allocator, {});
306    defer queue.deinit();
307
308    try queue.add(54);
309    try queue.add(12);
310    try queue.add(7);
311    try queue.add(23);
312    try queue.add(25);
313    try queue.add(13);
314    try expectEqual(@as(u32, 7), queue.remove());
315    try expectEqual(@as(u32, 12), queue.remove());
316    try expectEqual(@as(u32, 13), queue.remove());
317    try expectEqual(@as(u32, 23), queue.remove());
318    try expectEqual(@as(u32, 25), queue.remove());
319    try expectEqual(@as(u32, 54), queue.remove());
320}
321
322test "add and remove same min heap" {
323    var queue = PQlt.init(testing.allocator, {});
324    defer queue.deinit();
325
326    try queue.add(1);
327    try queue.add(1);
328    try queue.add(2);
329    try queue.add(2);
330    try queue.add(1);
331    try queue.add(1);
332    try expectEqual(@as(u32, 1), queue.remove());
333    try expectEqual(@as(u32, 1), queue.remove());
334    try expectEqual(@as(u32, 1), queue.remove());
335    try expectEqual(@as(u32, 1), queue.remove());
336    try expectEqual(@as(u32, 2), queue.remove());
337    try expectEqual(@as(u32, 2), queue.remove());
338}
339
340test "removeOrNull on empty" {
341    var queue = PQlt.init(testing.allocator, {});
342    defer queue.deinit();
343
344    try expect(queue.removeOrNull() == null);
345}
346
347test "edge case 3 elements" {
348    var queue = PQlt.init(testing.allocator, {});
349    defer queue.deinit();
350
351    try queue.add(9);
352    try queue.add(3);
353    try queue.add(2);
354    try expectEqual(@as(u32, 2), queue.remove());
355    try expectEqual(@as(u32, 3), queue.remove());
356    try expectEqual(@as(u32, 9), queue.remove());
357}
358
359test "peek" {
360    var queue = PQlt.init(testing.allocator, {});
361    defer queue.deinit();
362
363    try expect(queue.peek() == null);
364    try queue.add(9);
365    try queue.add(3);
366    try queue.add(2);
367    try expectEqual(@as(u32, 2), queue.peek().?);
368    try expectEqual(@as(u32, 2), queue.peek().?);
369}
370
371test "sift up with odd indices" {
372    var queue = PQlt.init(testing.allocator, {});
373    defer queue.deinit();
374    const items = [_]u32{ 15, 7, 21, 14, 13, 22, 12, 6, 7, 25, 5, 24, 11, 16, 15, 24, 2, 1 };
375    for (items) |e| {
376        try queue.add(e);
377    }
378
379    const sorted_items = [_]u32{ 1, 2, 5, 6, 7, 7, 11, 12, 13, 14, 15, 15, 16, 21, 22, 24, 24, 25 };
380    for (sorted_items) |e| {
381        try expectEqual(e, queue.remove());
382    }
383}
384
385test "addSlice" {
386    var queue = PQlt.init(testing.allocator, {});
387    defer queue.deinit();
388    const items = [_]u32{ 15, 7, 21, 14, 13, 22, 12, 6, 7, 25, 5, 24, 11, 16, 15, 24, 2, 1 };
389    try queue.addSlice(items[0..]);
390
391    const sorted_items = [_]u32{ 1, 2, 5, 6, 7, 7, 11, 12, 13, 14, 15, 15, 16, 21, 22, 24, 24, 25 };
392    for (sorted_items) |e| {
393        try expectEqual(e, queue.remove());
394    }
395}
396
397test "fromOwnedSlice trivial case 0" {
398    const items = [0]u32{};
399    const queue_items = try testing.allocator.dupe(u32, &items);
400    var queue = PQlt.fromOwnedSlice(testing.allocator, queue_items[0..], {});
401    defer queue.deinit();
402    try expectEqual(@as(usize, 0), queue.count());
403    try expect(queue.removeOrNull() == null);
404}
405
406test "fromOwnedSlice trivial case 1" {
407    const items = [1]u32{1};
408    const queue_items = try testing.allocator.dupe(u32, &items);
409    var queue = PQlt.fromOwnedSlice(testing.allocator, queue_items[0..], {});
410    defer queue.deinit();
411
412    try expectEqual(@as(usize, 1), queue.count());
413    try expectEqual(items[0], queue.remove());
414    try expect(queue.removeOrNull() == null);
415}
416
417test "fromOwnedSlice" {
418    const items = [_]u32{ 15, 7, 21, 14, 13, 22, 12, 6, 7, 25, 5, 24, 11, 16, 15, 24, 2, 1 };
419    const heap_items = try testing.allocator.dupe(u32, items[0..]);
420    var queue = PQlt.fromOwnedSlice(testing.allocator, heap_items[0..], {});
421    defer queue.deinit();
422
423    const sorted_items = [_]u32{ 1, 2, 5, 6, 7, 7, 11, 12, 13, 14, 15, 15, 16, 21, 22, 24, 24, 25 };
424    for (sorted_items) |e| {
425        try expectEqual(e, queue.remove());
426    }
427}
428
429test "add and remove max heap" {
430    var queue = PQgt.init(testing.allocator, {});
431    defer queue.deinit();
432
433    try queue.add(54);
434    try queue.add(12);
435    try queue.add(7);
436    try queue.add(23);
437    try queue.add(25);
438    try queue.add(13);
439    try expectEqual(@as(u32, 54), queue.remove());
440    try expectEqual(@as(u32, 25), queue.remove());
441    try expectEqual(@as(u32, 23), queue.remove());
442    try expectEqual(@as(u32, 13), queue.remove());
443    try expectEqual(@as(u32, 12), queue.remove());
444    try expectEqual(@as(u32, 7), queue.remove());
445}
446
447test "add and remove same max heap" {
448    var queue = PQgt.init(testing.allocator, {});
449    defer queue.deinit();
450
451    try queue.add(1);
452    try queue.add(1);
453    try queue.add(2);
454    try queue.add(2);
455    try queue.add(1);
456    try queue.add(1);
457    try expectEqual(@as(u32, 2), queue.remove());
458    try expectEqual(@as(u32, 2), queue.remove());
459    try expectEqual(@as(u32, 1), queue.remove());
460    try expectEqual(@as(u32, 1), queue.remove());
461    try expectEqual(@as(u32, 1), queue.remove());
462    try expectEqual(@as(u32, 1), queue.remove());
463}
464
465test "iterator" {
466    var queue = PQlt.init(testing.allocator, {});
467    var map = std.AutoHashMap(u32, void).init(testing.allocator);
468    defer {
469        queue.deinit();
470        map.deinit();
471    }
472
473    const items = [_]u32{ 54, 12, 7, 23, 25, 13 };
474    for (items) |e| {
475        _ = try queue.add(e);
476        try map.put(e, {});
477    }
478
479    var it = queue.iterator();
480    while (it.next()) |e| {
481        _ = map.remove(e);
482    }
483
484    try expectEqual(@as(usize, 0), map.count());
485}
486
487test "remove at index" {
488    var queue = PQlt.init(testing.allocator, {});
489    defer queue.deinit();
490
491    const items = [_]u32{ 2, 1, 8, 9, 3, 4, 5 };
492    for (items) |e| {
493        _ = try queue.add(e);
494    }
495
496    var it = queue.iterator();
497    var idx: usize = 0;
498    const two_idx = while (it.next()) |elem| {
499        if (elem == 2)
500            break idx;
501        idx += 1;
502    } else unreachable;
503    const sorted_items = [_]u32{ 1, 3, 4, 5, 8, 9 };
504    try expectEqual(queue.removeIndex(two_idx), 2);
505
506    var i: usize = 0;
507    while (queue.removeOrNull()) |n| : (i += 1) {
508        try expectEqual(n, sorted_items[i]);
509    }
510    try expectEqual(queue.removeOrNull(), null);
511}
512
513test "iterator while empty" {
514    var queue = PQlt.init(testing.allocator, {});
515    defer queue.deinit();
516
517    var it = queue.iterator();
518
519    try expectEqual(it.next(), null);
520}
521
522test "shrinkAndFree" {
523    var queue = PQlt.init(testing.allocator, {});
524    defer queue.deinit();
525
526    try queue.ensureTotalCapacity(4);
527    try expect(queue.capacity() >= 4);
528
529    try queue.add(1);
530    try queue.add(2);
531    try queue.add(3);
532    try expect(queue.capacity() >= 4);
533    try expectEqual(@as(usize, 3), queue.count());
534
535    queue.shrinkAndFree(3);
536    try expectEqual(@as(usize, 3), queue.capacity());
537    try expectEqual(@as(usize, 3), queue.count());
538
539    try expectEqual(@as(u32, 1), queue.remove());
540    try expectEqual(@as(u32, 2), queue.remove());
541    try expectEqual(@as(u32, 3), queue.remove());
542    try expect(queue.removeOrNull() == null);
543}
544
545test "update min heap" {
546    var queue = PQlt.init(testing.allocator, {});
547    defer queue.deinit();
548
549    try queue.add(55);
550    try queue.add(44);
551    try queue.add(11);
552    try queue.update(55, 5);
553    try queue.update(44, 4);
554    try queue.update(11, 1);
555    try expectEqual(@as(u32, 1), queue.remove());
556    try expectEqual(@as(u32, 4), queue.remove());
557    try expectEqual(@as(u32, 5), queue.remove());
558}
559
560test "update same min heap" {
561    var queue = PQlt.init(testing.allocator, {});
562    defer queue.deinit();
563
564    try queue.add(1);
565    try queue.add(1);
566    try queue.add(2);
567    try queue.add(2);
568    try queue.update(1, 5);
569    try queue.update(2, 4);
570    try expectEqual(@as(u32, 1), queue.remove());
571    try expectEqual(@as(u32, 2), queue.remove());
572    try expectEqual(@as(u32, 4), queue.remove());
573    try expectEqual(@as(u32, 5), queue.remove());
574}
575
576test "update max heap" {
577    var queue = PQgt.init(testing.allocator, {});
578    defer queue.deinit();
579
580    try queue.add(55);
581    try queue.add(44);
582    try queue.add(11);
583    try queue.update(55, 5);
584    try queue.update(44, 1);
585    try queue.update(11, 4);
586    try expectEqual(@as(u32, 5), queue.remove());
587    try expectEqual(@as(u32, 4), queue.remove());
588    try expectEqual(@as(u32, 1), queue.remove());
589}
590
591test "update same max heap" {
592    var queue = PQgt.init(testing.allocator, {});
593    defer queue.deinit();
594
595    try queue.add(1);
596    try queue.add(1);
597    try queue.add(2);
598    try queue.add(2);
599    try queue.update(1, 5);
600    try queue.update(2, 4);
601    try expectEqual(@as(u32, 5), queue.remove());
602    try expectEqual(@as(u32, 4), queue.remove());
603    try expectEqual(@as(u32, 2), queue.remove());
604    try expectEqual(@as(u32, 1), queue.remove());
605}
606
607test "update after remove" {
608    var queue = PQlt.init(testing.allocator, {});
609    defer queue.deinit();
610
611    try queue.add(1);
612    try expectEqual(@as(u32, 1), queue.remove());
613    try expectError(error.ElementNotFound, queue.update(1, 1));
614}
615
616test "siftUp in remove" {
617    var queue = PQlt.init(testing.allocator, {});
618    defer queue.deinit();
619
620    try queue.addSlice(&.{ 0, 1, 100, 2, 3, 101, 102, 4, 5, 6, 7, 103, 104, 105, 106, 8 });
621
622    _ = queue.removeIndex(std.mem.indexOfScalar(u32, queue.items[0..queue.count()], 102).?);
623
624    const sorted_items = [_]u32{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 100, 101, 103, 104, 105, 106 };
625    for (sorted_items) |e| {
626        try expectEqual(e, queue.remove());
627    }
628}
629
630fn contextLessThan(context: []const u32, a: usize, b: usize) Order {
631    return std.math.order(context[a], context[b]);
632}
633
634const CPQlt = PriorityQueue(usize, []const u32, contextLessThan);
635
636test "add and remove min heap with context comparator" {
637    const context = [_]u32{ 5, 3, 4, 2, 2, 8, 0 };
638
639    var queue = CPQlt.init(testing.allocator, context[0..]);
640    defer queue.deinit();
641
642    try queue.add(0);
643    try queue.add(1);
644    try queue.add(2);
645    try queue.add(3);
646    try queue.add(4);
647    try queue.add(5);
648    try queue.add(6);
649    try expectEqual(@as(usize, 6), queue.remove());
650    try expectEqual(@as(usize, 4), queue.remove());
651    try expectEqual(@as(usize, 3), queue.remove());
652    try expectEqual(@as(usize, 1), queue.remove());
653    try expectEqual(@as(usize, 2), queue.remove());
654    try expectEqual(@as(usize, 0), queue.remove());
655    try expectEqual(@as(usize, 5), queue.remove());
656}