Commit 6f7a9b3144

Luuk de Gram <luuk@degram.dev>
2023-11-20 20:35:31
wasm-linker: deduplicate aliased functions
When multiple symbols point to the same function, we ensure any other symbol other than the original will be discarded and point to the original instead. This prevents emitting the same function code more than once.
1 parent 8856ba7
Changed files (2)
src
src/link/Wasm/Symbol.zig
@@ -100,6 +100,10 @@ pub fn mark(symbol: *Symbol) void {
     symbol.flags |= @intFromEnum(Flag.alive);
 }
 
+pub fn unmark(symbol: *Symbol) void {
+    symbol.flags &= ~@intFromEnum(Flag.alive);
+}
+
 pub fn isAlive(symbol: Symbol) bool {
     return symbol.flags & @intFromEnum(Flag.alive) != 0;
 }
src/link/Wasm.zig
@@ -110,7 +110,7 @@ func_types: std.ArrayListUnmanaged(std.wasm.Type) = .{},
 /// Output function section where the key is the original
 /// function index and the value is function.
 /// This allows us to map multiple symbols to the same function.
-functions: std.AutoArrayHashMapUnmanaged(struct { file: ?u16, index: u32 }, std.wasm.Func) = .{},
+functions: std.AutoArrayHashMapUnmanaged(struct { file: ?u16, index: u32 }, struct { func: std.wasm.Func, sym_index: u32 }) = .{},
 /// Output global section
 wasm_globals: std.ArrayListUnmanaged(std.wasm.Global) = .{},
 /// Memory section
@@ -1584,7 +1584,7 @@ fn getFunctionSignature(wasm: *const Wasm, loc: SymbolLoc) std.wasm.Type {
         const ty_index = wasm.imports.get(loc).?.kind.function;
         return wasm.func_types.items[ty_index];
     }
-    return wasm.func_types.items[wasm.functions.get(.{ .file = loc.file, .index = loc.index }).?.type_index];
+    return wasm.func_types.items[wasm.functions.get(.{ .file = loc.file, .index = symbol.index }).?.func.type_index];
 }
 
 /// Lowers a constant typed value to a local symbol and atom.
@@ -2141,7 +2141,7 @@ fn parseAtom(wasm: *Wasm, atom_index: Atom.Index, kind: Kind) !void {
             try wasm.functions.putNoClobber(
                 wasm.base.allocator,
                 .{ .file = null, .index = index },
-                .{ .type_index = type_index },
+                .{ .func = .{ .type_index = type_index }, .sym_index = atom.sym_index },
             );
             symbol.tag = .function;
             symbol.index = index;
@@ -2274,7 +2274,14 @@ fn allocateAtoms(wasm: *Wasm) !void {
         while (true) {
             const atom = wasm.getAtomPtr(atom_index);
             const symbol_loc = atom.symbolLoc();
-            const sym = symbol_loc.getSymbol(wasm);
+            // Ensure we get the original symbol, so we verify the correct symbol on whether
+            // it is dead or not and ensure an atom is removed when dead.
+            // This is required as we may have parsed aliases into atoms.
+            const sym = if (symbol_loc.file) |object_index| sym: {
+                const object = wasm.objects.items[object_index];
+                break :sym object.symtable[symbol_loc.index];
+            } else wasm.symbols.items[symbol_loc.index];
+
             if (sym.isDead()) {
                 // Dead symbols must be unlinked from the linked-list to prevent them
                 // from being emit into the binary.
@@ -2477,7 +2484,7 @@ fn initializeCallCtorsFunction(wasm: *Wasm) !void {
         // call constructors
         for (wasm.init_funcs.items) |init_func_loc| {
             const symbol = init_func_loc.getSymbol(wasm);
-            const func = wasm.functions.values()[symbol.index - wasm.imported_functions_count];
+            const func = wasm.functions.values()[symbol.index - wasm.imported_functions_count].func;
             const ty = wasm.func_types.items[func.type_index];
 
             // Call function by its function index
@@ -2519,7 +2526,7 @@ fn createSyntheticFunction(
     try wasm.functions.putNoClobber(
         wasm.base.allocator,
         .{ .file = null, .index = func_index },
-        .{ .type_index = ty_index },
+        .{ .func = .{ .type_index = ty_index }, .sym_index = loc.index },
     );
     symbol.index = func_index;
 
@@ -2740,6 +2747,9 @@ fn setupImports(wasm: *Wasm) !void {
 /// Takes the global, function and table section from each linked object file
 /// and merges it into a single section for each.
 fn mergeSections(wasm: *Wasm) !void {
+    var removed_duplicates = std.ArrayList(SymbolLoc).init(wasm.base.allocator);
+    defer removed_duplicates.deinit();
+
     for (wasm.resolved_symbols.keys()) |sym_loc| {
         if (sym_loc.file == null) {
             // Zig code-generated symbols are already within the sections and do not
@@ -2767,9 +2777,19 @@ fn mergeSections(wasm: *Wasm) !void {
                     wasm.base.allocator,
                     .{ .file = sym_loc.file, .index = symbol.index },
                 );
-                if (!gop.found_existing) {
-                    gop.value_ptr.* = object.functions[index];
+                if (gop.found_existing) {
+                    // We found an alias to the same function, discard this symbol in favor of
+                    // the original symbol and point the discard function to it. This ensures
+                    // we only emit a single function, instead of duplicates.
+                    try wasm.discarded.putNoClobber(
+                        wasm.base.allocator,
+                        sym_loc,
+                        .{ .file = gop.key_ptr.*.file, .index = gop.value_ptr.*.sym_index },
+                    );
+                    try removed_duplicates.append(sym_loc);
+                    continue;
                 }
+                gop.value_ptr.* = .{ .func = object.functions[index], .sym_index = sym_loc.index };
                 symbol.index = @as(u32, @intCast(gop.index)) + wasm.imported_functions_count;
             },
             .global => {
@@ -2786,6 +2806,12 @@ fn mergeSections(wasm: *Wasm) !void {
         }
     }
 
+    // For any removed duplicates, remove them from the resolved symbols list
+    for (removed_duplicates.items) |sym_loc| {
+        assert(wasm.resolved_symbols.swapRemove(sym_loc));
+        sym_loc.getSymbol(wasm).unmark();
+    }
+
     log.debug("Merged ({d}) functions", .{wasm.functions.count()});
     log.debug("Merged ({d}) globals", .{wasm.wasm_globals.items.len});
     log.debug("Merged ({d}) tables", .{wasm.tables.items.len});
@@ -2821,7 +2847,7 @@ fn mergeTypes(wasm: *Wasm) !void {
             import.kind.function = try wasm.putOrGetFuncType(original_type);
         } else if (!dirty.contains(symbol.index)) {
             log.debug("Adding type from function '{s}'", .{sym_loc.getName(wasm)});
-            const func = &wasm.functions.values()[symbol.index - wasm.imported_functions_count];
+            const func = &wasm.functions.values()[symbol.index - wasm.imported_functions_count].func;
             func.type_index = try wasm.putOrGetFuncType(object.func_types[func.type_index]);
             dirty.putAssumeCapacityNoClobber(symbol.index, {});
         }
@@ -3498,12 +3524,12 @@ fn linkWithZld(wasm: *Wasm, comp: *Compilation, prog_node: *std.Progress.Node) l
 
     try wasm.markReferences();
     try wasm.setupImports();
+    try wasm.mergeSections();
+    try wasm.mergeTypes();
     try wasm.allocateAtoms();
     try wasm.setupMemory();
     wasm.allocateVirtualAddresses();
     wasm.mapFunctionTable();
-    try wasm.mergeSections();
-    try wasm.mergeTypes();
     try wasm.initializeCallCtorsFunction();
     try wasm.setupInitMemoryFunction();
     try wasm.setupTLSRelocationsFunction();
@@ -3639,12 +3665,12 @@ pub fn flushModule(wasm: *Wasm, comp: *Compilation, prog_node: *std.Progress.Nod
         }
     }
 
+    try wasm.mergeSections();
+    try wasm.mergeTypes();
     try wasm.allocateAtoms();
     try wasm.setupMemory();
     wasm.allocateVirtualAddresses();
     wasm.mapFunctionTable();
-    try wasm.mergeSections();
-    try wasm.mergeTypes();
     try wasm.initializeCallCtorsFunction();
     try wasm.setupInitMemoryFunction();
     try wasm.setupTLSRelocationsFunction();
@@ -3745,7 +3771,7 @@ fn writeToFile(
     if (wasm.functions.count() != 0) {
         const header_offset = try reserveVecSectionHeader(&binary_bytes);
         for (wasm.functions.values()) |function| {
-            try leb.writeULEB128(binary_writer, function.type_index);
+            try leb.writeULEB128(binary_writer, function.func.type_index);
         }
 
         try writeVecSectionHeader(
@@ -3916,6 +3942,7 @@ fn writeToFile(
             sorted_atoms.appendAssumeCapacity(atom); // found more code atoms than functions
             atom_index = atom.prev orelse break;
         }
+        std.debug.assert(wasm.functions.count() == sorted_atoms.items.len);
 
         const atom_sort_fn = struct {
             fn sort(ctx: *const Wasm, lhs: *const Atom, rhs: *const Atom) bool {