Commit 9f4f43cf7f

Alex Kladov <aleksey.kladov@gmail.com>
2024-05-13 16:17:11
std: align PriorityQueue and ArrayList API-wise
ArrayList uses `items` slice to store len initialized items, while PriorityQueue stores `capacity` potentially uninitialized items. This is a surprising difference in the API that leads to bugs! https://github.com/tigerbeetle/tigerbeetle/pull/1948
1 parent 8aae0d8
Changed files (1)
lib/std/priority_queue.zig
@@ -19,7 +19,7 @@ pub fn PriorityQueue(comptime T: type, comptime Context: type, comptime compareF
         const Self = @This();
 
         items: []T,
-        len: usize,
+        cap: usize,
         allocator: Allocator,
         context: Context,
 
@@ -27,7 +27,7 @@ pub fn PriorityQueue(comptime T: type, comptime Context: type, comptime compareF
         pub fn init(allocator: Allocator, context: Context) Self {
             return Self{
                 .items = &[_]T{},
-                .len = 0,
+                .cap = 0,
                 .allocator = allocator,
                 .context = context,
             };
@@ -35,7 +35,7 @@ pub fn PriorityQueue(comptime T: type, comptime Context: type, comptime compareF
 
         /// Free memory used by the queue.
         pub fn deinit(self: Self) void {
-            self.allocator.free(self.items);
+            self.allocator.free(self.allocatedSlice());
         }
 
         /// Insert a new element, maintaining priority.
@@ -45,9 +45,9 @@ pub fn PriorityQueue(comptime T: type, comptime Context: type, comptime compareF
         }
 
         fn addUnchecked(self: *Self, elem: T) void {
-            self.items[self.len] = elem;
-            siftUp(self, self.len);
-            self.len += 1;
+            self.items.len += 1;
+            self.items[self.items.len - 1] = elem;
+            siftUp(self, self.items.len - 1);
         }
 
         fn siftUp(self: *Self, start_index: usize) void {
@@ -74,13 +74,13 @@ pub fn PriorityQueue(comptime T: type, comptime Context: type, comptime compareF
         /// Look at the highest priority element in the queue. Returns
         /// `null` if empty.
         pub fn peek(self: *Self) ?T {
-            return if (self.len > 0) self.items[0] else null;
+            return if (self.items.len > 0) self.items[0] else null;
         }
 
         /// Pop the highest priority element from the queue. Returns
         /// `null` if empty.
         pub fn removeOrNull(self: *Self) ?T {
-            return if (self.len > 0) self.remove() else null;
+            return if (self.items.len > 0) self.remove() else null;
         }
 
         /// Remove and return the highest priority element from the
@@ -93,13 +93,15 @@ pub fn PriorityQueue(comptime T: type, comptime Context: type, comptime compareF
         /// same order as iterator, which is not necessarily priority
         /// order.
         pub fn removeIndex(self: *Self, index: usize) T {
-            assert(self.len > index);
-            const last = self.items[self.len - 1];
+            assert(self.items.len > index);
+            const last = self.items[self.items.len - 1];
             const item = self.items[index];
             self.items[index] = last;
-            self.len -= 1;
+            self.items.len -= 1;
 
-            if (index == 0) {
+            if (index == self.items.len) {
+                // Last element removed, nothing more to do.
+            } else if (index == 0) {
                 siftDown(self, index);
             } else {
                 const parent_index = ((index - 1) >> 1);
@@ -117,13 +119,20 @@ pub fn PriorityQueue(comptime T: type, comptime Context: type, comptime compareF
         /// Return the number of elements remaining in the priority
         /// queue.
         pub fn count(self: Self) usize {
-            return self.len;
+            return self.items.len;
         }
 
         /// Return the number of elements that can be added to the
         /// queue before more memory is allocated.
         pub fn capacity(self: Self) usize {
-            return self.items.len;
+            return self.cap;
+        }
+
+        /// Returns a slice of all the items plus the extra capacity, whose memory
+        /// contents are `undefined`.
+        fn allocatedSlice(self: Self) []T {
+            // `items.len` is the length, not the capacity.
+            return self.items.ptr[0..self.cap];
         }
 
         fn siftDown(self: *Self, target_index: usize) void {
@@ -131,10 +140,10 @@ pub fn PriorityQueue(comptime T: type, comptime Context: type, comptime compareF
             var index = target_index;
             while (true) {
                 var lesser_child_i = (std.math.mul(usize, index, 2) catch break) | 1;
-                if (!(lesser_child_i < self.len)) break;
+                if (!(lesser_child_i < self.items.len)) break;
 
                 const next_child_i = lesser_child_i + 1;
-                if (next_child_i < self.len and compareFn(self.context, self.items[next_child_i], self.items[lesser_child_i]) == .lt) {
+                if (next_child_i < self.items.len and compareFn(self.context, self.items[next_child_i], self.items[lesser_child_i]) == .lt) {
                     lesser_child_i = next_child_i;
                 }
 
@@ -152,12 +161,12 @@ pub fn PriorityQueue(comptime T: type, comptime Context: type, comptime compareF
         pub fn fromOwnedSlice(allocator: Allocator, items: []T, context: Context) Self {
             var self = Self{
                 .items = items,
-                .len = items.len,
+                .cap = items.len,
                 .allocator = allocator,
                 .context = context,
             };
 
-            var i = self.len >> 1;
+            var i = self.items.len >> 1;
             while (i > 0) {
                 i -= 1;
                 self.siftDown(i);
@@ -167,39 +176,45 @@ pub fn PriorityQueue(comptime T: type, comptime Context: type, comptime compareF
 
         /// Ensure that the queue can fit at least `new_capacity` items.
         pub fn ensureTotalCapacity(self: *Self, new_capacity: usize) !void {
-            var better_capacity = self.capacity();
+            var better_capacity = self.cap;
             if (better_capacity >= new_capacity) return;
             while (true) {
                 better_capacity += better_capacity / 2 + 8;
                 if (better_capacity >= new_capacity) break;
             }
-            self.items = try self.allocator.realloc(self.items, better_capacity);
+            const old_memory = self.allocatedSlice();
+            const new_memory = try self.allocator.realloc(old_memory, better_capacity);
+            self.items.ptr = new_memory.ptr;
+            self.cap = new_memory.len;
         }
 
         /// Ensure that the queue can fit at least `additional_count` **more** item.
         pub fn ensureUnusedCapacity(self: *Self, additional_count: usize) !void {
-            return self.ensureTotalCapacity(self.len + additional_count);
+            return self.ensureTotalCapacity(self.items.len + additional_count);
         }
 
-        /// Reduce allocated capacity to `new_len`.
-        pub fn shrinkAndFree(self: *Self, new_len: usize) void {
-            assert(new_len <= self.items.len);
+        /// Reduce allocated capacity to `new_capacity`.
+        pub fn shrinkAndFree(self: *Self, new_capacity: usize) void {
+            assert(new_capacity <= self.cap);
 
             // Cannot shrink to smaller than the current queue size without invalidating the heap property
-            assert(new_len >= self.len);
+            assert(new_capacity >= self.items.len);
 
-            self.items = self.allocator.realloc(self.items[0..], new_len) catch |e| switch (e) {
+            const old_memory = self.allocatedSlice();
+            const new_memory = self.allocator.realloc(old_memory, new_capacity) catch |e| switch (e) {
                 error.OutOfMemory => { // no problem, capacity is still correct then.
-                    self.items.len = new_len;
                     return;
                 },
             };
+
+            self.items.ptr = new_memory.ptr;
+            self.cap = new_memory.len;
         }
 
         pub fn update(self: *Self, elem: T, new_elem: T) !void {
             const update_index = blk: {
                 var idx: usize = 0;
-                while (idx < self.len) : (idx += 1) {
+                while (idx < self.items.len) : (idx += 1) {
                     const item = self.items[idx];
                     if (compareFn(self.context, item, elem) == .eq) break :blk idx;
                 }
@@ -219,7 +234,7 @@ pub fn PriorityQueue(comptime T: type, comptime Context: type, comptime compareF
             count: usize,
 
             pub fn next(it: *Iterator) ?T {
-                if (it.count >= it.queue.len) return null;
+                if (it.count >= it.queue.items.len) return null;
                 const out = it.count;
                 it.count += 1;
                 return it.queue.items[out];
@@ -244,16 +259,15 @@ pub fn PriorityQueue(comptime T: type, comptime Context: type, comptime compareF
             const print = std.debug.print;
             print("{{ ", .{});
             print("items: ", .{});
-            for (self.items, 0..) |e, i| {
-                if (i >= self.len) break;
+            for (self.items) |e| {
                 print("{}, ", .{e});
             }
             print("array: ", .{});
             for (self.items) |e| {
                 print("{}, ", .{e});
             }
-            print("len: {} ", .{self.len});
-            print("capacity: {}", .{self.capacity()});
+            print("len: {} ", .{self.items.len});
+            print("capacity: {}", .{self.cap});
             print(" }}\n", .{});
         }
     };
@@ -369,7 +383,7 @@ test "fromOwnedSlice trivial case 0" {
     const queue_items = try testing.allocator.dupe(u32, &items);
     var queue = PQlt.fromOwnedSlice(testing.allocator, queue_items[0..], {});
     defer queue.deinit();
-    try expectEqual(@as(usize, 0), queue.len);
+    try expectEqual(@as(usize, 0), queue.count());
     try expect(queue.removeOrNull() == null);
 }
 
@@ -379,7 +393,7 @@ test "fromOwnedSlice trivial case 1" {
     var queue = PQlt.fromOwnedSlice(testing.allocator, queue_items[0..], {});
     defer queue.deinit();
 
-    try expectEqual(@as(usize, 1), queue.len);
+    try expectEqual(@as(usize, 1), queue.count());
     try expectEqual(items[0], queue.remove());
     try expect(queue.removeOrNull() == null);
 }
@@ -500,11 +514,11 @@ test "shrinkAndFree" {
     try queue.add(2);
     try queue.add(3);
     try expect(queue.capacity() >= 4);
-    try expectEqual(@as(usize, 3), queue.len);
+    try expectEqual(@as(usize, 3), queue.count());
 
     queue.shrinkAndFree(3);
     try expectEqual(@as(usize, 3), queue.capacity());
-    try expectEqual(@as(usize, 3), queue.len);
+    try expectEqual(@as(usize, 3), queue.count());
 
     try expectEqual(@as(u32, 1), queue.remove());
     try expectEqual(@as(u32, 2), queue.remove());
@@ -589,7 +603,7 @@ test "siftUp in remove" {
 
     try queue.addSlice(&.{ 0, 1, 100, 2, 3, 101, 102, 4, 5, 6, 7, 103, 104, 105, 106, 8 });
 
-    _ = queue.removeIndex(std.mem.indexOfScalar(u32, queue.items[0..queue.len], 102).?);
+    _ = queue.removeIndex(std.mem.indexOfScalar(u32, queue.items[0..queue.count()], 102).?);
 
     const sorted_items = [_]u32{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 100, 101, 103, 104, 105, 106 };
     for (sorted_items) |e| {