Commit 828d23956d

Veikka Tuominen <git@vexu.eu>
2023-12-14 15:03:44
std.heap: add runtime safety for calling `stackFallback(N).get` multiple times
Closes #16344
1 parent 6a32d58
Changed files (2)
deps
lib
deps/aro/aro/Compilation.zig
@@ -1350,9 +1350,10 @@ pub fn hasInclude(
     }
 
     var stack_fallback = std.heap.stackFallback(path_buf_stack_limit, comp.gpa);
+    const sf_allocator = stack_fallback.get();
 
-    while (try it.nextWithFile(filename, stack_fallback.get())) |found| {
-        defer stack_fallback.get().free(found.path);
+    while (try it.nextWithFile(filename, sf_allocator)) |found| {
+        defer sf_allocator.free(found.path);
         if (!std.meta.isError(cwd.access(found.path, .{}))) return true;
     }
     return false;
@@ -1411,9 +1412,10 @@ pub fn findEmbed(
     };
     var it = IncludeDirIterator{ .comp = comp, .cwd_source_id = cwd_source_id };
     var stack_fallback = std.heap.stackFallback(path_buf_stack_limit, comp.gpa);
+    const sf_allocator = stack_fallback.get();
 
-    while (try it.nextWithFile(filename, stack_fallback.get())) |found| {
-        defer stack_fallback.get().free(found.path);
+    while (try it.nextWithFile(filename, sf_allocator)) |found| {
+        defer sf_allocator.free(found.path);
         if (comp.getFileContents(found.path, limit)) |some|
             return some
         else |err| switch (err) {
@@ -1457,8 +1459,10 @@ pub fn findInclude(
     }
 
     var stack_fallback = std.heap.stackFallback(path_buf_stack_limit, comp.gpa);
-    while (try it.nextWithFile(filename, stack_fallback.get())) |found| {
-        defer stack_fallback.get().free(found.path);
+    const sf_allocator = stack_fallback.get();
+
+    while (try it.nextWithFile(filename, sf_allocator)) |found| {
+        defer sf_allocator.free(found.path);
         if (comp.addSourceFromPathExtra(found.path, found.kind)) |some| {
             if (it.tried_ms_cwd) {
                 try comp.addDiagnostic(.{
lib/std/heap.zig
@@ -521,10 +521,16 @@ pub fn StackFallbackAllocator(comptime size: usize) type {
         buffer: [size]u8,
         fallback_allocator: Allocator,
         fixed_buffer_allocator: FixedBufferAllocator,
+        get_called: if (std.debug.runtime_safety) bool else void =
+            if (std.debug.runtime_safety) false else {},
 
         /// This function both fetches a `Allocator` interface to this
         /// allocator *and* resets the internal buffer allocator.
         pub fn get(self: *Self) Allocator {
+            if (std.debug.runtime_safety) {
+                assert(!self.get_called); // `get` called multiple times; instead use `const allocator = stackFallback(N).get();`
+                self.get_called = true;
+            }
             self.fixed_buffer_allocator = FixedBufferAllocator.init(self.buffer[0..]);
             return .{
                 .ptr = self,
@@ -536,6 +542,12 @@ pub fn StackFallbackAllocator(comptime size: usize) type {
             };
         }
 
+        /// Unlike most std allocators `StackFallbackAllocator` modifies
+        /// its internal state before returning an implementation of
+        /// the`Allocator` interface and therefore also doesn't use
+        /// the usual `.allocator()` method.
+        pub const allocator = @compileError("use 'const allocator = stackFallback(N).get();' instead");
+
         fn alloc(
             ctx: *anyopaque,
             len: usize,
@@ -675,13 +687,22 @@ test "FixedBufferAllocator.reset" {
 }
 
 test "StackFallbackAllocator" {
-    const fallback_allocator = page_allocator;
-    var stack_allocator = stackFallback(4096, fallback_allocator);
-
-    try testAllocator(stack_allocator.get());
-    try testAllocatorAligned(stack_allocator.get());
-    try testAllocatorLargeAlignment(stack_allocator.get());
-    try testAllocatorAlignedShrink(stack_allocator.get());
+    {
+        var stack_allocator = stackFallback(4096, std.testing.allocator);
+        try testAllocator(stack_allocator.get());
+    }
+    {
+        var stack_allocator = stackFallback(4096, std.testing.allocator);
+        try testAllocatorAligned(stack_allocator.get());
+    }
+    {
+        var stack_allocator = stackFallback(4096, std.testing.allocator);
+        try testAllocatorLargeAlignment(stack_allocator.get());
+    }
+    {
+        var stack_allocator = stackFallback(4096, std.testing.allocator);
+        try testAllocatorAlignedShrink(stack_allocator.get());
+    }
 }
 
 test "FixedBufferAllocator Reuse memory on realloc" {