master
1const std = @import("std.zig");
2const Allocator = std.mem.Allocator;
3const assert = std.debug.assert;
4const Order = std.math.Order;
5const testing = std.testing;
6const expect = testing.expect;
7const expectEqual = testing.expectEqual;
8const expectError = testing.expectError;
9
10/// Priority queue for storing generic data. Initialize with `init`.
11/// Provide `compareFn` that returns `Order.lt` when its second
12/// argument should get popped before its third argument,
13/// `Order.eq` if the arguments are of equal priority, or `Order.gt`
14/// if the third argument should be popped first.
15/// For example, to make `pop` return the smallest number, provide
16/// `fn lessThan(context: void, a: T, b: T) Order { _ = context; return std.math.order(a, b); }`
17pub fn PriorityQueue(comptime T: type, comptime Context: type, comptime compareFn: fn (context: Context, a: T, b: T) Order) type {
18 return struct {
19 const Self = @This();
20
21 items: []T,
22 cap: usize,
23 allocator: Allocator,
24 context: Context,
25
26 /// Initialize and return a priority queue.
27 pub fn init(allocator: Allocator, context: Context) Self {
28 return Self{
29 .items = &[_]T{},
30 .cap = 0,
31 .allocator = allocator,
32 .context = context,
33 };
34 }
35
36 /// Free memory used by the queue.
37 pub fn deinit(self: Self) void {
38 self.allocator.free(self.allocatedSlice());
39 }
40
41 /// Insert a new element, maintaining priority.
42 pub fn add(self: *Self, elem: T) !void {
43 try self.ensureUnusedCapacity(1);
44 addUnchecked(self, elem);
45 }
46
47 fn addUnchecked(self: *Self, elem: T) void {
48 self.items.len += 1;
49 self.items[self.items.len - 1] = elem;
50 siftUp(self, self.items.len - 1);
51 }
52
53 fn siftUp(self: *Self, start_index: usize) void {
54 const child = self.items[start_index];
55 var child_index = start_index;
56 while (child_index > 0) {
57 const parent_index = ((child_index - 1) >> 1);
58 const parent = self.items[parent_index];
59 if (compareFn(self.context, child, parent) != .lt) break;
60 self.items[child_index] = parent;
61 child_index = parent_index;
62 }
63 self.items[child_index] = child;
64 }
65
66 /// Add each element in `items` to the queue.
67 pub fn addSlice(self: *Self, items: []const T) !void {
68 try self.ensureUnusedCapacity(items.len);
69 for (items) |e| {
70 self.addUnchecked(e);
71 }
72 }
73
74 /// Look at the highest priority element in the queue. Returns
75 /// `null` if empty.
76 pub fn peek(self: *Self) ?T {
77 return if (self.items.len > 0) self.items[0] else null;
78 }
79
80 /// Pop the highest priority element from the queue. Returns
81 /// `null` if empty.
82 pub fn removeOrNull(self: *Self) ?T {
83 return if (self.items.len > 0) self.remove() else null;
84 }
85
86 /// Remove and return the highest priority element from the
87 /// queue.
88 pub fn remove(self: *Self) T {
89 return self.removeIndex(0);
90 }
91
92 /// Remove and return element at index. Indices are in the
93 /// same order as iterator, which is not necessarily priority
94 /// order.
95 pub fn removeIndex(self: *Self, index: usize) T {
96 assert(self.items.len > index);
97 const last = self.items[self.items.len - 1];
98 const item = self.items[index];
99 self.items[index] = last;
100 self.items.len -= 1;
101
102 if (index == self.items.len) {
103 // Last element removed, nothing more to do.
104 } else if (index == 0) {
105 siftDown(self, index);
106 } else {
107 const parent_index = ((index - 1) >> 1);
108 const parent = self.items[parent_index];
109 if (compareFn(self.context, last, parent) == .gt) {
110 siftDown(self, index);
111 } else {
112 siftUp(self, index);
113 }
114 }
115
116 return item;
117 }
118
119 /// Return the number of elements remaining in the priority
120 /// queue.
121 pub fn count(self: Self) usize {
122 return self.items.len;
123 }
124
125 /// Return the number of elements that can be added to the
126 /// queue before more memory is allocated.
127 pub fn capacity(self: Self) usize {
128 return self.cap;
129 }
130
131 /// Returns a slice of all the items plus the extra capacity, whose memory
132 /// contents are `undefined`.
133 fn allocatedSlice(self: Self) []T {
134 // `items.len` is the length, not the capacity.
135 return self.items.ptr[0..self.cap];
136 }
137
138 fn siftDown(self: *Self, target_index: usize) void {
139 const target_element = self.items[target_index];
140 var index = target_index;
141 while (true) {
142 var lesser_child_i = (std.math.mul(usize, index, 2) catch break) | 1;
143 if (!(lesser_child_i < self.items.len)) break;
144
145 const next_child_i = lesser_child_i + 1;
146 if (next_child_i < self.items.len and compareFn(self.context, self.items[next_child_i], self.items[lesser_child_i]) == .lt) {
147 lesser_child_i = next_child_i;
148 }
149
150 if (compareFn(self.context, target_element, self.items[lesser_child_i]) == .lt) break;
151
152 self.items[index] = self.items[lesser_child_i];
153 index = lesser_child_i;
154 }
155 self.items[index] = target_element;
156 }
157
158 /// PriorityQueue takes ownership of the passed in slice. The slice must have been
159 /// allocated with `allocator`.
160 /// Deinitialize with `deinit`.
161 pub fn fromOwnedSlice(allocator: Allocator, items: []T, context: Context) Self {
162 var self = Self{
163 .items = items,
164 .cap = items.len,
165 .allocator = allocator,
166 .context = context,
167 };
168
169 var i = self.items.len >> 1;
170 while (i > 0) {
171 i -= 1;
172 self.siftDown(i);
173 }
174 return self;
175 }
176
177 /// Ensure that the queue can fit at least `new_capacity` items.
178 pub fn ensureTotalCapacity(self: *Self, new_capacity: usize) !void {
179 var better_capacity = self.cap;
180 if (better_capacity >= new_capacity) return;
181 while (true) {
182 better_capacity += better_capacity / 2 + 8;
183 if (better_capacity >= new_capacity) break;
184 }
185 try self.ensureTotalCapacityPrecise(better_capacity);
186 }
187
188 pub fn ensureTotalCapacityPrecise(self: *Self, new_capacity: usize) !void {
189 if (self.capacity() >= new_capacity) return;
190
191 const old_memory = self.allocatedSlice();
192 const new_memory = try self.allocator.realloc(old_memory, new_capacity);
193 self.items.ptr = new_memory.ptr;
194 self.cap = new_memory.len;
195 }
196
197 /// Ensure that the queue can fit at least `additional_count` **more** item.
198 pub fn ensureUnusedCapacity(self: *Self, additional_count: usize) !void {
199 return self.ensureTotalCapacity(self.items.len + additional_count);
200 }
201
202 /// Reduce allocated capacity to `new_capacity`.
203 pub fn shrinkAndFree(self: *Self, new_capacity: usize) void {
204 assert(new_capacity <= self.cap);
205
206 // Cannot shrink to smaller than the current queue size without invalidating the heap property
207 assert(new_capacity >= self.items.len);
208
209 const old_memory = self.allocatedSlice();
210 const new_memory = self.allocator.realloc(old_memory, new_capacity) catch |e| switch (e) {
211 error.OutOfMemory => { // no problem, capacity is still correct then.
212 return;
213 },
214 };
215
216 self.items.ptr = new_memory.ptr;
217 self.cap = new_memory.len;
218 }
219
220 pub fn clearRetainingCapacity(self: *Self) void {
221 self.items.len = 0;
222 }
223
224 pub fn clearAndFree(self: *Self) void {
225 self.allocator.free(self.allocatedSlice());
226 self.items.len = 0;
227 self.cap = 0;
228 }
229
230 pub fn update(self: *Self, elem: T, new_elem: T) !void {
231 const update_index = blk: {
232 var idx: usize = 0;
233 while (idx < self.items.len) : (idx += 1) {
234 const item = self.items[idx];
235 if (compareFn(self.context, item, elem) == .eq) break :blk idx;
236 }
237 return error.ElementNotFound;
238 };
239 const old_elem: T = self.items[update_index];
240 self.items[update_index] = new_elem;
241 switch (compareFn(self.context, new_elem, old_elem)) {
242 .lt => siftUp(self, update_index),
243 .gt => siftDown(self, update_index),
244 .eq => {}, // Nothing to do as the items have equal priority
245 }
246 }
247
248 pub const Iterator = struct {
249 queue: *PriorityQueue(T, Context, compareFn),
250 count: usize,
251
252 pub fn next(it: *Iterator) ?T {
253 if (it.count >= it.queue.items.len) return null;
254 const out = it.count;
255 it.count += 1;
256 return it.queue.items[out];
257 }
258
259 pub fn reset(it: *Iterator) void {
260 it.count = 0;
261 }
262 };
263
264 /// Return an iterator that walks the queue without consuming
265 /// it. The iteration order may differ from the priority order.
266 /// Invalidated if the heap is modified.
267 pub fn iterator(self: *Self) Iterator {
268 return Iterator{
269 .queue = self,
270 .count = 0,
271 };
272 }
273
274 fn dump(self: *Self) void {
275 const print = std.debug.print;
276 print("{{ ", .{});
277 print("items: ", .{});
278 for (self.items) |e| {
279 print("{}, ", .{e});
280 }
281 print("array: ", .{});
282 for (self.items) |e| {
283 print("{}, ", .{e});
284 }
285 print("len: {} ", .{self.items.len});
286 print("capacity: {}", .{self.cap});
287 print(" }}\n", .{});
288 }
289 };
290}
291
292fn lessThan(context: void, a: u32, b: u32) Order {
293 _ = context;
294 return std.math.order(a, b);
295}
296
297fn greaterThan(context: void, a: u32, b: u32) Order {
298 return lessThan(context, a, b).invert();
299}
300
301const PQlt = PriorityQueue(u32, void, lessThan);
302const PQgt = PriorityQueue(u32, void, greaterThan);
303
304test "add and remove min heap" {
305 var queue = PQlt.init(testing.allocator, {});
306 defer queue.deinit();
307
308 try queue.add(54);
309 try queue.add(12);
310 try queue.add(7);
311 try queue.add(23);
312 try queue.add(25);
313 try queue.add(13);
314 try expectEqual(@as(u32, 7), queue.remove());
315 try expectEqual(@as(u32, 12), queue.remove());
316 try expectEqual(@as(u32, 13), queue.remove());
317 try expectEqual(@as(u32, 23), queue.remove());
318 try expectEqual(@as(u32, 25), queue.remove());
319 try expectEqual(@as(u32, 54), queue.remove());
320}
321
322test "add and remove same min heap" {
323 var queue = PQlt.init(testing.allocator, {});
324 defer queue.deinit();
325
326 try queue.add(1);
327 try queue.add(1);
328 try queue.add(2);
329 try queue.add(2);
330 try queue.add(1);
331 try queue.add(1);
332 try expectEqual(@as(u32, 1), queue.remove());
333 try expectEqual(@as(u32, 1), queue.remove());
334 try expectEqual(@as(u32, 1), queue.remove());
335 try expectEqual(@as(u32, 1), queue.remove());
336 try expectEqual(@as(u32, 2), queue.remove());
337 try expectEqual(@as(u32, 2), queue.remove());
338}
339
340test "removeOrNull on empty" {
341 var queue = PQlt.init(testing.allocator, {});
342 defer queue.deinit();
343
344 try expect(queue.removeOrNull() == null);
345}
346
347test "edge case 3 elements" {
348 var queue = PQlt.init(testing.allocator, {});
349 defer queue.deinit();
350
351 try queue.add(9);
352 try queue.add(3);
353 try queue.add(2);
354 try expectEqual(@as(u32, 2), queue.remove());
355 try expectEqual(@as(u32, 3), queue.remove());
356 try expectEqual(@as(u32, 9), queue.remove());
357}
358
359test "peek" {
360 var queue = PQlt.init(testing.allocator, {});
361 defer queue.deinit();
362
363 try expect(queue.peek() == null);
364 try queue.add(9);
365 try queue.add(3);
366 try queue.add(2);
367 try expectEqual(@as(u32, 2), queue.peek().?);
368 try expectEqual(@as(u32, 2), queue.peek().?);
369}
370
371test "sift up with odd indices" {
372 var queue = PQlt.init(testing.allocator, {});
373 defer queue.deinit();
374 const items = [_]u32{ 15, 7, 21, 14, 13, 22, 12, 6, 7, 25, 5, 24, 11, 16, 15, 24, 2, 1 };
375 for (items) |e| {
376 try queue.add(e);
377 }
378
379 const sorted_items = [_]u32{ 1, 2, 5, 6, 7, 7, 11, 12, 13, 14, 15, 15, 16, 21, 22, 24, 24, 25 };
380 for (sorted_items) |e| {
381 try expectEqual(e, queue.remove());
382 }
383}
384
385test "addSlice" {
386 var queue = PQlt.init(testing.allocator, {});
387 defer queue.deinit();
388 const items = [_]u32{ 15, 7, 21, 14, 13, 22, 12, 6, 7, 25, 5, 24, 11, 16, 15, 24, 2, 1 };
389 try queue.addSlice(items[0..]);
390
391 const sorted_items = [_]u32{ 1, 2, 5, 6, 7, 7, 11, 12, 13, 14, 15, 15, 16, 21, 22, 24, 24, 25 };
392 for (sorted_items) |e| {
393 try expectEqual(e, queue.remove());
394 }
395}
396
397test "fromOwnedSlice trivial case 0" {
398 const items = [0]u32{};
399 const queue_items = try testing.allocator.dupe(u32, &items);
400 var queue = PQlt.fromOwnedSlice(testing.allocator, queue_items[0..], {});
401 defer queue.deinit();
402 try expectEqual(@as(usize, 0), queue.count());
403 try expect(queue.removeOrNull() == null);
404}
405
406test "fromOwnedSlice trivial case 1" {
407 const items = [1]u32{1};
408 const queue_items = try testing.allocator.dupe(u32, &items);
409 var queue = PQlt.fromOwnedSlice(testing.allocator, queue_items[0..], {});
410 defer queue.deinit();
411
412 try expectEqual(@as(usize, 1), queue.count());
413 try expectEqual(items[0], queue.remove());
414 try expect(queue.removeOrNull() == null);
415}
416
417test "fromOwnedSlice" {
418 const items = [_]u32{ 15, 7, 21, 14, 13, 22, 12, 6, 7, 25, 5, 24, 11, 16, 15, 24, 2, 1 };
419 const heap_items = try testing.allocator.dupe(u32, items[0..]);
420 var queue = PQlt.fromOwnedSlice(testing.allocator, heap_items[0..], {});
421 defer queue.deinit();
422
423 const sorted_items = [_]u32{ 1, 2, 5, 6, 7, 7, 11, 12, 13, 14, 15, 15, 16, 21, 22, 24, 24, 25 };
424 for (sorted_items) |e| {
425 try expectEqual(e, queue.remove());
426 }
427}
428
429test "add and remove max heap" {
430 var queue = PQgt.init(testing.allocator, {});
431 defer queue.deinit();
432
433 try queue.add(54);
434 try queue.add(12);
435 try queue.add(7);
436 try queue.add(23);
437 try queue.add(25);
438 try queue.add(13);
439 try expectEqual(@as(u32, 54), queue.remove());
440 try expectEqual(@as(u32, 25), queue.remove());
441 try expectEqual(@as(u32, 23), queue.remove());
442 try expectEqual(@as(u32, 13), queue.remove());
443 try expectEqual(@as(u32, 12), queue.remove());
444 try expectEqual(@as(u32, 7), queue.remove());
445}
446
447test "add and remove same max heap" {
448 var queue = PQgt.init(testing.allocator, {});
449 defer queue.deinit();
450
451 try queue.add(1);
452 try queue.add(1);
453 try queue.add(2);
454 try queue.add(2);
455 try queue.add(1);
456 try queue.add(1);
457 try expectEqual(@as(u32, 2), queue.remove());
458 try expectEqual(@as(u32, 2), queue.remove());
459 try expectEqual(@as(u32, 1), queue.remove());
460 try expectEqual(@as(u32, 1), queue.remove());
461 try expectEqual(@as(u32, 1), queue.remove());
462 try expectEqual(@as(u32, 1), queue.remove());
463}
464
465test "iterator" {
466 var queue = PQlt.init(testing.allocator, {});
467 var map = std.AutoHashMap(u32, void).init(testing.allocator);
468 defer {
469 queue.deinit();
470 map.deinit();
471 }
472
473 const items = [_]u32{ 54, 12, 7, 23, 25, 13 };
474 for (items) |e| {
475 _ = try queue.add(e);
476 try map.put(e, {});
477 }
478
479 var it = queue.iterator();
480 while (it.next()) |e| {
481 _ = map.remove(e);
482 }
483
484 try expectEqual(@as(usize, 0), map.count());
485}
486
487test "remove at index" {
488 var queue = PQlt.init(testing.allocator, {});
489 defer queue.deinit();
490
491 const items = [_]u32{ 2, 1, 8, 9, 3, 4, 5 };
492 for (items) |e| {
493 _ = try queue.add(e);
494 }
495
496 var it = queue.iterator();
497 var idx: usize = 0;
498 const two_idx = while (it.next()) |elem| {
499 if (elem == 2)
500 break idx;
501 idx += 1;
502 } else unreachable;
503 const sorted_items = [_]u32{ 1, 3, 4, 5, 8, 9 };
504 try expectEqual(queue.removeIndex(two_idx), 2);
505
506 var i: usize = 0;
507 while (queue.removeOrNull()) |n| : (i += 1) {
508 try expectEqual(n, sorted_items[i]);
509 }
510 try expectEqual(queue.removeOrNull(), null);
511}
512
513test "iterator while empty" {
514 var queue = PQlt.init(testing.allocator, {});
515 defer queue.deinit();
516
517 var it = queue.iterator();
518
519 try expectEqual(it.next(), null);
520}
521
522test "shrinkAndFree" {
523 var queue = PQlt.init(testing.allocator, {});
524 defer queue.deinit();
525
526 try queue.ensureTotalCapacity(4);
527 try expect(queue.capacity() >= 4);
528
529 try queue.add(1);
530 try queue.add(2);
531 try queue.add(3);
532 try expect(queue.capacity() >= 4);
533 try expectEqual(@as(usize, 3), queue.count());
534
535 queue.shrinkAndFree(3);
536 try expectEqual(@as(usize, 3), queue.capacity());
537 try expectEqual(@as(usize, 3), queue.count());
538
539 try expectEqual(@as(u32, 1), queue.remove());
540 try expectEqual(@as(u32, 2), queue.remove());
541 try expectEqual(@as(u32, 3), queue.remove());
542 try expect(queue.removeOrNull() == null);
543}
544
545test "update min heap" {
546 var queue = PQlt.init(testing.allocator, {});
547 defer queue.deinit();
548
549 try queue.add(55);
550 try queue.add(44);
551 try queue.add(11);
552 try queue.update(55, 5);
553 try queue.update(44, 4);
554 try queue.update(11, 1);
555 try expectEqual(@as(u32, 1), queue.remove());
556 try expectEqual(@as(u32, 4), queue.remove());
557 try expectEqual(@as(u32, 5), queue.remove());
558}
559
560test "update same min heap" {
561 var queue = PQlt.init(testing.allocator, {});
562 defer queue.deinit();
563
564 try queue.add(1);
565 try queue.add(1);
566 try queue.add(2);
567 try queue.add(2);
568 try queue.update(1, 5);
569 try queue.update(2, 4);
570 try expectEqual(@as(u32, 1), queue.remove());
571 try expectEqual(@as(u32, 2), queue.remove());
572 try expectEqual(@as(u32, 4), queue.remove());
573 try expectEqual(@as(u32, 5), queue.remove());
574}
575
576test "update max heap" {
577 var queue = PQgt.init(testing.allocator, {});
578 defer queue.deinit();
579
580 try queue.add(55);
581 try queue.add(44);
582 try queue.add(11);
583 try queue.update(55, 5);
584 try queue.update(44, 1);
585 try queue.update(11, 4);
586 try expectEqual(@as(u32, 5), queue.remove());
587 try expectEqual(@as(u32, 4), queue.remove());
588 try expectEqual(@as(u32, 1), queue.remove());
589}
590
591test "update same max heap" {
592 var queue = PQgt.init(testing.allocator, {});
593 defer queue.deinit();
594
595 try queue.add(1);
596 try queue.add(1);
597 try queue.add(2);
598 try queue.add(2);
599 try queue.update(1, 5);
600 try queue.update(2, 4);
601 try expectEqual(@as(u32, 5), queue.remove());
602 try expectEqual(@as(u32, 4), queue.remove());
603 try expectEqual(@as(u32, 2), queue.remove());
604 try expectEqual(@as(u32, 1), queue.remove());
605}
606
607test "update after remove" {
608 var queue = PQlt.init(testing.allocator, {});
609 defer queue.deinit();
610
611 try queue.add(1);
612 try expectEqual(@as(u32, 1), queue.remove());
613 try expectError(error.ElementNotFound, queue.update(1, 1));
614}
615
616test "siftUp in remove" {
617 var queue = PQlt.init(testing.allocator, {});
618 defer queue.deinit();
619
620 try queue.addSlice(&.{ 0, 1, 100, 2, 3, 101, 102, 4, 5, 6, 7, 103, 104, 105, 106, 8 });
621
622 _ = queue.removeIndex(std.mem.indexOfScalar(u32, queue.items[0..queue.count()], 102).?);
623
624 const sorted_items = [_]u32{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 100, 101, 103, 104, 105, 106 };
625 for (sorted_items) |e| {
626 try expectEqual(e, queue.remove());
627 }
628}
629
630fn contextLessThan(context: []const u32, a: usize, b: usize) Order {
631 return std.math.order(context[a], context[b]);
632}
633
634const CPQlt = PriorityQueue(usize, []const u32, contextLessThan);
635
636test "add and remove min heap with context comparator" {
637 const context = [_]u32{ 5, 3, 4, 2, 2, 8, 0 };
638
639 var queue = CPQlt.init(testing.allocator, context[0..]);
640 defer queue.deinit();
641
642 try queue.add(0);
643 try queue.add(1);
644 try queue.add(2);
645 try queue.add(3);
646 try queue.add(4);
647 try queue.add(5);
648 try queue.add(6);
649 try expectEqual(@as(usize, 6), queue.remove());
650 try expectEqual(@as(usize, 4), queue.remove());
651 try expectEqual(@as(usize, 3), queue.remove());
652 try expectEqual(@as(usize, 1), queue.remove());
653 try expectEqual(@as(usize, 2), queue.remove());
654 try expectEqual(@as(usize, 0), queue.remove());
655 try expectEqual(@as(usize, 5), queue.remove());
656}