Commit 7226ad2670

Luuk de Gram <luuk@degram.dev>
2021-11-28 12:49:04
wasm-link: Implement indirect function table
The function table contains all function pointers that are called by using call_indirect. During codegen, we create a relocation where the linker will resolve the correct index into the table and stores this value within the data section at the location of the pointer.
1 parent 9b5d614
Changed files (5)
src/arch/wasm/CodeGen.zig
@@ -1065,9 +1065,16 @@ fn airCall(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
     const pl_op = self.air.instructions.items(.data)[inst].pl_op;
     const extra = self.air.extraData(Air.Call, pl_op.payload);
     const args = self.air.extra[extra.end..][0..extra.data.args_len];
+    const ty = self.air.typeOf(pl_op.operand);
 
-    const target: *Decl = blk: {
-        const func_val = self.air.value(pl_op.operand).?;
+    const fn_ty = switch (ty.zigTypeTag()) {
+        .Fn => ty,
+        .Pointer => ty.childType(),
+        else => unreachable,
+    };
+
+    const target: ?*Decl = blk: {
+        const func_val = self.air.value(pl_op.operand) orelse break :blk null;
 
         if (func_val.castTag(.function)) |func| {
             break :blk func.data.owner_decl;
@@ -1082,9 +1089,24 @@ fn airCall(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
         try self.emitWValue(arg_val);
     }
 
-    try self.addLabel(.call, target.link.wasm.sym_index);
+    if (target) |direct| {
+        try self.addLabel(.call, direct.link.wasm.sym_index);
+    } else {
+        // in this case we call a function pointer
+        // so load its value onto the stack
+        std.debug.assert(ty.zigTypeTag() == .Pointer);
+        const operand = self.resolveInst(pl_op.operand);
+        const result = try self.load(operand, fn_ty, operand.local_with_offset.offset);
+        try self.addLabel(.local_get, result.local);
+
+        var fn_type = try self.genFunctype(fn_ty);
+        defer fn_type.deinit(self.gpa);
+
+        const fn_type_index = try self.bin_file.putOrGetFuncType(fn_type);
+        try self.addLabel(.call_indirect, fn_type_index);
+    }
 
-    const ret_ty = target.ty.fnReturnType();
+    const ret_ty = fn_ty.fnReturnType();
     switch (ret_ty.zigTypeTag()) {
         .Void, .NoReturn => return WValue.none,
         else => {
src/arch/wasm/Emit.zig
@@ -47,6 +47,7 @@ pub fn emitMir(emit: *Emit) InnerError!void {
 
             // relocatables
             .call => try emit.emitCall(inst),
+            .call_indirect => try emit.emitCallIndirect(inst),
             .global_get => try emit.emitGlobal(tag, inst),
             .global_set => try emit.emitGlobal(tag, inst),
             .memory_address => try emit.emitMemAddress(inst),
@@ -276,6 +277,13 @@ fn emitCall(emit: *Emit, inst: Mir.Inst.Index) !void {
     });
 }
 
+fn emitCallIndirect(emit: *Emit, inst: Mir.Inst.Index) !void {
+    const label = emit.mir.instructions.items(.data)[inst].label;
+    try emit.code.append(std.wasm.opcode(.call_indirect));
+    try leb128.writeULEB128(emit.code.writer(), @as(u32, 0)); // TODO: Emit relocation for table index
+    try leb128.writeULEB128(emit.code.writer(), label);
+}
+
 fn emitMemAddress(emit: *Emit, inst: Mir.Inst.Index) !void {
     const symbol_index = emit.mir.instructions.items(.data)[inst].label;
     try emit.code.append(std.wasm.opcode(.i32_const));
src/arch/wasm/Mir.zig
@@ -69,6 +69,11 @@ pub const Inst = struct {
         ///
         /// Uses `label`
         call = 0x10,
+        /// Calls a function pointer by its function signature
+        /// and index into the function table.
+        ///
+        /// Uses `label`
+        call_indirect = 0x11,
         /// Loads a local at given index onto the stack.
         ///
         /// Uses `label`
src/link/Wasm/Atom.zig
@@ -129,7 +129,7 @@ fn relocationValue(relocation: types.Relocation, wasm_bin: *const Wasm) !u64 {
         .R_WASM_TABLE_INDEX_I64,
         .R_WASM_TABLE_INDEX_SLEB,
         .R_WASM_TABLE_INDEX_SLEB64,
-        => return error.TodoImplementTableIndex, // find table index from a function symbol
+        => return wasm_bin.function_table.get(relocation.index) orelse 0,
         .R_WASM_TYPE_INDEX_LEB => wasm_bin.functions.items[symbol.index].type_index,
         .R_WASM_GLOBAL_INDEX_I32,
         .R_WASM_GLOBAL_INDEX_LEB,
src/link/Wasm.zig
@@ -79,7 +79,9 @@ memories: wasm.Memory = .{ .limits = .{ .min = 0, .max = null } },
 /// Indirect function table, used to call function pointers
 /// When this is non-zero, we must emit a table entry,
 /// as well as an 'elements' section.
-function_table: std.ArrayListUnmanaged(Symbol) = .{},
+///
+/// Note: Key is symbol index, value represents the index into the table
+function_table: std.AutoHashMapUnmanaged(u32, u32) = .{},
 
 pub const Segment = struct {
     alignment: u32,
@@ -276,7 +278,7 @@ pub fn updateDecl(self: *Wasm, module: *Module, decl: *Module.Decl) !void {
     defer codegen.deinit();
 
     // generate the 'code' section for the function declaration
-    const result = codegen.gen(decl.ty, decl.val) catch |err| switch (err) {
+    const result = codegen.genDecl(decl.ty, decl.val) catch |err| switch (err) {
         error.CodegenFail => {
             decl.analysis = .codegen_failure;
             try module.failed_decls.put(module.gpa, decl, codegen.err_msg);
@@ -334,6 +336,25 @@ pub fn freeDecl(self: *Wasm, decl: *Module.Decl) void {
             else => unreachable,
         }
     }
+
+    // maybe remove from function table if needed
+    if (decl.ty.zigTypeTag() == .Fn) {
+        _ = self.function_table.remove(atom.sym_index);
+    }
+}
+
+/// Appends a new entry to the indirect function table
+pub fn addTableFunction(self: *Wasm, symbol_index: u32) !void {
+    const index = @intCast(u32, self.function_table.count());
+    try self.function_table.put(self.base.allocator, symbol_index, index);
+}
+
+fn mapFunctionTable(self: *Wasm) void {
+    var it = self.function_table.valueIterator();
+    var index: u32 = 0;
+    while (it.next()) |value_ptr| : (index += 1) {
+        value_ptr.* = index;
+    }
 }
 
 fn addOrUpdateImport(self: *Wasm, decl: *Module.Decl) !void {
@@ -583,6 +604,7 @@ pub fn flushModule(self: *Wasm, comp: *Compilation) !void {
 
     try self.setupMemory();
     try self.allocateAtoms();
+    self.mapFunctionTable();
 
     const file = self.base.file.?;
     const header_size = 5 + 1;
@@ -662,6 +684,22 @@ pub fn flushModule(self: *Wasm, comp: *Compilation) !void {
         );
     }
 
+    if (self.function_table.count() > 0) {
+        const header_offset = try reserveVecSectionHeader(file);
+        const writer = file.writer();
+
+        try leb.writeULEB128(writer, wasm.reftype(.funcref));
+        try emitLimits(writer, .{ .min = 1, .max = null });
+
+        try writeVecSectionHeader(
+            file,
+            header_offset,
+            .table,
+            @intCast(u32, (try file.getPos()) - header_offset - header_size),
+            @as(u32, 1),
+        );
+    }
+
     // Memory section
     if (!self.base.options.import_memory) {
         const header_offset = try reserveVecSectionHeader(file);
@@ -743,6 +781,31 @@ pub fn flushModule(self: *Wasm, comp: *Compilation) !void {
         );
     }
 
+    // element section (function table)
+    if (self.function_table.count() > 0) {
+        const header_offset = try reserveVecSectionHeader(file);
+        const writer = file.writer();
+
+        var flags: u32 = 0x2; // Yes we have a table
+        try leb.writeULEB128(writer, flags);
+        try leb.writeULEB128(writer, @as(u32, 0)); // index of that table. TODO: Store synthetic symbols
+        try emitInit(writer, .{ .i32_const = 0 });
+        try leb.writeULEB128(writer, @as(u8, 0));
+        try leb.writeULEB128(writer, @intCast(u32, self.function_table.count()));
+        var symbol_it = self.function_table.keyIterator();
+        while (symbol_it.next()) |symbol_index_ptr| {
+            try leb.writeULEB128(writer, self.symbols.items[symbol_index_ptr.*].index);
+        }
+
+        try writeVecSectionHeader(
+            file,
+            header_offset,
+            .element,
+            @intCast(u32, (try file.getPos()) - header_offset - header_size),
+            @as(u32, 1),
+        );
+    }
+
     // Code section
     if (self.code_section_index) |code_index| {
         const header_offset = try reserveVecSectionHeader(file);
@@ -1233,16 +1296,3 @@ pub fn putOrGetFuncType(self: *Wasm, func_type: wasm.Type) !u32 {
     });
     return index;
 }
-
-/// From a given index and an `ExternalKind`, finds the corresponding Import.
-/// This is due to indexes for imports being unique per type, rather than across all imports.
-fn findImport(self: Wasm, index: u32, external_type: wasm.ExternalKind) ?*wasm.Import {
-    var current_index: u32 = 0;
-    for (self.imports.items) |*import| {
-        if (import.kind == external_type) {
-            if (current_index == index) return import;
-            current_index += 1;
-        }
-    }
-    return null;
-}