Commit 16875b3598

ziggoon <zackmusgrave20@gmail.com>
2025-03-05 05:01:08
update std.heap.PageAllocator Windows implementation to remove race condition and utilize NtAllocateVirtualMemory / NtFreeVirtualMemory instead of VirtualAlloc and VirtualFree
1 parent 79460d4
Changed files (1)
lib
lib/std/heap/PageAllocator.zig
@@ -6,9 +6,14 @@ const maxInt = std.math.maxInt;
 const assert = std.debug.assert;
 const native_os = builtin.os.tag;
 const windows = std.os.windows;
+const ntdll = windows.ntdll;
 const posix = std.posix;
 const page_size_min = std.heap.page_size_min;
 
+const SUCCESS = @import("../os/windows/ntstatus.zig").NTSTATUS.SUCCESS;
+const MEM_RESERVE_PLACEHOLDER = windows.MEM_RESERVE_PLACEHOLDER;
+const MEM_PRESERVE_PLACEHOLDER = windows.MEM_PRESERVE_PLACEHOLDER;
+
 pub const vtable: Allocator.VTable = .{
     .alloc = alloc,
     .resize = resize,
@@ -22,51 +27,62 @@ pub fn map(n: usize, alignment: mem.Alignment) ?[*]u8 {
     const alignment_bytes = alignment.toByteUnits();
 
     if (native_os == .windows) {
-        // According to official documentation, VirtualAlloc aligns to page
-        // boundary, however, empirically it reserves pages on a 64K boundary.
-        // Since it is very likely the requested alignment will be honored,
-        // this logic first tries a call with exactly the size requested,
-        // before falling back to the loop below.
-        // https://devblogs.microsoft.com/oldnewthing/?p=42223
-        const addr = windows.VirtualAlloc(
-            null,
-            // VirtualAlloc will round the length to a multiple of page size.
-            // "If the lpAddress parameter is NULL, this value is rounded up to
-            // the next page boundary".
-            n,
-            windows.MEM_COMMIT | windows.MEM_RESERVE,
-            windows.PAGE_READWRITE,
-        ) catch return null;
-
-        if (mem.isAligned(@intFromPtr(addr), alignment_bytes))
-            return @ptrCast(addr);
-
-        // Fallback: reserve a range of memory large enough to find a
-        // sufficiently aligned address, then free the entire range and
-        // immediately allocate the desired subset. Another thread may have won
-        // the race to map the target range, in which case a retry is needed.
-        windows.VirtualFree(addr, 0, windows.MEM_RELEASE);
+        var base_addr: ?*anyopaque = null;
+        var size: windows.SIZE_T = n;
+
+        var status = ntdll.NtAllocateVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), 0, &size, windows.MEM_COMMIT | windows.MEM_RESERVE, windows.PAGE_READWRITE);
+
+        if (status == SUCCESS and mem.isAligned(@intFromPtr(base_addr), alignment_bytes)) {
+            return @ptrCast(base_addr);
+        }
+
+        if (status == SUCCESS) {
+            var region_size: windows.SIZE_T = 0;
+            _ = ntdll.NtFreeVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), &region_size, windows.MEM_RELEASE);
+        }
 
         const overalloc_len = n + alignment_bytes - page_size;
         const aligned_len = mem.alignForward(usize, n, page_size);
 
-        while (true) {
-            const reserved_addr = windows.VirtualAlloc(
-                null,
-                overalloc_len,
-                windows.MEM_RESERVE,
-                windows.PAGE_NOACCESS,
-            ) catch return null;
-            const aligned_addr = mem.alignForward(usize, @intFromPtr(reserved_addr), alignment_bytes);
-            windows.VirtualFree(reserved_addr, 0, windows.MEM_RELEASE);
-            const ptr = windows.VirtualAlloc(
-                @ptrFromInt(aligned_addr),
-                aligned_len,
-                windows.MEM_COMMIT | windows.MEM_RESERVE,
-                windows.PAGE_READWRITE,
-            ) catch continue;
-            return @ptrCast(ptr);
+        base_addr = null;
+        size = overalloc_len;
+
+        status = ntdll.NtAllocateVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), 0, &size, windows.MEM_RESERVE | MEM_RESERVE_PLACEHOLDER, windows.PAGE_NOACCESS);
+
+        if (status != SUCCESS) return null;
+
+        const placeholder_addr = @intFromPtr(base_addr);
+        const aligned_addr = mem.alignForward(usize, placeholder_addr, alignment_bytes);
+        const prefix_size = aligned_addr - placeholder_addr;
+
+        if (prefix_size > 0) {
+            var prefix_base = base_addr;
+            var prefix_size_param: windows.SIZE_T = prefix_size;
+            _ = ntdll.NtFreeVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&prefix_base), &prefix_size_param, windows.MEM_RELEASE | MEM_PRESERVE_PLACEHOLDER);
         }
+
+        const suffix_start = aligned_addr + aligned_len;
+        const suffix_size = (placeholder_addr + overalloc_len) - suffix_start;
+        if (suffix_size > 0) {
+            var suffix_base = @as(?*anyopaque, @ptrFromInt(suffix_start));
+            var suffix_size_param: windows.SIZE_T = suffix_size;
+            _ = ntdll.NtFreeVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&suffix_base), &suffix_size_param, windows.MEM_RELEASE | MEM_PRESERVE_PLACEHOLDER);
+        }
+
+        base_addr = @ptrFromInt(aligned_addr);
+        size = aligned_len;
+
+        status = ntdll.NtAllocateVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), 0, &size, windows.MEM_COMMIT | MEM_PRESERVE_PLACEHOLDER, windows.PAGE_READWRITE);
+
+        if (status == SUCCESS) {
+            return @ptrCast(base_addr);
+        }
+
+        base_addr = @as(?*anyopaque, @ptrFromInt(aligned_addr));
+        size = aligned_len;
+        _ = ntdll.NtFreeVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), &size, windows.MEM_RELEASE);
+
+        return null;
     }
 
     const aligned_len = mem.alignForward(usize, n, page_size);
@@ -104,26 +120,14 @@ fn alloc(context: *anyopaque, n: usize, alignment: mem.Alignment, ra: usize) ?[*
     return map(n, alignment);
 }
 
-fn resize(
-    context: *anyopaque,
-    memory: []u8,
-    alignment: mem.Alignment,
-    new_len: usize,
-    return_address: usize,
-) bool {
+fn resize(context: *anyopaque, memory: []u8, alignment: mem.Alignment, new_len: usize, return_address: usize) bool {
     _ = context;
     _ = alignment;
     _ = return_address;
     return realloc(memory, new_len, false) != null;
 }
 
-fn remap(
-    context: *anyopaque,
-    memory: []u8,
-    alignment: mem.Alignment,
-    new_len: usize,
-    return_address: usize,
-) ?[*]u8 {
+fn remap(context: *anyopaque, memory: []u8, alignment: mem.Alignment, new_len: usize, return_address: usize) ?[*]u8 {
     _ = context;
     _ = alignment;
     _ = return_address;
@@ -139,7 +143,9 @@ fn free(context: *anyopaque, memory: []u8, alignment: mem.Alignment, return_addr
 
 pub fn unmap(memory: []align(page_size_min) u8) void {
     if (native_os == .windows) {
-        windows.VirtualFree(memory.ptr, 0, windows.MEM_RELEASE);
+        var base_addr: ?*anyopaque = memory.ptr;
+        var region_size: windows.SIZE_T = 0;
+        _ = ntdll.NtFreeVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&base_addr), &region_size, windows.MEM_RELEASE);
     } else {
         const page_aligned_len = mem.alignForward(usize, memory.len, std.heap.pageSize());
         posix.munmap(memory.ptr[0..page_aligned_len]);
@@ -157,13 +163,10 @@ pub fn realloc(uncasted_memory: []u8, new_len: usize, may_move: bool) ?[*]u8 {
             const old_addr_end = base_addr + memory.len;
             const new_addr_end = mem.alignForward(usize, base_addr + new_len, page_size);
             if (old_addr_end > new_addr_end) {
-                // For shrinking that is not releasing, we will only decommit
-                // the pages not needed anymore.
-                windows.VirtualFree(
-                    @ptrFromInt(new_addr_end),
-                    old_addr_end - new_addr_end,
-                    windows.MEM_DECOMMIT,
-                );
+                var decommit_addr: ?*anyopaque = @ptrFromInt(new_addr_end);
+                var decommit_size: windows.SIZE_T = old_addr_end - new_addr_end;
+
+                _ = ntdll.NtAllocateVirtualMemory(windows.GetCurrentProcess(), @ptrCast(&decommit_addr), 0, &decommit_size, windows.MEM_RESET, windows.PAGE_NOACCESS);
             }
             return memory.ptr;
         }