Commit 9fce1df4cd

Luuk de Gram <luuk@degram.dev>
2023-03-17 06:32:37
wasm-linker: implement runtime TLS relocations
1 parent 9d13c22
Changed files (3)
src/link/Wasm/Object.zig
@@ -930,11 +930,29 @@ pub fn parseIntoAtoms(object: *Object, gpa: Allocator, object_index: u16, wasm_b
                 reloc.offset -= relocatable_data.offset;
                 try atom.relocs.append(gpa, reloc);
 
-                if (relocation.isTableIndex()) {
-                    try wasm_bin.function_table.put(gpa, .{
-                        .file = object_index,
-                        .index = relocation.index,
-                    }, 0);
+                switch (relocation.relocation_type) {
+                    .R_WASM_TABLE_INDEX_I32,
+                    .R_WASM_TABLE_INDEX_I64,
+                    .R_WASM_TABLE_INDEX_SLEB,
+                    .R_WASM_TABLE_INDEX_SLEB64,
+                    => {
+                        try wasm_bin.function_table.put(gpa, .{
+                            .file = object_index,
+                            .index = relocation.index,
+                        }, 0);
+                    },
+                    .R_WASM_GLOBAL_INDEX_I32,
+                    .R_WASM_GLOBAL_INDEX_LEB,
+                    => {
+                        const sym = object.symtable[relocation.index];
+                        if (sym.tag != .global) {
+                            try wasm_bin.got_symbols.append(
+                                wasm_bin.base.allocator,
+                                .{ .file = object_index, .index = relocation.index },
+                            );
+                        }
+                    },
+                    else => {},
                 }
             }
         }
src/link/Wasm/types.zig
@@ -71,18 +71,6 @@ pub const Relocation = struct {
         };
     }
 
-    /// Returns true when the relocation represents a table index relocatable
-    pub fn isTableIndex(self: Relocation) bool {
-        return switch (self.relocation_type) {
-            .R_WASM_TABLE_INDEX_I32,
-            .R_WASM_TABLE_INDEX_I64,
-            .R_WASM_TABLE_INDEX_SLEB,
-            .R_WASM_TABLE_INDEX_SLEB64,
-            => true,
-            else => false,
-        };
-    }
-
     pub fn format(self: Relocation, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
         _ = fmt;
         _ = options;
src/link/Wasm.zig
@@ -139,6 +139,8 @@ archives: std.ArrayListUnmanaged(Archive) = .{},
 
 /// A map of global names (read: offset into string table) to their symbol location
 globals: std.AutoHashMapUnmanaged(u32, SymbolLoc) = .{},
+/// The list of GOT symbols and their location
+got_symbols: std.ArrayListUnmanaged(SymbolLoc) = .{},
 /// Maps discarded symbols and their positions to the location of the symbol
 /// it was resolved to
 discarded: std.AutoHashMapUnmanaged(SymbolLoc, SymbolLoc) = .{},
@@ -635,6 +637,15 @@ fn parseArchive(wasm: *Wasm, path: []const u8, force_load: bool) !bool {
     return true;
 }
 
+fn requiresTLSReloc(wasm: *const Wasm) bool {
+    for (wasm.got_symbols.items) |loc| {
+        if (loc.getSymbol(wasm).isTLS()) {
+            return true;
+        }
+    }
+    return false;
+}
+
 fn resolveSymbolsInObject(wasm: *Wasm, object_index: u16) !void {
     const object: Object = wasm.objects.items[object_index];
     log.debug("Resolving symbols in object: '{s}'", .{object.name});
@@ -813,6 +824,48 @@ fn resolveSymbolsInArchives(wasm: *Wasm) !void {
     }
 }
 
+fn setupTLSRelocationsFunction(wasm: *Wasm) !void {
+    // When we have TLS GOT entries and shared memory is enabled,
+    // we must perform runtime relocations or else we don't create the function.
+    if (!wasm.base.options.shared_memory or !wasm.requiresTLSReloc()) {
+        return;
+    }
+
+    // const loc = try wasm.createSyntheticSymbol("__wasm_apply_global_tls_relocs");
+    var function_body = std.ArrayList(u8).init(wasm.base.allocator);
+    defer function_body.deinit();
+    const writer = function_body.writer();
+
+    // locals (we have none)
+    try writer.writeByte(0);
+    for (wasm.got_symbols.items, 0..) |got_loc, got_index| {
+        const sym: *Symbol = got_loc.getSymbol(wasm);
+        if (!sym.isTLS()) continue; // only relocate TLS symbols
+        if (sym.tag == .data and sym.isDefined()) {
+            // get __tls_base
+            try writer.writeByte(std.wasm.opcode(.global_get));
+            try leb.writeULEB128(writer, wasm.findGlobalSymbol("__tls_base").?.getSymbol(wasm).index);
+
+            // add the virtual address of the symbol
+            try writer.writeByte(std.wasm.opcode(.i32_const));
+            try leb.writeULEB128(writer, sym.virtual_address);
+        } else if (sym.tag == .function) {
+            @panic("TODO: relocate GOT entry of function");
+        } else continue;
+
+        try writer.writeByte(std.wasm.opcode(.i32_add));
+        try writer.writeByte(std.wasm.opcode(.global_set));
+        try leb.writeULEB128(writer, wasm.imported_globals_count + @intCast(u32, wasm.wasm_globals.items.len + got_index));
+    }
+    try writer.writeByte(std.wasm.opcode(.end));
+
+    try wasm.createSyntheticFunction(
+        "__wasm_apply_global_tls_relocs",
+        std.wasm.Type{ .params = &.{}, .returns = &.{} },
+        &function_body,
+    );
+}
+
 fn validateFeatures(
     wasm: *const Wasm,
     to_emit: *[@typeInfo(types.Feature.Tag).Enum.fields.len]bool,
@@ -2083,6 +2136,14 @@ fn initializeTLSFunction(wasm: *Wasm) !void {
         try leb.writeULEB128(writer, @as(u32, 0));
     }
 
+    // If we have to perform any TLS relocations, call the corresponding function
+    // which performs all runtime TLS relocations. This is a synthetic function,
+    // generated by the linker.
+    if (wasm.findGlobalSymbol("__wasm_apply_global_tls_relocs")) |loc| {
+        try writer.writeByte(std.wasm.opcode(.call));
+        try leb.writeULEB128(writer, loc.getSymbol(wasm).index);
+    }
+
     try writer.writeByte(std.wasm.opcode(.end));
 
     try wasm.createSyntheticFunction(
@@ -2939,6 +3000,7 @@ fn linkWithZld(wasm: *Wasm, comp: *Compilation, prog_node: *std.Progress.Node) l
     try wasm.mergeSections();
     try wasm.mergeTypes();
     try wasm.initializeCallCtorsFunction();
+    try wasm.setupTLSRelocationsFunction();
     try wasm.initializeTLSFunction();
     try wasm.setupExports();
     try wasm.writeToFile(enabled_features, emit_features_count, arena);
@@ -3059,6 +3121,7 @@ pub fn flushModule(wasm: *Wasm, comp: *Compilation, prog_node: *std.Progress.Nod
     try wasm.mergeSections();
     try wasm.mergeTypes();
     try wasm.initializeCallCtorsFunction();
+    try wasm.setupTLSRelocationsFunction();
     try wasm.initializeTLSFunction();
     try wasm.setupExports();
     try wasm.writeToFile(enabled_features, emit_features_count, arena);