Commit fb3345e346

Jakub Konka <kubkon@jakubkonka.com>
2023-03-28 10:40:19
coff: do not use atoms for synthetic import address table
Instead, introduce a custom ImportTable structure which will act as a thunk in the MachO linker, and we will use that to calculate the address of a pointer on-the-fly. Additionally, fix logic in writeImportTables to allow for multiple DLLs.
1 parent 2a5c4ea
Changed files (2)
src/link/Coff/Relocation.zig
@@ -45,23 +45,25 @@ pcrel: bool,
 length: u2,
 dirty: bool = true,
 
-/// Returns an Atom which is the target node of this relocation edge (if any).
-pub fn getTargetAtomIndex(self: Relocation, coff_file: *const Coff) ?Atom.Index {
+/// Returns address of the target if any.
+pub fn getTargetAddress(self: Relocation, coff_file: *const Coff) ?u32 {
     switch (self.type) {
-        .got,
-        .got_page,
-        .got_pageoff,
-        => return coff_file.getGotAtomIndexForSymbol(self.target),
-
-        .direct,
-        .page,
-        .pageoff,
-        => return coff_file.getAtomIndexForSymbol(self.target),
-
-        .import,
-        .import_page,
-        .import_pageoff,
-        => return coff_file.getImportAtomIndexForSymbol(self.target),
+        .got, .got_page, .got_pageoff, .direct, .page, .pageoff => {
+            const maybe_target_atom_index = switch (self.type) {
+                .got, .got_page, .got_pageoff => coff_file.getGotAtomIndexForSymbol(self.target),
+                .direct, .page, .pageoff => coff_file.getAtomIndexForSymbol(self.target),
+                else => unreachable,
+            };
+            const target_atom_index = maybe_target_atom_index orelse return null;
+            const target_atom = coff_file.getAtom(target_atom_index);
+            return target_atom.getSymbol(coff_file).value;
+        },
+
+        .import, .import_page, .import_pageoff => {
+            const sym = coff_file.getSymbol(self.target);
+            const itab = coff_file.import_tables.get(sym.value) orelse return null;
+            return itab.getImportAddress(coff_file, self.target);
+        },
     }
 }
 
@@ -73,9 +75,7 @@ pub fn resolve(self: *Relocation, atom_index: Atom.Index, coff_file: *Coff) !voi
 
     const file_offset = source_section.pointer_to_raw_data + source_sym.value - source_section.virtual_address;
 
-    const target_atom_index = self.getTargetAtomIndex(coff_file) orelse return;
-    const target_atom = coff_file.getAtom(target_atom_index);
-    const target_vaddr = target_atom.getSymbol(coff_file).value;
+    const target_vaddr = self.getTargetAddress(coff_file) orelse return;
     const target_vaddr_with_addend = target_vaddr + self.addend;
 
     log.debug("  ({x}: [() => 0x{x} ({s})) ({s}) (in file at 0x{x})", .{
src/link/Coff.zig
@@ -70,9 +70,10 @@ got_entries: std.ArrayListUnmanaged(Entry) = .{},
 got_entries_free_list: std.ArrayListUnmanaged(u32) = .{},
 got_entries_table: std.AutoHashMapUnmanaged(SymbolWithLoc, u32) = .{},
 
-imports: std.ArrayListUnmanaged(Entry) = .{},
-imports_free_list: std.ArrayListUnmanaged(u32) = .{},
-imports_table: std.AutoHashMapUnmanaged(SymbolWithLoc, u32) = .{},
+/// A table of ImportTables partitioned by the library name.
+/// Key is an offset into the interning string table `temp_strtab`.
+import_tables: std.AutoArrayHashMapUnmanaged(u32, ImportTable) = .{},
+imports_count_dirty: bool = true,
 
 /// Virtual address of the entry point procedure relative to image base.
 entry_addr: ?u32 = null,
@@ -159,6 +160,92 @@ const Section = struct {
     free_list: std.ArrayListUnmanaged(Atom.Index) = .{},
 };
 
+/// Represents an import table in the .idata section where each contained pointer
+/// is to a symbol from the same DLL.
+///
+/// The layout of .idata section is as follows:
+///
+/// --- ADDR1 : IAT (all import tables concatenated together)
+///     ptr
+///     ptr
+///     0 sentinel
+///     ptr
+///     0 sentinel
+/// --- ADDR2: headers
+///     ImportDirectoryEntry header
+///     ImportDirectoryEntry header
+///     sentinel
+/// --- ADDR2: lookup tables
+///     Lookup table
+///     0 sentinel
+///     Lookup table
+///     0 sentinel
+/// --- ADDR3: name hint tables
+///     hint-symname
+///     hint-symname
+/// --- ADDR4: DLL names
+///     DLL#1 name
+///     DLL#2 name
+/// --- END
+const ImportTable = struct {
+    entries: std.ArrayListUnmanaged(SymbolWithLoc) = .{},
+    free_list: std.ArrayListUnmanaged(u32) = .{},
+    lookup: std.AutoHashMapUnmanaged(SymbolWithLoc, u32) = .{},
+    index: u8,
+
+    const ITable = @This();
+
+    fn deinit(itab: *ITable, allocator: Allocator) void {
+        itab.entries.deinit(allocator);
+        itab.free_list.deinit(allocator);
+        itab.lookup.deinit(allocator);
+    }
+
+    fn size(itab: ITable) u32 {
+        return @intCast(u32, itab.entries.items.len) * @sizeOf(u64);
+    }
+
+    fn addImport(itab: *ITable, allocator: Allocator, target: SymbolWithLoc) !u32 {
+        try itab.entries.ensureUnusedCapacity(allocator, 1);
+        const index: u32 = blk: {
+            if (itab.free_list.popOrNull()) |index| {
+                log.debug("  (reusing import entry index {d})", .{index});
+                break :blk index;
+            } else {
+                log.debug("  (allocating import entry at index {d})", .{itab.entries.items.len});
+                const index = @intCast(u32, itab.entries.items.len);
+                _ = itab.entries.addOneAssumeCapacity();
+                break :blk index;
+            }
+        };
+        itab.entries.items[index] = target;
+        try itab.lookup.putNoClobber(allocator, target, index);
+        return index;
+    }
+
+    fn getBaseAddress(itab: *const ITable, coff_file: *const Coff) u32 {
+        const header = coff_file.sections.items(.header)[coff_file.idata_section_index.?];
+        var addr = header.virtual_address;
+        for (coff_file.import_tables.values(), 0..) |other_itab, i| {
+            if (itab.index == i) break;
+            addr += @intCast(u32, other_itab.entries.items.len * @sizeOf(u64)) + 8;
+        }
+        return addr;
+    }
+
+    pub fn getImportAddress(itab: *const ITable, coff_file: *const Coff, target: SymbolWithLoc) ?u32 {
+        const index = itab.lookup.get(target) orelse return null;
+        const base_vaddr = itab.getBaseAddress(coff_file);
+        return base_vaddr + index * @sizeOf(u64);
+    }
+
+    pub fn write(itab: ITable, writer: anytype) !void {
+        for (itab.entries.items) |_| {
+            try writer.writeIntLittle(u64, 0);
+        }
+    }
+};
+
 const DeclMetadata = struct {
     atom: Atom.Index,
     section: u16,
@@ -315,9 +402,11 @@ pub fn deinit(self: *Coff) void {
     self.got_entries.deinit(gpa);
     self.got_entries_free_list.deinit(gpa);
     self.got_entries_table.deinit(gpa);
-    self.imports.deinit(gpa);
-    self.imports_free_list.deinit(gpa);
-    self.imports_table.deinit(gpa);
+
+    for (self.import_tables.values()) |*itab| {
+        itab.deinit(gpa);
+    }
+    self.import_tables.deinit(gpa);
 
     {
         var it = self.decls.iterator();
@@ -730,28 +819,6 @@ pub fn allocateGotEntry(self: *Coff, target: SymbolWithLoc) !u32 {
     return index;
 }
 
-pub fn allocateImportEntry(self: *Coff, target: SymbolWithLoc) !u32 {
-    const gpa = self.base.allocator;
-    try self.imports.ensureUnusedCapacity(gpa, 1);
-
-    const index: u32 = blk: {
-        if (self.imports_free_list.popOrNull()) |index| {
-            log.debug("  (reusing import entry index {d})", .{index});
-            break :blk index;
-        } else {
-            log.debug("  (allocating import entry at index {d})", .{self.imports.items.len});
-            const index = @intCast(u32, self.imports.items.len);
-            _ = self.imports.addOneAssumeCapacity();
-            break :blk index;
-        }
-    };
-
-    self.imports.items[index] = .{ .target = target, .sym_index = 0 };
-    try self.imports_table.putNoClobber(gpa, target, index);
-
-    return index;
-}
-
 pub fn createAtom(self: *Coff) !Atom.Index {
     const gpa = self.base.allocator;
     const atom_index = @intCast(Atom.Index, self.atoms.items.len);
@@ -802,21 +869,6 @@ fn createGotAtom(self: *Coff, target: SymbolWithLoc) !Atom.Index {
     return atom_index;
 }
 
-fn createImportAtom(self: *Coff) !Atom.Index {
-    const atom_index = try self.createAtom();
-    const atom = self.getAtomPtr(atom_index);
-    atom.size = @sizeOf(u64);
-    atom.alignment = @alignOf(u64);
-
-    const sym = atom.getSymbolPtr(self);
-    sym.section_number = @intToEnum(coff.SectionNumber, self.idata_section_index.? + 1);
-    sym.value = try self.allocateAtom(atom_index, atom.size, atom.alignment);
-
-    log.debug("allocated import atom at 0x{x}", .{sym.value});
-
-    return atom_index;
-}
-
 fn growAtom(self: *Coff, atom_index: Atom.Index, new_atom_size: u32, alignment: u32) !u32 {
     const atom = self.getAtom(atom_index);
     const sym = atom.getSymbol(self);
@@ -876,10 +928,8 @@ fn markRelocsDirtyByAddress(self: *Coff, addr: u32) void {
     var it = self.relocs.valueIterator();
     while (it.next()) |relocs| {
         for (relocs.items) |*reloc| {
-            const target_atom_index = reloc.getTargetAtomIndex(self) orelse continue;
-            const target_atom = self.getAtom(target_atom_index);
-            const target_sym = target_atom.getSymbol(self);
-            if (target_sym.value < addr) continue;
+            const target_vaddr = reloc.getTargetAddress(self) orelse continue;
+            if (target_vaddr < addr) continue;
             reloc.dirty = true;
         }
     }
@@ -1468,35 +1518,42 @@ pub fn flushModule(self: *Coff, comp: *Compilation, prog_node: *std.Progress.Nod
     sub_prog_node.activate();
     defer sub_prog_node.end();
 
+    const gpa = self.base.allocator;
+
     while (self.unresolved.popOrNull()) |entry| {
         assert(entry.value); // We only expect imports generated by the incremental linker for now.
         const global = self.globals.items[entry.key];
-        if (self.imports_table.contains(global)) continue;
-
-        const import_index = try self.allocateImportEntry(global);
-        const import_atom_index = try self.createImportAtom();
-        const import_atom = self.getAtom(import_atom_index);
-        self.imports.items[import_index].sym_index = import_atom.getSymbolIndex().?;
-        try self.writePtrWidthAtom(import_atom_index);
-    }
-
-    if (build_options.enable_logging) {
-        self.logSymtab();
+        const sym = self.getSymbol(global);
+        const res = try self.import_tables.getOrPut(gpa, sym.value);
+        const itable = res.value_ptr;
+        if (!res.found_existing) {
+            itable.* = .{ .index = @intCast(u8, self.import_tables.values().len - 1) };
+        }
+        if (itable.lookup.contains(global)) continue;
+        // TODO: we could technically write the pointer placeholder for to-be-bound import here,
+        // but since this happens in flush, there is currently no point.
+        _ = try itable.addImport(gpa, global);
+        self.imports_count_dirty = true;
     }
 
+    try self.writeImportTables();
     {
         var it = self.relocs.keyIterator();
         while (it.next()) |atom| {
             try self.resolveRelocs(atom.*);
         }
     }
-    try self.writeImportTable();
     try self.writeBaseRelocations();
 
     if (self.getEntryPoint()) |entry_sym_loc| {
         self.entry_addr = self.getSymbol(entry_sym_loc).value;
     }
 
+    if (build_options.enable_logging) {
+        self.logSymtab();
+        self.logImportTables();
+    }
+
     try self.writeStrtab();
     try self.writeDataDirectoriesHeaders();
     try self.writeSectionHeaders();
@@ -1660,53 +1717,36 @@ fn writeBaseRelocations(self: *Coff) !void {
     };
 }
 
-fn writeImportTable(self: *Coff) !void {
+fn writeImportTables(self: *Coff) !void {
     if (self.idata_section_index == null) return;
+    if (!self.imports_count_dirty) return;
 
     const gpa = self.base.allocator;
 
-    const last_atom_index = self.sections.items(.last_atom_index)[self.idata_section_index.?] orelse return;
+    const ext = ".dll";
     const header = &self.sections.items(.header)[self.idata_section_index.?];
-    const last_atom = self.getAtom(last_atom_index);
-
-    const iat_rva = header.virtual_address;
-    const iat_size = last_atom.getSymbol(self).value + last_atom.size * 2 - iat_rva; // account for sentinel zero pointer
 
-    const dll_name = "KERNEL32.dll";
-
-    var import_dir_entry = coff.ImportDirectoryEntry{
-        .import_lookup_table_rva = @sizeOf(coff.ImportDirectoryEntry) * 2,
-        .time_date_stamp = 0,
-        .forwarder_chain = 0,
-        .name_rva = 0,
-        .import_address_table_rva = iat_rva,
-    };
-
-    // TODO: we currently assume there's only one (implicit) DLL - ntdll
-    var lookup_table = std.ArrayList(coff.ImportLookupEntry64.ByName).init(gpa);
-    defer lookup_table.deinit();
-
-    var names_table = std.ArrayList(u8).init(gpa);
-    defer names_table.deinit();
-
-    // TODO: check if import is still valid
-    for (self.imports.items) |entry| {
-        const target_name = self.getSymbolName(entry.target);
-        const start = names_table.items.len;
-        mem.writeIntLittle(u16, try names_table.addManyAsArray(2), 0); // TODO: currently, hint is set to 0 as we haven't yet parsed any DLL
-        try names_table.appendSlice(target_name);
-        try names_table.append(0);
-        const end = names_table.items.len;
-        if (!mem.isAlignedGeneric(usize, end - start, @sizeOf(u16))) {
-            try names_table.append(0);
+    // Calculate needed size
+    var iat_size: u32 = 0;
+    var dir_table_size: u32 = @sizeOf(coff.ImportDirectoryEntry); // sentinel
+    var lookup_table_size: u32 = 0;
+    var names_table_size: u32 = 0;
+    var dll_names_size: u32 = 0;
+    for (self.import_tables.keys(), 0..) |off, i| {
+        const lib_name = self.temp_strtab.getAssumeExists(off);
+        const itable = self.import_tables.values()[i];
+        iat_size += itable.size() + 8;
+        dir_table_size += @sizeOf(coff.ImportDirectoryEntry);
+        lookup_table_size += @intCast(u32, itable.entries.items.len + 1) * @sizeOf(coff.ImportLookupEntry64.ByName);
+        for (itable.entries.items) |entry| {
+            const sym_name = self.getSymbolName(entry);
+            names_table_size += 2 + mem.alignForwardGeneric(u32, @intCast(u32, sym_name.len + 1), 2);
         }
-        try lookup_table.append(.{ .name_table_rva = @intCast(u31, start) });
+        dll_names_size += @intCast(u32, lib_name.len + ext.len + 1);
     }
-    try lookup_table.append(.{ .name_table_rva = 0 }); // the sentinel
 
-    const dir_entry_size = @sizeOf(coff.ImportDirectoryEntry) + lookup_table.items.len * @sizeOf(coff.ImportLookupEntry64.ByName) + names_table.items.len + dll_name.len + 1;
+    const needed_size = iat_size + dir_table_size + lookup_table_size + names_table_size + dll_names_size;
     const sect_capacity = self.allocatedSize(header.pointer_to_raw_data);
-    const needed_size = @intCast(u32, iat_size + dir_entry_size + @sizeOf(coff.ImportDirectoryEntry));
     if (needed_size > sect_capacity) {
         const new_offset = self.findFreeSpace(needed_size, default_file_alignment);
         log.debug("moving .idata from 0x{x} to 0x{x}", .{ header.pointer_to_raw_data, new_offset });
@@ -1716,41 +1756,105 @@ fn writeImportTable(self: *Coff) !void {
         if (needed_size > sect_vm_capacity) {
             try self.growSectionVM(self.idata_section_index.?, needed_size);
         }
-    }
-
-    // Fixup offsets
-    const base_rva = iat_rva + iat_size;
-    import_dir_entry.import_lookup_table_rva += base_rva;
-    import_dir_entry.name_rva = @intCast(u32, base_rva + dir_entry_size + @sizeOf(coff.ImportDirectoryEntry) - dll_name.len - 1);
 
-    for (lookup_table.items[0 .. lookup_table.items.len - 1]) |*lk| {
-        lk.name_table_rva += @intCast(u31, base_rva + @sizeOf(coff.ImportDirectoryEntry) * 2 + lookup_table.items.len * @sizeOf(coff.ImportLookupEntry64.ByName));
+        header.virtual_size = @max(header.virtual_size, needed_size);
+        header.size_of_raw_data = needed_size;
     }
 
+    // Do the actual writes
     var buffer = std.ArrayList(u8).init(gpa);
     defer buffer.deinit();
-    try buffer.ensureTotalCapacity(dir_entry_size + @sizeOf(coff.ImportDirectoryEntry));
-    buffer.appendSliceAssumeCapacity(mem.asBytes(&import_dir_entry));
-    buffer.appendNTimesAssumeCapacity(0, @sizeOf(coff.ImportDirectoryEntry)); // the sentinel; TODO: I think doing all of the above on bytes directly might be cleaner
-    buffer.appendSliceAssumeCapacity(mem.sliceAsBytes(lookup_table.items));
-    buffer.appendSliceAssumeCapacity(names_table.items);
-    buffer.appendSliceAssumeCapacity(dll_name);
-    buffer.appendAssumeCapacity(0);
-
-    try self.base.file.?.pwriteAll(buffer.items, header.pointer_to_raw_data + iat_size);
-    // Override the IAT atoms
-    // TODO: we should rewrite only dirtied atoms, but that's for way later
-    try self.base.file.?.pwriteAll(mem.sliceAsBytes(lookup_table.items), header.pointer_to_raw_data);
+    try buffer.ensureTotalCapacityPrecise(needed_size);
+    buffer.resize(needed_size) catch unreachable;
+
+    const dir_header_size = @sizeOf(coff.ImportDirectoryEntry);
+    const lookup_entry_size = @sizeOf(coff.ImportLookupEntry64.ByName);
+
+    var iat_offset: u32 = 0;
+    var dir_table_offset = iat_size;
+    var lookup_table_offset = dir_table_offset + dir_table_size;
+    var names_table_offset = lookup_table_offset + lookup_table_size;
+    var dll_names_offset = names_table_offset + names_table_size;
+    for (self.import_tables.keys(), 0..) |off, i| {
+        const lib_name = self.temp_strtab.getAssumeExists(off);
+        const itable = self.import_tables.values()[i];
+
+        // Lookup table header
+        const lookup_header = coff.ImportDirectoryEntry{
+            .import_lookup_table_rva = header.virtual_address + lookup_table_offset,
+            .time_date_stamp = 0,
+            .forwarder_chain = 0,
+            .name_rva = header.virtual_address + dll_names_offset,
+            .import_address_table_rva = header.virtual_address + iat_offset,
+        };
+        mem.copy(u8, buffer.items[dir_table_offset..], mem.asBytes(&lookup_header));
+        dir_table_offset += dir_header_size;
+
+        for (itable.entries.items) |entry| {
+            const import_name = self.getSymbolName(entry);
+
+            // IAT and lookup table entry
+            const lookup = coff.ImportLookupEntry64.ByName{ .name_table_rva = @intCast(u31, header.virtual_address + names_table_offset) };
+            mem.copy(u8, buffer.items[iat_offset..], mem.asBytes(&lookup));
+            iat_offset += lookup_entry_size;
+            mem.copy(u8, buffer.items[lookup_table_offset..], mem.asBytes(&lookup));
+            lookup_table_offset += lookup_entry_size;
+
+            // Names table entry
+            mem.writeIntLittle(u16, buffer.items[names_table_offset..][0..2], 0); // Hint set to 0 until we learn how to parse DLLs
+            names_table_offset += 2;
+            mem.copy(u8, buffer.items[names_table_offset..], import_name);
+            names_table_offset += @intCast(u32, import_name.len);
+            buffer.items[names_table_offset] = 0;
+            names_table_offset += 1;
+            if (!mem.isAlignedGeneric(usize, names_table_offset, @sizeOf(u16))) {
+                buffer.items[names_table_offset] = 0;
+                names_table_offset += 1;
+            }
+        }
 
-    self.data_directories[@enumToInt(coff.DirectoryEntry.IMPORT)] = .{
-        .virtual_address = iat_rva + iat_size,
-        .size = @intCast(u32, @sizeOf(coff.ImportDirectoryEntry) * 2),
+        // IAT sentinel
+        mem.writeIntLittle(u64, buffer.items[iat_offset..][0..lookup_entry_size], 0);
+        iat_offset += 8;
+
+        // Lookup table sentinel
+        mem.copy(u8, buffer.items[lookup_table_offset..], mem.asBytes(&coff.ImportLookupEntry64.ByName{ .name_table_rva = 0 }));
+        lookup_table_offset += lookup_entry_size;
+
+        // DLL name
+        mem.copy(u8, buffer.items[dll_names_offset..], lib_name);
+        dll_names_offset += @intCast(u32, lib_name.len);
+        mem.copy(u8, buffer.items[dll_names_offset..], ext);
+        dll_names_offset += @intCast(u32, ext.len);
+        buffer.items[dll_names_offset] = 0;
+        dll_names_offset += 1;
+    }
+
+    // Sentinel
+    const lookup_header = coff.ImportDirectoryEntry{
+        .import_lookup_table_rva = 0,
+        .time_date_stamp = 0,
+        .forwarder_chain = 0,
+        .name_rva = 0,
+        .import_address_table_rva = 0,
     };
+    mem.copy(u8, buffer.items[dir_table_offset..], mem.asBytes(&lookup_header));
+    dir_table_offset += dir_header_size;
+
+    assert(dll_names_offset == needed_size);
 
+    try self.base.file.?.pwriteAll(buffer.items, header.pointer_to_raw_data);
+
+    self.data_directories[@enumToInt(coff.DirectoryEntry.IMPORT)] = .{
+        .virtual_address = header.virtual_address + iat_size,
+        .size = dir_table_size,
+    };
     self.data_directories[@enumToInt(coff.DirectoryEntry.IAT)] = .{
-        .virtual_address = iat_rva,
+        .virtual_address = header.virtual_address,
         .size = iat_size,
     };
+
+    self.imports_count_dirty = false;
 }
 
 fn writeStrtab(self: *Coff) !void {
@@ -2139,14 +2243,6 @@ pub fn getGotAtomIndexForSymbol(self: *const Coff, sym_loc: SymbolWithLoc) ?Atom
     return self.getAtomIndexForSymbol(.{ .sym_index = got_entry.sym_index, .file = null });
 }
 
-/// Returns import atom that references `sym_loc` if one exists.
-/// Returns null otherwise.
-pub fn getImportAtomIndexForSymbol(self: *const Coff, sym_loc: SymbolWithLoc) ?Atom.Index {
-    const imports_index = self.imports_table.get(sym_loc) orelse return null;
-    const imports_entry = self.imports.items[imports_index];
-    return self.getAtomIndexForSymbol(.{ .sym_index = imports_entry.sym_index, .file = null });
-}
-
 fn setSectionName(self: *Coff, header: *coff.SectionHeader, name: []const u8) !void {
     if (name.len <= 8) {
         mem.copy(u8, &header.name, name);
@@ -2267,3 +2363,19 @@ fn logSections(self: *Coff) void {
         });
     }
 }
+
+fn logImportTables(self: *const Coff) void {
+    log.debug("import tables:", .{});
+    for (self.import_tables.keys(), 0..) |off, i| {
+        const lib_name = self.temp_strtab.getAssumeExists(off);
+        const itable = self.import_tables.values()[i];
+        log.debug("IAT({s}) @{x}:", .{ lib_name, itable.getBaseAddress(self) });
+        for (itable.entries.items, 0..) |entry, j| {
+            log.debug("  {d}@{?x} => {s}", .{
+                j,
+                itable.getImportAddress(self, entry),
+                self.getSymbolName(entry),
+            });
+        }
+    }
+}