Commit 216badef0b

Jakub Konka <kubkon@jakubkonka.com>
2023-03-30 18:52:56
coff: use std.os.windows wrappers; fix relocating in-file
1 parent ba5302c
Changed files (2)
src/link/Coff/Relocation.zig
@@ -74,7 +74,7 @@ pub fn getTargetAddress(self: Relocation, coff_file: *const Coff) ?u32 {
 
 /// Returns `false` if obtaining the target address has been deferred until `flushModule`.
 /// This can happen when trying to resolve address of an import table entry ahead of time.
-pub fn resolve(self: Relocation, atom_index: Atom.Index, code: []u8, coff_file: *Coff) bool {
+pub fn resolve(self: Relocation, atom_index: Atom.Index, code: []u8, image_base: u64, coff_file: *Coff) bool {
     const atom = coff_file.getAtom(atom_index);
     const source_sym = atom.getSymbol(coff_file);
     const source_vaddr = source_sym.value + self.offset;
@@ -92,7 +92,7 @@ pub fn resolve(self: Relocation, atom_index: Atom.Index, code: []u8, coff_file:
     const ctx: Context = .{
         .source_vaddr = source_vaddr,
         .target_vaddr = target_vaddr_with_addend,
-        .image_base = coff_file.hot_state.loaded_base_address orelse coff_file.getImageBase(),
+        .image_base = image_base,
         .code = code,
         .ptr_width = coff_file.ptr_width,
     };
src/link/Coff.zig
@@ -93,7 +93,9 @@ base_relocs: BaseRelocationTable = .{},
 hot_state: HotUpdateState = .{},
 
 const HotUpdateState = struct {
-    loaded_base_address: ?u64 = null,
+    /// Base address at which the process (image) got loaded.
+    /// We need this info to correctly slide pointers when relocating.
+    loaded_base_address: ?std.os.windows.HMODULE = null,
 };
 
 const Entry = struct {
@@ -784,139 +786,53 @@ fn writeAtom(self: *Coff, atom_index: Atom.Index, code: []u8) !void {
         file_offset,
         file_offset + code.len,
     });
-    self.resolveRelocs(atom_index, code);
 
     if (self.base.child_pid) |handle| {
-        const vaddr = sym.value + (self.hot_state.loaded_base_address orelse self.getImageBase());
+        const slide = @ptrToInt(self.hot_state.loaded_base_address.?);
+
+        const mem_code = try self.base.allocator.dupe(u8, code);
+        defer self.base.allocator.free(mem_code);
+        self.resolveRelocs(atom_index, mem_code, slide);
+
+        const vaddr = sym.value + slide;
+        const pvaddr = @intToPtr(*anyopaque, vaddr);
         log.debug("writing to memory at address {x}", .{vaddr});
         if (section.header.flags.MEM_WRITE == 0) {
             log.debug("page not mapped for write access; re-mapping...", .{});
-            try writeMemProtected(handle, vaddr, code);
+            writeMemProtected(handle, pvaddr, mem_code) catch |err| {
+                log.warn("writing to protected memory failed with error: {s}", .{@errorName(err)});
+            };
         } else {
-            if (WriteProcessMemory(handle, vaddr, code)) |amt| {
-                if (amt != code.len) return error.InputOutput;
-            } else |err| {
-                log.warn("writing to process memory failed with error: {s}", .{@errorName(err)});
-            }
+            writeMem(handle, pvaddr, mem_code) catch |err| {
+                log.warn("writing to protected memory failed with error: {s}", .{@errorName(err)});
+            };
         }
     }
 
+    self.resolveRelocs(atom_index, code, self.getImageBase());
     try self.base.file.?.pwriteAll(code, file_offset);
 }
 
-extern "ntdll" fn NtReadVirtualMemory(
-    ProcessHandle: std.os.windows.HANDLE,
-    BaseAddress: std.os.windows.PVOID,
-    Buffer: std.os.windows.LPVOID,
-    NumberOfBytesToRead: std.os.windows.SIZE_T,
-    NumberOfBytesRead: ?*std.os.windows.SIZE_T,
-) std.os.windows.NTSTATUS;
-
-extern "ntdll" fn NtWriteVirtualMemory(
-    ProcessHandle: std.os.windows.HANDLE,
-    BaseAddress: std.os.windows.PVOID,
-    Buffer: std.os.windows.LPCVOID,
-    NumberOfBytesToWrite: std.os.windows.SIZE_T,
-    NumberOfBytesWritten: ?*std.os.windows.SIZE_T,
-) std.os.windows.NTSTATUS;
-
-extern "ntdll" fn NtProtectVirtualMemory(
-    ProcessHandle: std.os.windows.HANDLE,
-    BaseAddress: *std.os.windows.PVOID,
-    NumberOfBytesToProtect: *std.os.windows.SIZE_T,
-    NewAccessProtection: std.os.windows.ULONG,
-    OldAccessProtection: *std.os.windows.ULONG,
-) std.os.windows.NTSTATUS;
-
-fn ReadProcessMemory(handle: std.os.windows.HANDLE, base_addr: usize, buffer: []u8) ![]u8 {
-    var nread: usize = 0;
-    switch (NtReadVirtualMemory(
-        handle,
-        @intToPtr(*anyopaque, base_addr),
-        buffer.ptr,
-        buffer.len,
-        &nread,
-    )) {
-        .SUCCESS => return buffer[0..nread],
-        else => |rc| return std.os.windows.unexpectedStatus(rc),
-    }
-}
-
-fn WriteProcessMemory(handle: std.os.windows.HANDLE, base_addr: usize, buffer: []const u8) !usize {
-    var nwritten: usize = 0;
-    switch (NtWriteVirtualMemory(
-        handle,
-        @intToPtr(*anyopaque, base_addr),
-        @ptrCast(*const anyopaque, buffer.ptr),
-        buffer.len,
-        &nwritten,
-    )) {
-        .SUCCESS => return nwritten,
-        else => |rc| return std.os.windows.unexpectedStatus(rc),
-    }
-}
-
-fn VirtualProtectEx(handle: std.os.windows.HANDLE, base_addr: usize, size: usize, new_prot: u32) !u32 {
-    var out_paddr = @intToPtr(*anyopaque, base_addr);
-    var out_size = size;
-    var old_prot: u32 = undefined;
-    switch (NtProtectVirtualMemory(
-        handle,
-        &out_paddr,
-        &out_size,
-        new_prot,
-        &old_prot,
-    )) {
-        .SUCCESS => return old_prot,
-        else => |rc| return std.os.windows.unexpectedStatus(rc),
-    }
-}
-
-const PROCESS_BASIC_INFORMATION = extern struct {
-    ExitStatus: std.os.windows.NTSTATUS,
-    PebBaseAddress: *std.os.windows.PEB,
-    AffinityMask: std.os.windows.ULONG_PTR,
-    BasePriority: std.os.windows.KPRIORITY,
-    UniqueProcessId: std.os.windows.ULONG_PTR,
-    InheritedFromUniqueProcessId: std.os.windows.ULONG_PTR,
-};
-
-fn getProcessBaseAddress(handle: std.ChildProcess.Id) !u64 {
-    var info: PROCESS_BASIC_INFORMATION = undefined;
-    var nread: std.os.windows.DWORD = 0;
-    const rc = std.os.windows.ntdll.NtQueryInformationProcess(
-        handle,
-        .ProcessBasicInformation,
-        &info,
-        @sizeOf(PROCESS_BASIC_INFORMATION),
-        &nread,
-    );
-    switch (rc) {
-        .SUCCESS => {},
-        else => return std.os.windows.unexpectedStatus(rc),
-    }
-
-    var peb_buf: [@sizeOf(std.os.windows.PEB)]u8 align(@alignOf(std.os.windows.PEB)) = undefined;
-    const pebout = try ReadProcessMemory(handle, @ptrToInt(info.PebBaseAddress), &peb_buf);
-    const peb = @ptrCast(*const std.os.windows.PEB, @alignCast(@alignOf(std.os.windows.PEB), pebout.ptr));
-    return @ptrToInt(peb.ImageBaseAddress);
-}
-
-fn debugMem(allocator: Allocator, handle: std.ChildProcess.Id, vaddr: u64, code: []const u8) !void {
+fn debugMem(allocator: Allocator, handle: std.ChildProcess.Id, pvaddr: std.os.windows.LPVOID, code: []const u8) !void {
     var buffer = try allocator.alloc(u8, code.len);
     defer allocator.free(buffer);
-    const memread = try ReadProcessMemory(handle, vaddr, buffer);
+    const memread = try std.os.windows.ReadProcessMemory(handle, pvaddr, buffer);
     log.debug("in memory: {x}", .{std.fmt.fmtSliceHexLower(memread)});
     log.debug("to write: {x}", .{std.fmt.fmtSliceHexLower(code)});
 }
 
-fn writeMemProtected(handle: std.ChildProcess.Id, vaddr: u64, code: []const u8) !void {
-    const old_prot = try VirtualProtectEx(handle, vaddr, code.len, std.os.windows.PAGE_EXECUTE_WRITECOPY);
-    const amt = try WriteProcessMemory(handle, vaddr, code);
-    if (amt != code.len) return error.InputOutput;
+fn writeMemProtected(handle: std.ChildProcess.Id, pvaddr: std.os.windows.LPVOID, code: []const u8) !void {
+    var old_prot: std.os.windows.DWORD = undefined;
+    try std.os.windows.VirtualProtectEx(handle, pvaddr, code.len, std.os.windows.PAGE_EXECUTE_WRITECOPY, &old_prot);
+    try writeMem(handle, pvaddr, code);
     // TODO: We can probably just set the pages writeable and leave it at that without having to restore the attributes.
     // For that though, we want to track which page has already been modified.
-    _ = try VirtualProtectEx(handle, vaddr, code.len, old_prot);
+    try std.os.windows.VirtualProtectEx(handle, pvaddr, code.len, old_prot, null);
+}
+
+fn writeMem(handle: std.ChildProcess.Id, pvaddr: std.os.windows.LPVOID, code: []const u8) !void {
+    const amt = try std.os.windows.WriteProcessMemory(handle, pvaddr, code);
+    if (amt != code.len) return error.InputOutput;
 }
 
 fn writePtrWidthAtom(self: *Coff, atom_index: Atom.Index) !void {
@@ -952,14 +868,14 @@ fn markRelocsDirtyByAddress(self: *Coff, addr: u32) void {
     }
 }
 
-fn resolveRelocs(self: *Coff, atom_index: Atom.Index, code: []u8) void {
+fn resolveRelocs(self: *Coff, atom_index: Atom.Index, code: []u8, image_base: u64) void {
     const relocs = self.relocs.getPtr(atom_index) orelse return;
 
     log.debug("relocating '{s}'", .{self.getAtom(atom_index).getName(self)});
 
     for (relocs.items) |*reloc| {
         if (!reloc.dirty) continue;
-        if (reloc.resolve(atom_index, code, self)) {
+        if (reloc.resolve(atom_index, code, image_base, self)) {
             reloc.dirty = false;
         }
     }
@@ -967,7 +883,7 @@ fn resolveRelocs(self: *Coff, atom_index: Atom.Index, code: []u8) void {
 
 pub fn ptraceAttach(self: *Coff, handle: std.ChildProcess.Id) !void {
     log.debug("attaching to process with handle {*}", .{handle});
-    self.hot_state.loaded_base_address = getProcessBaseAddress(handle) catch |err| {
+    self.hot_state.loaded_base_address = std.os.windows.ProcessBaseAddress(handle) catch |err| {
         log.warn("failed to get base address for the process with error: {s}", .{@errorName(err)});
         return;
     };