Commit f8d1efd99a

Luuk de Gram <luuk@degram.dev>
2023-01-11 07:03:15
wasm-linker: implement __wasm_call_ctors symbol
This implements the `__wasm_call_ctors` symbol. This symbol is automatically referenced by libc to initialize its constructors. We first retrieve all constructors from each object file, and then create a function body that calls each constructor based on its priority. Constructors are not allowed to have any parameters, but are allowed to have a return type. When a return type does exist, we simply drop its value from the stack after calling the constructor to ensure we pass the stack validator.
1 parent 1072f82
Changed files (2)
src/link/Wasm.zig
@@ -118,6 +118,9 @@ memories: std.wasm.Memory = .{ .limits = .{ .min = 0, .max = null } },
 tables: std.ArrayListUnmanaged(std.wasm.Table) = .{},
 /// Output export section
 exports: std.ArrayListUnmanaged(types.Export) = .{},
+/// List of initialization functions. These must be called in order of priority
+/// by the (synthetic) __wasm_call_ctors function.
+init_funcs: std.ArrayListUnmanaged(InitFuncLoc) = .{},
 
 /// Indirect function table, used to call function pointers
 /// When this is non-zero, we must emit a table entry,
@@ -238,6 +241,34 @@ pub const SymbolLoc = struct {
     }
 };
 
+// Contains the location of the function symbol, as well as
+/// the priority itself of the initialization function.
+pub const InitFuncLoc = struct {
+    /// object file index in the list of objects.
+    /// Unlike `SymbolLoc` this cannot be `null` as we never define
+    /// our own ctors.
+    file: u16,
+    /// Symbol index within the corresponding object file.
+    index: u32,
+    /// The priority in which the constructor must be called.
+    priority: u32,
+
+    /// From a given `InitFuncLoc` returns the corresponding function symbol
+    fn getSymbol(loc: InitFuncLoc, wasm: *const Wasm) *Symbol {
+        return getSymbolLoc(loc).getSymbol(wasm);
+    }
+
+    /// Turns the given `InitFuncLoc` into a `SymbolLoc`
+    fn getSymbolLoc(loc: InitFuncLoc) SymbolLoc {
+        return .{ .file = loc.file, .index = loc.index };
+    }
+
+    /// Returns true when `lhs` has a higher priority (e.i. value closer to 0) than `rhs`.
+    fn lessThan(ctx: void, lhs: InitFuncLoc, rhs: InitFuncLoc) bool {
+        _ = ctx;
+        return lhs.priority < rhs.priority;
+    }
+};
 /// Generic string table that duplicates strings
 /// and converts them into offsets instead.
 pub const StringTable = struct {
@@ -393,6 +424,16 @@ pub fn openPath(allocator: Allocator, sub_path: []const u8, options: link.Option
         }
     }
 
+    // create __wasm_call_ctors
+    {
+        const loc = try wasm_bin.createSyntheticSymbol("__wasm_call_ctors", .function);
+        const symbol = loc.getSymbol(wasm_bin);
+        symbol.setFlag(.WASM_SYM_VISIBILITY_HIDDEN);
+        // we do not know the function index until after we merged all sections.
+        // Therefore we set `symbol.index` and create its corresponding references
+        // at the end during `initializeCallCtorsFunction`.
+    }
+
     if (!options.strip and options.module != null) {
         wasm_bin.dwarf = Dwarf.init(allocator, &wasm_bin.base, options.target);
         try wasm_bin.initDebugSections();
@@ -896,6 +937,7 @@ pub fn deinit(wasm: *Wasm) void {
     wasm.wasm_globals.deinit(gpa);
     wasm.function_table.deinit(gpa);
     wasm.tables.deinit(gpa);
+    wasm.init_funcs.deinit(gpa);
     wasm.exports.deinit(gpa);
 
     wasm.string_table.deinit(gpa);
@@ -1698,6 +1740,130 @@ fn sortDataSegments(wasm: *Wasm) !void {
     wasm.data_segments = new_mapping;
 }
 
+/// Obtains all initfuncs from each object file, verifies its function signature,
+/// and then appends it to our final `init_funcs` list.
+/// After all functions have been inserted, the functions will be ordered based
+/// on their priority.
+/// NOTE: This function must be called before we merged any other section.
+/// This is because all init funcs in the object files contain references to the
+/// original functions and their types. We need to know the type to verify it doesn't
+/// contain any parameters.
+fn setupInitFunctions(wasm: *Wasm) !void {
+    for (wasm.objects.items) |object, file_index| {
+        try wasm.init_funcs.ensureUnusedCapacity(wasm.base.allocator, object.init_funcs.len);
+        for (object.init_funcs) |init_func| {
+            const symbol = object.symtable[init_func.symbol_index];
+            const ty: std.wasm.Type = if (symbol.isUndefined()) ty: {
+                const imp: types.Import = object.findImport(.function, symbol.index);
+                break :ty object.func_types[imp.kind.function];
+            } else ty: {
+                const func_index = symbol.index - object.importedCountByKind(.function);
+                const func = object.functions[func_index];
+                break :ty object.func_types[func.type_index];
+            };
+            if (ty.params.len != 0) {
+                log.err("constructor functions cannot take arguments: '{s}'", .{object.string_table.get(symbol.name)});
+                return error.InvalidInitFunc;
+            }
+            log.debug("appended init func '{s}'\n", .{object.string_table.get(symbol.name)});
+            wasm.init_funcs.appendAssumeCapacity(.{
+                .index = init_func.symbol_index,
+                .file = @intCast(u16, file_index),
+                .priority = init_func.priority,
+            });
+        }
+    }
+
+    // sort the initfunctions based on their priority
+    std.sort.sort(InitFuncLoc, wasm.init_funcs.items, {}, InitFuncLoc.lessThan);
+}
+
+/// Creates a function body for the `__wasm_call_ctors` symbol.
+/// Loops over all constructors found in `init_funcs` and calls them
+/// respectively based on their priority which was sorted by `setupInitFunctions`.
+/// NOTE: This function must be called after we merged all sections to ensure the
+/// references to the function stored in the symbol have been finalized so we end
+/// up calling the resolved function.
+fn initializeCallCtorsFunction(wasm: *Wasm) !void {
+    // No code to emit, so also no ctors to call
+    if (wasm.code_section_index == null) {
+        // Make sure to remove it from the resolved symbols so we do not emit
+        // it within any section. TODO: Remove this once we implement garbage collection.
+        const loc = wasm.globals.get(wasm.string_table.getOffset("__wasm_call_ctors").?).?;
+        std.debug.assert(wasm.resolved_symbols.swapRemove(loc));
+        return;
+    }
+
+    var function_body = std.ArrayList(u8).init(wasm.base.allocator);
+    defer function_body.deinit();
+    const writer = function_body.writer();
+
+    // Create the function body
+    {
+        // Write locals count (we have none)
+        try leb.writeULEB128(writer, @as(u32, 0));
+
+        // call constructors
+        for (wasm.init_funcs.items) |init_func_loc| {
+            const symbol = init_func_loc.getSymbol(wasm);
+            if (symbol.isUndefined()) {
+                std.debug.print("Undefined symbol '{s}'\n", .{wasm.string_table.get(symbol.name)});
+            }
+            std.debug.print("Symbol: {s}\n", .{init_func_loc.getSymbolLoc().getName(wasm)});
+            std.debug.assert(wasm.resolved_symbols.contains(init_func_loc.getSymbolLoc().finalLoc(wasm)));
+            const func = wasm.functions.values()[symbol.index - wasm.imported_functions_count];
+            const ty = wasm.func_types.items[func.type_index];
+
+            // Call function by its function index
+            try writer.writeByte(std.wasm.opcode(.call));
+            try leb.writeULEB128(writer, symbol.index);
+
+            // drop all returned values from the stack as __wasm_call_ctors has no return value
+            for (ty.returns) |_| {
+                try writer.writeByte(std.wasm.opcode(.drop));
+            }
+        }
+
+        // End function body
+        try writer.writeByte(std.wasm.opcode(.end));
+    }
+
+    const loc = wasm.globals.get(wasm.string_table.getOffset("__wasm_call_ctors").?).?;
+    const symbol = loc.getSymbol(wasm);
+    // create type (() -> nil) as we do not have any parameters or return value.
+    const ty_index = try wasm.putOrGetFuncType(.{ .params = &[_]std.wasm.Valtype{}, .returns = &[_]std.wasm.Valtype{} });
+    // create function with above type
+    const func_index = wasm.imported_functions_count + @intCast(u32, wasm.functions.count());
+    try wasm.functions.putNoClobber(
+        wasm.base.allocator,
+        .{ .file = null, .index = func_index },
+        .{ .type_index = ty_index },
+    );
+    symbol.index = func_index;
+
+    // create the atom that will be output into the final binary
+    const atom = try wasm.base.allocator.create(Atom);
+    errdefer wasm.base.allocator.destroy(atom);
+    atom.* = .{
+        .size = @intCast(u32, function_body.items.len),
+        .offset = 0,
+        .sym_index = loc.index,
+        .file = null,
+        .alignment = 1,
+        .next = null,
+        .prev = null,
+        .code = function_body.moveToUnmanaged(),
+        .dbg_info_atom = undefined,
+    };
+    try wasm.managed_atoms.append(wasm.base.allocator, atom);
+    try wasm.appendAtomAtIndex(wasm.code_section_index.?, atom);
+    try wasm.symbol_atom.putNoClobber(wasm.base.allocator, loc, atom);
+
+    // `allocateAtoms` has already been called, set the atom's offset manually.
+    // This is fine to do manually as we insert the atom at the very end.
+    atom.offset = atom.prev.?.offset + atom.prev.?.size;
+}
+
 fn setupImports(wasm: *Wasm) !void {
     log.debug("Merging imports", .{});
     var discarded_it = wasm.discarded.keyIterator();
@@ -1870,16 +2036,17 @@ fn setupExports(wasm: *Wasm) !void {
 
     const force_exp_names = wasm.base.options.export_symbol_names;
     if (force_exp_names.len > 0) {
-        var failed_exports = try std.ArrayList([]const u8).initCapacity(wasm.base.allocator, force_exp_names.len);
-        defer failed_exports.deinit();
+        var failed_exports = false;
 
         for (force_exp_names) |exp_name| {
             const name_index = wasm.string_table.getOffset(exp_name) orelse {
-                failed_exports.appendAssumeCapacity(exp_name);
+                log.err("could not export '{s}', symbol not found", .{exp_name});
+                failed_exports = true;
                 continue;
             };
             const loc = wasm.globals.get(name_index) orelse {
-                failed_exports.appendAssumeCapacity(exp_name);
+                log.err("could not export '{s}', symbol not found", .{exp_name});
+                failed_exports = true;
                 continue;
             };
 
@@ -1887,10 +2054,7 @@ fn setupExports(wasm: *Wasm) !void {
             symbol.setFlag(.WASM_SYM_EXPORTED);
         }
 
-        if (failed_exports.items.len > 0) {
-            for (failed_exports.items) |exp_name| {
-                log.err("could not export '{s}', symbol not found", .{exp_name});
-            }
+        if (failed_exports) {
             return error.MissingSymbol;
         }
     }
@@ -1948,6 +2112,7 @@ fn setupStart(wasm: *Wasm) !void {
 
     const symbol_loc = wasm.globals.get(symbol_name_offset) orelse {
         log.err("Entry symbol '{s}' not found", .{entry_name});
+        return error.MissingSymbol;
     };
     const symbol = symbol_loc.getSymbol(wasm);
     if (symbol.tag != .function) {
@@ -2503,6 +2668,7 @@ fn linkWithZld(wasm: *Wasm, comp: *Compilation, prog_node: *std.Progress.Node) l
     try wasm.resolveSymbolsInArchives();
     try wasm.checkUndefinedSymbols();
 
+    try wasm.setupInitFunctions();
     try wasm.setupStart();
     try wasm.setupImports();
 
@@ -2515,6 +2681,7 @@ fn linkWithZld(wasm: *Wasm, comp: *Compilation, prog_node: *std.Progress.Node) l
     wasm.mapFunctionTable();
     try wasm.mergeSections();
     try wasm.mergeTypes();
+    try wasm.initializeCallCtorsFunction();
     try wasm.setupExports();
     try wasm.writeToFile(enabled_features, emit_features_count, arena);
 
@@ -2587,6 +2754,7 @@ pub fn flushModule(wasm: *Wasm, comp: *Compilation, prog_node: *std.Progress.Nod
     // When we finish/error we reset the state of the linker
     // So we can rebuild the binary file on each incremental update
     defer wasm.resetState();
+    try wasm.setupInitFunctions();
     try wasm.setupStart();
     try wasm.setupImports();
     if (wasm.base.options.module) |mod| {
@@ -2629,6 +2797,7 @@ pub fn flushModule(wasm: *Wasm, comp: *Compilation, prog_node: *std.Progress.Nod
     wasm.mapFunctionTable();
     try wasm.mergeSections();
     try wasm.mergeTypes();
+    try wasm.initializeCallCtorsFunction();
     try wasm.setupExports();
     try wasm.writeToFile(enabled_features, emit_features_count, arena);
 }
@@ -3909,8 +4078,8 @@ pub fn getTypeIndex(wasm: *const Wasm, func_type: std.wasm.Type) ?u32 {
     return null;
 }
 
-/// Searches for an a matching function signature, when not found
-/// a new entry will be made. The index of the existing/new signature will be returned.
+/// Searches for a matching function signature. When no matching signature is found,
+/// a new entry will be made. The value returned is the index of the type within `wasm.func_types`.
 pub fn putOrGetFuncType(wasm: *Wasm, func_type: std.wasm.Type) !u32 {
     if (wasm.getTypeIndex(func_type)) |index| {
         return index;
src/link.zig
@@ -716,6 +716,7 @@ pub const File = struct {
         InvalidFeatureSet,
         InvalidFormat,
         InvalidIndex,
+        InvalidInitFunc,
         InvalidMagicByte,
         InvalidWasmVersion,
         LLDCrashed,