Commit 242ab81112

Andrew Kelley <andrew@ziglang.org>
2024-03-16 23:45:10
std: introduce pointer stability locks to hash maps (#17719)
This adds std.debug.SafetyLock and uses it in std.HashMapUnmanaged by adding lockPointers() and unlockPointers(). This provides a way to detect when an illegal modification has happened and panic rather than invoke undefined behavior.
1 parent 1b8d1b1
lib/std/array_hash_map.zig
@@ -137,6 +137,23 @@ pub fn ArrayHashMap(
             self.* = undefined;
         }
 
+        /// Puts the hash map into a state where any method call that would
+        /// cause an existing key or value pointer to become invalidated will
+        /// instead trigger an assertion.
+        ///
+        /// An additional call to `lockPointers` in such state also triggers an
+        /// assertion.
+        ///
+        /// `unlockPointers` returns the hash map to the previous state.
+        pub fn lockPointers(self: *Self) void {
+            self.unmanaged.lockPointers();
+        }
+
+        /// Undoes a call to `lockPointers`.
+        pub fn unlockPointers(self: *Self) void {
+            self.unmanaged.unlockPointers();
+        }
+
         /// Clears the map but retains the backing allocation for future use.
         pub fn clearRetainingCapacity(self: *Self) void {
             return self.unmanaged.clearRetainingCapacity();
@@ -403,6 +420,7 @@ pub fn ArrayHashMap(
         /// Set the map to an empty state, making deinitialization a no-op, and
         /// returning a copy of the original.
         pub fn move(self: *Self) Self {
+            self.pointer_stability.assertUnlocked();
             const result = self.*;
             self.unmanaged = .{};
             return result;
@@ -495,6 +513,9 @@ pub fn ArrayHashMapUnmanaged(
         /// by how many total indexes there are.
         index_header: ?*IndexHeader = null,
 
+        /// Used to detect memory safety violations.
+        pointer_stability: std.debug.SafetyLock = .{},
+
         comptime {
             std.hash_map.verifyContext(Context, K, K, u32, true);
         }
@@ -589,6 +610,7 @@ pub fn ArrayHashMapUnmanaged(
         /// Note that this does not free keys or values.  You must take care of that
         /// before calling this function, if it is needed.
         pub fn deinit(self: *Self, allocator: Allocator) void {
+            self.pointer_stability.assertUnlocked();
             self.entries.deinit(allocator);
             if (self.index_header) |header| {
                 header.free(allocator);
@@ -596,8 +618,28 @@ pub fn ArrayHashMapUnmanaged(
             self.* = undefined;
         }
 
+        /// Puts the hash map into a state where any method call that would
+        /// cause an existing key or value pointer to become invalidated will
+        /// instead trigger an assertion.
+        ///
+        /// An additional call to `lockPointers` in such state also triggers an
+        /// assertion.
+        ///
+        /// `unlockPointers` returns the hash map to the previous state.
+        pub fn lockPointers(self: *Self) void {
+            self.pointer_stability.lock();
+        }
+
+        /// Undoes a call to `lockPointers`.
+        pub fn unlockPointers(self: *Self) void {
+            self.pointer_stability.unlock();
+        }
+
         /// Clears the map but retains the backing allocation for future use.
         pub fn clearRetainingCapacity(self: *Self) void {
+            self.pointer_stability.lock();
+            defer self.pointer_stability.unlock();
+
             self.entries.len = 0;
             if (self.index_header) |header| {
                 switch (header.capacityIndexType()) {
@@ -610,6 +652,9 @@ pub fn ArrayHashMapUnmanaged(
 
         /// Clears the map and releases the backing allocation
         pub fn clearAndFree(self: *Self, allocator: Allocator) void {
+            self.pointer_stability.lock();
+            defer self.pointer_stability.unlock();
+
             self.entries.shrinkAndFree(allocator, 0);
             if (self.index_header) |header| {
                 header.free(allocator);
@@ -795,6 +840,9 @@ pub fn ArrayHashMapUnmanaged(
             return self.ensureTotalCapacityContext(allocator, new_capacity, undefined);
         }
         pub fn ensureTotalCapacityContext(self: *Self, allocator: Allocator, new_capacity: usize, ctx: Context) !void {
+            self.pointer_stability.lock();
+            defer self.pointer_stability.unlock();
+
             if (new_capacity <= linear_scan_max) {
                 try self.entries.ensureTotalCapacity(allocator, new_capacity);
                 return;
@@ -1079,6 +1127,9 @@ pub fn ArrayHashMapUnmanaged(
             return self.fetchSwapRemoveContextAdapted(key, ctx, undefined);
         }
         pub fn fetchSwapRemoveContextAdapted(self: *Self, key: anytype, key_ctx: anytype, ctx: Context) ?KV {
+            self.pointer_stability.lock();
+            defer self.pointer_stability.unlock();
+
             return self.fetchRemoveByKey(key, key_ctx, if (store_hash) {} else ctx, .swap);
         }
 
@@ -1100,6 +1151,9 @@ pub fn ArrayHashMapUnmanaged(
             return self.fetchOrderedRemoveContextAdapted(key, ctx, undefined);
         }
         pub fn fetchOrderedRemoveContextAdapted(self: *Self, key: anytype, key_ctx: anytype, ctx: Context) ?KV {
+            self.pointer_stability.lock();
+            defer self.pointer_stability.unlock();
+
             return self.fetchRemoveByKey(key, key_ctx, if (store_hash) {} else ctx, .ordered);
         }
 
@@ -1121,6 +1175,9 @@ pub fn ArrayHashMapUnmanaged(
             return self.swapRemoveContextAdapted(key, ctx, undefined);
         }
         pub fn swapRemoveContextAdapted(self: *Self, key: anytype, key_ctx: anytype, ctx: Context) bool {
+            self.pointer_stability.lock();
+            defer self.pointer_stability.unlock();
+
             return self.removeByKey(key, key_ctx, if (store_hash) {} else ctx, .swap);
         }
 
@@ -1142,6 +1199,9 @@ pub fn ArrayHashMapUnmanaged(
             return self.orderedRemoveContextAdapted(key, ctx, undefined);
         }
         pub fn orderedRemoveContextAdapted(self: *Self, key: anytype, key_ctx: anytype, ctx: Context) bool {
+            self.pointer_stability.lock();
+            defer self.pointer_stability.unlock();
+
             return self.removeByKey(key, key_ctx, if (store_hash) {} else ctx, .ordered);
         }
 
@@ -1154,6 +1214,9 @@ pub fn ArrayHashMapUnmanaged(
             return self.swapRemoveAtContext(index, undefined);
         }
         pub fn swapRemoveAtContext(self: *Self, index: usize, ctx: Context) void {
+            self.pointer_stability.lock();
+            defer self.pointer_stability.unlock();
+
             self.removeByIndex(index, if (store_hash) {} else ctx, .swap);
         }
 
@@ -1167,6 +1230,9 @@ pub fn ArrayHashMapUnmanaged(
             return self.orderedRemoveAtContext(index, undefined);
         }
         pub fn orderedRemoveAtContext(self: *Self, index: usize, ctx: Context) void {
+            self.pointer_stability.lock();
+            defer self.pointer_stability.unlock();
+
             self.removeByIndex(index, if (store_hash) {} else ctx, .ordered);
         }
 
@@ -1196,6 +1262,7 @@ pub fn ArrayHashMapUnmanaged(
         /// Set the map to an empty state, making deinitialization a no-op, and
         /// returning a copy of the original.
         pub fn move(self: *Self) Self {
+            self.pointer_stability.assertUnlocked();
             const result = self.*;
             self.* = .{};
             return result;
@@ -1271,6 +1338,9 @@ pub fn ArrayHashMapUnmanaged(
             sort_ctx: anytype,
             ctx: Context,
         ) void {
+            self.pointer_stability.lock();
+            defer self.pointer_stability.unlock();
+
             switch (mode) {
                 .stable => self.entries.sort(sort_ctx),
                 .unstable => self.entries.sortUnstable(sort_ctx),
@@ -1288,6 +1358,9 @@ pub fn ArrayHashMapUnmanaged(
             return self.shrinkRetainingCapacityContext(new_len, undefined);
         }
         pub fn shrinkRetainingCapacityContext(self: *Self, new_len: usize, ctx: Context) void {
+            self.pointer_stability.lock();
+            defer self.pointer_stability.unlock();
+
             // Remove index entries from the new length onwards.
             // Explicitly choose to ONLY remove index entries and not the underlying array list
             // entries as we're going to remove them in the subsequent shrink call.
@@ -1307,6 +1380,9 @@ pub fn ArrayHashMapUnmanaged(
             return self.shrinkAndFreeContext(allocator, new_len, undefined);
         }
         pub fn shrinkAndFreeContext(self: *Self, allocator: Allocator, new_len: usize, ctx: Context) void {
+            self.pointer_stability.lock();
+            defer self.pointer_stability.unlock();
+
             // Remove index entries from the new length onwards.
             // Explicitly choose to ONLY remove index entries and not the underlying array list
             // entries as we're going to remove them in the subsequent shrink call.
@@ -1325,6 +1401,9 @@ pub fn ArrayHashMapUnmanaged(
             return self.popContext(undefined);
         }
         pub fn popContext(self: *Self, ctx: Context) KV {
+            self.pointer_stability.lock();
+            defer self.pointer_stability.unlock();
+
             const item = self.entries.get(self.entries.len - 1);
             if (self.index_header) |header|
                 self.removeFromIndexByIndex(self.entries.len - 1, if (store_hash) {} else ctx, header);
@@ -1346,9 +1425,13 @@ pub fn ArrayHashMapUnmanaged(
             return if (self.entries.len == 0) null else self.popContext(ctx);
         }
 
-        // ------------------ No pub fns below this point ------------------
-
-        fn fetchRemoveByKey(self: *Self, key: anytype, key_ctx: anytype, ctx: ByIndexContext, comptime removal_type: RemovalType) ?KV {
+        fn fetchRemoveByKey(
+            self: *Self,
+            key: anytype,
+            key_ctx: anytype,
+            ctx: ByIndexContext,
+            comptime removal_type: RemovalType,
+        ) ?KV {
             const header = self.index_header orelse {
                 // Linear scan.
                 const key_hash = if (store_hash) key_ctx.hash(key) else {};
@@ -1377,7 +1460,15 @@ pub fn ArrayHashMapUnmanaged(
                 .u32 => self.fetchRemoveByKeyGeneric(key, key_ctx, ctx, header, u32, removal_type),
             };
         }
-        fn fetchRemoveByKeyGeneric(self: *Self, key: anytype, key_ctx: anytype, ctx: ByIndexContext, header: *IndexHeader, comptime I: type, comptime removal_type: RemovalType) ?KV {
+        fn fetchRemoveByKeyGeneric(
+            self: *Self,
+            key: anytype,
+            key_ctx: anytype,
+            ctx: ByIndexContext,
+            header: *IndexHeader,
+            comptime I: type,
+            comptime removal_type: RemovalType,
+        ) ?KV {
             const indexes = header.indexes(I);
             const entry_index = self.removeFromIndexByKey(key, key_ctx, header, I, indexes) orelse return null;
             const slice = self.entries.slice();
@@ -1389,7 +1480,13 @@ pub fn ArrayHashMapUnmanaged(
             return removed_entry;
         }
 
-        fn removeByKey(self: *Self, key: anytype, key_ctx: anytype, ctx: ByIndexContext, comptime removal_type: RemovalType) bool {
+        fn removeByKey(
+            self: *Self,
+            key: anytype,
+            key_ctx: anytype,
+            ctx: ByIndexContext,
+            comptime removal_type: RemovalType,
+        ) bool {
             const header = self.index_header orelse {
                 // Linear scan.
                 const key_hash = if (store_hash) key_ctx.hash(key) else {};
lib/std/debug.zig
@@ -2838,6 +2838,29 @@ pub fn ConfigurableTrace(comptime size: usize, comptime stack_frame_count: usize
     };
 }
 
+pub const SafetyLock = struct {
+    state: State = .unlocked,
+
+    pub const State = if (runtime_safety) enum { unlocked, locked } else enum { unlocked };
+
+    pub fn lock(l: *SafetyLock) void {
+        if (!runtime_safety) return;
+        assert(l.state == .unlocked);
+        l.state = .locked;
+    }
+
+    pub fn unlock(l: *SafetyLock) void {
+        if (!runtime_safety) return;
+        assert(l.state == .locked);
+        l.state = .unlocked;
+    }
+
+    pub fn assertUnlocked(l: SafetyLock) void {
+        if (!runtime_safety) return;
+        assert(l.state == .unlocked);
+    }
+};
+
 test {
     _ = &dump_hex;
 }
lib/std/hash_map.zig
@@ -416,6 +416,23 @@ pub fn HashMap(
             };
         }
 
+        /// Puts the hash map into a state where any method call that would
+        /// cause an existing key or value pointer to become invalidated will
+        /// instead trigger an assertion.
+        ///
+        /// An additional call to `lockPointers` in such state also triggers an
+        /// assertion.
+        ///
+        /// `unlockPointers` returns the hash map to the previous state.
+        pub fn lockPointers(self: *Self) void {
+            self.unmanaged.lockPointers();
+        }
+
+        /// Undoes a call to `lockPointers`.
+        pub fn unlockPointers(self: *Self) void {
+            self.unmanaged.unlockPointers();
+        }
+
         /// Release the backing array and invalidate this map.
         /// This does *not* deinit keys, values, or the context!
         /// If your keys or values need to be released, ensure
@@ -672,6 +689,7 @@ pub fn HashMap(
         /// Set the map to an empty state, making deinitialization a no-op, and
         /// returning a copy of the original.
         pub fn move(self: *Self) Self {
+            self.unmanaged.pointer_stability.assertUnlocked();
             const result = self.*;
             self.unmanaged = .{};
             return result;
@@ -722,6 +740,9 @@ pub fn HashMapUnmanaged(
         /// `max_load_percentage`.
         available: Size = 0,
 
+        /// Used to detect memory safety violations.
+        pointer_stability: std.debug.SafetyLock = .{},
+
         // This is purely empirical and not a /very smart magic constantâ„¢/.
         /// Capacity of the first grow when bootstrapping the hashmap.
         const minimal_capacity = 8;
@@ -884,11 +905,29 @@ pub fn HashMapUnmanaged(
             };
         }
 
+        /// Puts the hash map into a state where any method call that would
+        /// cause an existing key or value pointer to become invalidated will
+        /// instead trigger an assertion.
+        ///
+        /// An additional call to `lockPointers` in such state also triggers an
+        /// assertion.
+        ///
+        /// `unlockPointers` returns the hash map to the previous state.
+        pub fn lockPointers(self: *Self) void {
+            self.pointer_stability.lock();
+        }
+
+        /// Undoes a call to `lockPointers`.
+        pub fn unlockPointers(self: *Self) void {
+            self.pointer_stability.unlock();
+        }
+
         fn isUnderMaxLoadPercentage(size: Size, cap: Size) bool {
             return size * 100 < max_load_percentage * cap;
         }
 
         pub fn deinit(self: *Self, allocator: Allocator) void {
+            self.pointer_stability.assertUnlocked();
             self.deallocate(allocator);
             self.* = undefined;
         }
@@ -905,6 +944,8 @@ pub fn HashMapUnmanaged(
             return ensureTotalCapacityContext(self, allocator, new_size, undefined);
         }
         pub fn ensureTotalCapacityContext(self: *Self, allocator: Allocator, new_size: Size, ctx: Context) Allocator.Error!void {
+            self.pointer_stability.lock();
+            defer self.pointer_stability.unlock();
             if (new_size > self.size)
                 try self.growIfNeeded(allocator, new_size - self.size, ctx);
         }
@@ -919,14 +960,18 @@ pub fn HashMapUnmanaged(
         }
 
         pub fn clearRetainingCapacity(self: *Self) void {
+            self.pointer_stability.lock();
+            defer self.pointer_stability.unlock();
             if (self.metadata) |_| {
                 self.initMetadatas();
                 self.size = 0;
-                self.available = @as(u32, @truncate((self.capacity() * max_load_percentage) / 100));
+                self.available = @truncate((self.capacity() * max_load_percentage) / 100);
             }
         }
 
         pub fn clearAndFree(self: *Self, allocator: Allocator) void {
+            self.pointer_stability.lock();
+            defer self.pointer_stability.unlock();
             self.deallocate(allocator);
             self.size = 0;
             self.available = 0;
@@ -997,9 +1042,11 @@ pub fn HashMapUnmanaged(
             return self.putNoClobberContext(allocator, key, value, undefined);
         }
         pub fn putNoClobberContext(self: *Self, allocator: Allocator, key: K, value: V, ctx: Context) Allocator.Error!void {
-            assert(!self.containsContext(key, ctx));
-            try self.growIfNeeded(allocator, 1, ctx);
-
+            {
+                self.pointer_stability.lock();
+                defer self.pointer_stability.unlock();
+                try self.growIfNeeded(allocator, 1, ctx);
+            }
             self.putAssumeCapacityNoClobberContext(key, value, ctx);
         }
 
@@ -1028,7 +1075,7 @@ pub fn HashMapUnmanaged(
 
             const hash = ctx.hash(key);
             const mask = self.capacity() - 1;
-            var idx = @as(usize, @truncate(hash & mask));
+            var idx: usize = @truncate(hash & mask);
 
             var metadata = self.metadata.? + idx;
             while (metadata[0].isUsed()) {
@@ -1280,17 +1327,21 @@ pub fn HashMapUnmanaged(
             return self.getOrPutContextAdapted(allocator, key, key_ctx, undefined);
         }
         pub fn getOrPutContextAdapted(self: *Self, allocator: Allocator, key: anytype, key_ctx: anytype, ctx: Context) Allocator.Error!GetOrPutResult {
-            self.growIfNeeded(allocator, 1, ctx) catch |err| {
-                // If allocation fails, try to do the lookup anyway.
-                // If we find an existing item, we can return it.
-                // Otherwise return the error, we could not add another.
-                const index = self.getIndex(key, key_ctx) orelse return err;
-                return GetOrPutResult{
-                    .key_ptr = &self.keys()[index],
-                    .value_ptr = &self.values()[index],
-                    .found_existing = true,
+            {
+                self.pointer_stability.lock();
+                defer self.pointer_stability.unlock();
+                self.growIfNeeded(allocator, 1, ctx) catch |err| {
+                    // If allocation fails, try to do the lookup anyway.
+                    // If we find an existing item, we can return it.
+                    // Otherwise return the error, we could not add another.
+                    const index = self.getIndex(key, key_ctx) orelse return err;
+                    return GetOrPutResult{
+                        .key_ptr = &self.keys()[index],
+                        .value_ptr = &self.values()[index],
+                        .found_existing = true,
+                    };
                 };
-            };
+            }
             return self.getOrPutAssumeCapacityAdapted(key, key_ctx);
         }
 
@@ -1495,6 +1546,7 @@ pub fn HashMapUnmanaged(
         /// Set the map to an empty state, making deinitialization a no-op, and
         /// returning a copy of the original.
         pub fn move(self: *Self) Self {
+            self.pointer_stability.assertUnlocked();
             const result = self.*;
             self.* = .{};
             return result;
@@ -1506,28 +1558,28 @@ pub fn HashMapUnmanaged(
             assert(new_cap > self.capacity());
             assert(std.math.isPowerOfTwo(new_cap));
 
-            var map = Self{};
+            var map: Self = .{};
             defer map.deinit(allocator);
+            map.pointer_stability.lock();
             try map.allocate(allocator, new_cap);
             map.initMetadatas();
             map.available = @truncate((new_cap * max_load_percentage) / 100);
 
             if (self.size != 0) {
                 const old_capacity = self.capacity();
-                var i: Size = 0;
-                var metadata = self.metadata.?;
-                const keys_ptr = self.keys();
-                const values_ptr = self.values();
-                while (i < old_capacity) : (i += 1) {
-                    if (metadata[i].isUsed()) {
-                        map.putAssumeCapacityNoClobberContext(keys_ptr[i], values_ptr[i], ctx);
-                        if (map.size == self.size)
-                            break;
-                    }
+                for (
+                    self.metadata.?[0..old_capacity],
+                    self.keys()[0..old_capacity],
+                    self.values()[0..old_capacity],
+                ) |m, k, v| {
+                    if (!m.isUsed()) continue;
+                    map.putAssumeCapacityNoClobberContext(k, v, ctx);
+                    if (map.size == self.size) break;
                 }
             }
 
             self.size = 0;
+            self.pointer_stability = .{ .state = .unlocked };
             std.mem.swap(Self, self, &map);
         }