Commit d3599ec73c

mlugg <mlugg@mlugg.co.uk>
2022-12-22 14:41:05
std: implement os.mprotect on Windows
1 parent a9eb463
Changed files (4)
lib/std/c/windows.zig
@@ -94,6 +94,18 @@ pub const SEEK = struct {
     pub const END = 2;
 };
 
+/// Basic memory protection flags
+pub const PROT = struct {
+    /// page can not be accessed
+    pub const NONE = 0x0;
+    /// page can be read
+    pub const READ = 0x1;
+    /// page can be written
+    pub const WRITE = 0x2;
+    /// page can be executed
+    pub const EXEC = 0x4;
+};
+
 pub const E = enum(u16) {
     /// No error occurred.
     SUCCESS = 0,
lib/std/os/windows/ntdll.zig
@@ -291,3 +291,11 @@ pub extern "ntdll" fn RtlQueryRegistryValues(
     Context: ?*anyopaque,
     Environment: ?*anyopaque,
 ) callconv(WINAPI) NTSTATUS;
+
+pub extern "ntdll" fn NtProtectVirtualMemory(
+    ProcessHandle: HANDLE,
+    BaseAddress: *PVOID,
+    NumberOfBytesToProtect: *ULONG,
+    NewAccessProtection: ULONG,
+    OldAccessProtection: *ULONG,
+) callconv(WINAPI) NTSTATUS;
lib/std/os/windows.zig
@@ -1499,6 +1499,22 @@ pub fn VirtualFree(lpAddress: ?LPVOID, dwSize: usize, dwFreeType: DWORD) void {
     assert(kernel32.VirtualFree(lpAddress, dwSize, dwFreeType) != 0);
 }
 
+pub const VirtualProtectError = error{
+    InvalidAddress,
+    Unexpected,
+};
+
+pub fn VirtualProtect(lpAddress: ?LPVOID, dwSize: SIZE_T, flNewProtect: DWORD, lpflOldProtect: *DWORD) VirtualProtectError!void {
+    // ntdll takes an extra level of indirection here
+    var addr = lpAddress;
+    var size = dwSize;
+    switch (ntdll.NtProtectVirtualMemory(self_process_handle, &addr, &size, flNewProtect, lpflOldProtect)) {
+        .SUCCESS => {},
+        .INVALID_ADDRESS => return error.InvalidAddress,
+        else => |st| return unexpectedStatus(st),
+    }
+}
+
 pub const VirtualQueryError = error{Unexpected};
 
 pub fn VirtualQuery(lpAddress: ?LPVOID, lpBuffer: PMEMORY_BASIC_INFORMATION, dwLength: SIZE_T) VirtualQueryError!SIZE_T {
lib/std/os.zig
@@ -4226,12 +4226,30 @@ pub const MProtectError = error{
 /// `memory.len` must be page-aligned.
 pub fn mprotect(memory: []align(mem.page_size) u8, protection: u32) MProtectError!void {
     assert(mem.isAligned(memory.len, mem.page_size));
-    switch (errno(system.mprotect(memory.ptr, memory.len, protection))) {
-        .SUCCESS => return,
-        .INVAL => unreachable,
-        .ACCES => return error.AccessDenied,
-        .NOMEM => return error.OutOfMemory,
-        else => |err| return unexpectedErrno(err),
+    if (builtin.os.tag == .windows) {
+        const win_prot: windows.DWORD = switch (@truncate(u3, protection)) {
+            0b000 => windows.PAGE_NOACCESS,
+            0b001 => windows.PAGE_READONLY,
+            0b010 => unreachable, // +w -r not allowed
+            0b011 => windows.PAGE_READWRITE,
+            0b100 => windows.PAGE_EXECUTE,
+            0b101 => windows.PAGE_EXECUTE_READ,
+            0b110 => unreachable, // +w -r not allowed
+            0b111 => windows.PAGE_EXECUTE_READWRITE,
+        };
+        var old: windows.DWORD = undefined;
+        windows.VirtualProtect(memory.ptr, memory.len, win_prot, &old) catch |err| switch (err) {
+            error.InvalidAddress => return error.AccessDenied,
+            error.Unexpected => return error.Unexpected,
+        };
+    } else {
+        switch (errno(system.mprotect(memory.ptr, memory.len, protection))) {
+            .SUCCESS => return,
+            .INVAL => unreachable,
+            .ACCES => return error.AccessDenied,
+            .NOMEM => return error.OutOfMemory,
+            else => |err| return unexpectedErrno(err),
+        }
     }
 }