Commit 728103467e

Andrew Kelley <andrew@ziglang.org>
2024-12-21 07:15:21
wasm linker: implement indirect function calls
1 parent fbbb54b
Changed files (5)
src/arch/wasm/CodeGen.zig
@@ -1021,7 +1021,20 @@ fn emitWValue(cg: *CodeGen, value: WValue) InnerError!void {
         .float32 => |val| try cg.addInst(.{ .tag = .f32_const, .data = .{ .float32 = val } }),
         .float64 => |val| try cg.addFloat64(val),
         .nav_ref => |nav_ref| {
-            if (nav_ref.offset == 0) {
+            const wasm = cg.wasm;
+            const comp = wasm.base.comp;
+            const zcu = comp.zcu.?;
+            const ip = &zcu.intern_pool;
+            const ip_index = ip.getNav(nav_ref.nav_index).status.resolved.val;
+            if (ip.isFunctionType(ip.typeOf(ip_index))) {
+                assert(nav_ref.offset == 0);
+                const gop = try wasm.indirect_function_table.getOrPut(comp.gpa, ip_index);
+                if (!gop.found_existing) gop.value_ptr.* = {};
+                try cg.addInst(.{
+                    .tag = .func_ref,
+                    .data = .{ .indirect_function_table_index = @enumFromInt(gop.index) },
+                });
+            } else if (nav_ref.offset == 0) {
                 try cg.addInst(.{ .tag = .nav_ref, .data = .{ .nav_index = nav_ref.nav_index } });
             } else {
                 try cg.addInst(.{
@@ -1037,8 +1050,19 @@ fn emitWValue(cg: *CodeGen, value: WValue) InnerError!void {
         },
         .uav_ref => |uav| {
             const wasm = cg.wasm;
-            const is_obj = wasm.base.comp.config.output_mode == .Obj;
-            if (uav.offset == 0) {
+            const comp = wasm.base.comp;
+            const is_obj = comp.config.output_mode == .Obj;
+            const zcu = comp.zcu.?;
+            const ip = &zcu.intern_pool;
+            if (ip.isFunctionType(ip.typeOf(uav.ip_index))) {
+                assert(uav.offset == 0);
+                const gop = try wasm.indirect_function_table.getOrPut(comp.gpa, uav.ip_index);
+                if (!gop.found_existing) gop.value_ptr.* = {};
+                try cg.addInst(.{
+                    .tag = .func_ref,
+                    .data = .{ .indirect_function_table_index = @enumFromInt(gop.index) },
+                });
+            } else if (uav.offset == 0) {
                 try cg.addInst(.{
                     .tag = .uav_ref,
                     .data = if (is_obj) .{
src/arch/wasm/Emit.zig
@@ -76,7 +76,16 @@ pub fn lowerToCode(emit: *Emit) Error!void {
             inst += 1;
             continue :loop tags[inst];
         },
-
+        .func_ref => {
+            code.appendAssumeCapacity(@intFromEnum(std.wasm.Opcode.i32_const));
+            if (is_obj) {
+                @panic("TODO");
+            } else {
+                leb.writeUleb128(code.fixedWriter(), @intFromEnum(datas[inst].indirect_function_table_index)) catch unreachable;
+            }
+            inst += 1;
+            continue :loop tags[inst];
+        },
         .dbg_line => {
             inst += 1;
             continue :loop tags[inst];
@@ -938,40 +947,23 @@ fn navRefOff(wasm: *Wasm, code: *std.ArrayListUnmanaged(u8), data: Mir.NavRefOff
     const gpa = comp.gpa;
     const is_obj = comp.config.output_mode == .Obj;
     const nav_ty = ip.getNav(data.nav_index).typeOf(ip);
+    assert(!ip.isFunctionType(nav_ty));
 
     try code.ensureUnusedCapacity(gpa, 11);
 
-    if (ip.isFunctionType(nav_ty)) {
-        code.appendAssumeCapacity(@intFromEnum(std.wasm.Opcode.i32_const));
-        assert(data.offset == 0);
-        if (is_obj) {
-            try wasm.out_relocs.append(gpa, .{
-                .offset = @intCast(code.items.len),
-                .pointee = .{ .symbol_index = try wasm.navSymbolIndex(data.nav_index) },
-                .tag = .TABLE_INDEX_SLEB,
-                .addend = data.offset,
-            });
-            code.appendNTimesAssumeCapacity(0, 5);
-        } else {
-            const function_imports_len: u32 = @intCast(wasm.function_imports.entries.len);
-            const func_index = Wasm.FunctionIndex.fromIpNav(wasm, data.nav_index).?;
-            leb.writeUleb128(code.fixedWriter(), function_imports_len + @intFromEnum(func_index)) catch unreachable;
-        }
+    const opcode: std.wasm.Opcode = if (is_wasm32) .i32_const else .i64_const;
+    code.appendAssumeCapacity(@intFromEnum(opcode));
+    if (is_obj) {
+        try wasm.out_relocs.append(gpa, .{
+            .offset = @intCast(code.items.len),
+            .pointee = .{ .symbol_index = try wasm.navSymbolIndex(data.nav_index) },
+            .tag = if (is_wasm32) .MEMORY_ADDR_LEB else .MEMORY_ADDR_LEB64,
+            .addend = data.offset,
+        });
+        code.appendNTimesAssumeCapacity(0, if (is_wasm32) 5 else 10);
     } else {
-        const opcode: std.wasm.Opcode = if (is_wasm32) .i32_const else .i64_const;
-        code.appendAssumeCapacity(@intFromEnum(opcode));
-        if (is_obj) {
-            try wasm.out_relocs.append(gpa, .{
-                .offset = @intCast(code.items.len),
-                .pointee = .{ .symbol_index = try wasm.navSymbolIndex(data.nav_index) },
-                .tag = if (is_wasm32) .MEMORY_ADDR_LEB else .MEMORY_ADDR_LEB64,
-                .addend = data.offset,
-            });
-            code.appendNTimesAssumeCapacity(0, if (is_wasm32) 5 else 10);
-        } else {
-            const addr = wasm.navAddr(data.nav_index);
-            leb.writeUleb128(code.fixedWriter(), @as(u32, @intCast(@as(i64, addr) + data.offset))) catch unreachable;
-        }
+        const addr = wasm.navAddr(data.nav_index);
+        leb.writeUleb128(code.fixedWriter(), @as(u32, @intCast(@as(i64, addr) + data.offset))) catch unreachable;
     }
 }
 
src/arch/wasm/Mir.zig
@@ -65,9 +65,7 @@ pub const Inst = struct {
         /// Lowers to an i32_const (wasm32) or i64_const (wasm64) which is the
         /// memory address of a named constant.
         ///
-        /// When this refers to a function, this always lowers to an i32_const
-        /// which is the function index. When emitting an object file, this
-        /// adds a `Wasm.Relocation.Tag.TABLE_INDEX_SLEB` relocation.
+        /// May not refer to a function.
         ///
         /// Uses `nav_index`.
         nav_ref,
@@ -75,10 +73,15 @@ pub const Inst = struct {
         /// memory address of named constant, offset by an integer value.
         /// When emitting an object file, this adds a relocation.
         ///
-        /// This may not refer to a function.
+        /// May not refer to a function.
         ///
         /// Uses `payload` pointing to a `NavRefOff`.
         nav_ref_off,
+        /// Lowers to an i32_const which is the index of the function in the
+        /// table section.
+        ///
+        /// Uses `indirect_function_table_index`.
+        func_ref,
         /// Inserts debug information about the current line and column
         /// of the source code
         ///
@@ -88,12 +91,6 @@ pub const Inst = struct {
         /// names.
         /// Uses `tag`.
         errors_len,
-        /// Lowers to an i32_const (wasm32) or i64_const (wasm64) containing
-        /// the base address of the table of error code names, with each
-        /// element being a null-terminated slice.
-        ///
-        /// Uses `tag`.
-        error_name_table_ref,
         /// Represents the end of a function body or an initialization expression
         ///
         /// Uses `tag` (no additional data).
@@ -115,6 +112,12 @@ pub const Inst = struct {
         ///
         /// Uses `tag`.
         @"return" = 0x0F,
+        /// Lowers to an i32_const (wasm32) or i64_const (wasm64) containing
+        /// the base address of the table of error code names, with each
+        /// element being a null-terminated slice.
+        ///
+        /// Uses `tag`.
+        error_name_table_ref,
         /// Calls a function using `nav_index`.
         call_nav,
         /// Calls a function pointer by its function signature
@@ -612,6 +615,7 @@ pub const Inst = struct {
         intrinsic: Intrinsic,
         uav_obj: Wasm.UavsObjIndex,
         uav_exe: Wasm.UavsExeIndex,
+        indirect_function_table_index: Wasm.IndirectFunctionTableIndex,
 
         comptime {
             switch (builtin.mode) {
src/link/Wasm/Flush.zig
@@ -33,8 +33,6 @@ missing_exports: std.AutoArrayHashMapUnmanaged(String, void) = .empty,
 function_imports: std.AutoArrayHashMapUnmanaged(String, Wasm.FunctionImportId) = .empty,
 global_imports: std.AutoArrayHashMapUnmanaged(String, Wasm.GlobalImportId) = .empty,
 
-indirect_function_table: std.AutoArrayHashMapUnmanaged(Wasm.OutputFunctionIndex, u32) = .empty,
-
 /// For debug purposes only.
 memory_layout_finished: bool = false,
 
@@ -42,7 +40,6 @@ pub fn clear(f: *Flush) void {
     f.data_segments.clearRetainingCapacity();
     f.data_segment_groups.clearRetainingCapacity();
     f.binary_bytes.clearRetainingCapacity();
-    f.indirect_function_table.clearRetainingCapacity();
     f.memory_layout_finished = false;
 }
 
@@ -53,7 +50,6 @@ pub fn deinit(f: *Flush, gpa: Allocator) void {
     f.missing_exports.deinit(gpa);
     f.function_imports.deinit(gpa);
     f.global_imports.deinit(gpa);
-    f.indirect_function_table.deinit(gpa);
     f.* = undefined;
 }
 
@@ -72,10 +68,6 @@ pub fn finish(f: *Flush, wasm: *Wasm) !void {
     };
     const is_obj = comp.config.output_mode == .Obj;
     const allow_undefined = is_obj or wasm.import_symbols;
-    //const undef_byte: u8 = switch (comp.root_mod.optimize_mode) {
-    //    .Debug, .ReleaseSafe => 0xaa,
-    //    .ReleaseFast, .ReleaseSmall => 0x00,
-    //};
 
     if (comp.zcu) |zcu| {
         const ip: *const InternPool = &zcu.intern_pool; // No mutations allowed!
@@ -215,6 +207,12 @@ pub fn finish(f: *Flush, wasm: *Wasm) !void {
         wasm.functions.putAssumeCapacity(.__wasm_init_tls, {});
     }
 
+    try wasm.tables.ensureUnusedCapacity(gpa, 1);
+
+    if (wasm.indirect_function_table.entries.len > 0) {
+        wasm.tables.putAssumeCapacity(.__indirect_function_table, {});
+    }
+
     // Sort order:
     // 0. Segment category (tls, data, zero)
     // 1. Segment name prefix
@@ -642,34 +640,31 @@ pub fn finish(f: *Flush, wasm: *Wasm) !void {
         replaceVecSectionHeader(binary_bytes, header_offset, .start, @intFromEnum(func_index));
     }
 
-    // element section (function table)
-    if (f.indirect_function_table.count() > 0) {
-        @panic("TODO");
-        //const header_offset = try reserveVecSectionHeader(gpa, binary_bytes);
-
-        //const table_loc = wasm.globals.get(wasm.preloaded_strings.__indirect_function_table).?;
-        //const table_sym = wasm.finalSymbolByLoc(table_loc);
+    // element section
+    if (wasm.indirect_function_table.entries.len > 0) {
+        const header_offset = try reserveVecSectionHeader(gpa, binary_bytes);
 
-        //const flags: u32 = if (table_sym.index == 0) 0x0 else 0x02; // passive with implicit 0-index table or set table index manually
-        //try leb.writeUleb128(binary_writer, flags);
-        //if (flags == 0x02) {
-        //    try leb.writeUleb128(binary_writer, table_sym.index);
-        //}
-        //try emitInit(binary_writer, .{ .i32_const = 1 }); // We start at index 1, so unresolved function pointers are invalid
-        //if (flags == 0x02) {
-        //    try leb.writeUleb128(binary_writer, @as(u8, 0)); // represents funcref
-        //}
-        //try leb.writeUleb128(binary_writer, @as(u32, @intCast(f.indirect_function_table.count())));
-        //var symbol_it = f.indirect_function_table.keyIterator();
-        //while (symbol_it.next()) |symbol_loc_ptr| {
-        //    const sym = wasm.finalSymbolByLoc(symbol_loc_ptr.*);
-        //    assert(sym.flags.alive);
-        //    assert(sym.index < wasm.functions.count() + wasm.imported_functions_count);
-        //    try leb.writeUleb128(binary_writer, sym.index);
-        //}
+        // indirect function table elements
+        const table_index: u32 = @intCast(wasm.tables.getIndex(.__indirect_function_table).?);
+        // passive with implicit 0-index table or set table index manually
+        const flags: u32 = if (table_index == 0) 0x0 else 0x02;
+        try leb.writeUleb128(binary_writer, flags);
+        if (flags == 0x02) {
+            try leb.writeUleb128(binary_writer, table_index);
+        }
+        // We start at index 1, so unresolved function pointers are invalid
+        try emitInit(binary_writer, .{ .i32_const = 1 });
+        if (flags == 0x02) {
+            try leb.writeUleb128(binary_writer, @as(u8, 0)); // represents funcref
+        }
+        try leb.writeUleb128(binary_writer, @as(u32, @intCast(wasm.indirect_function_table.entries.len)));
+        for (wasm.indirect_function_table.keys()) |ip_index| {
+            const func_index: Wasm.OutputFunctionIndex = .fromIpIndex(wasm, ip_index);
+            try leb.writeUleb128(binary_writer, @intFromEnum(func_index));
+        }
 
-        //replaceVecSectionHeader(binary_bytes, header_offset, .element, 1);
-        //section_index += 1;
+        replaceVecSectionHeader(binary_bytes, header_offset, .element, 1);
+        section_index += 1;
     }
 
     // When the shared-memory option is enabled, we *must* emit the 'data count' section.
src/link/Wasm.zig
@@ -235,6 +235,10 @@ global_imports: std.AutoArrayHashMapUnmanaged(String, GlobalImportId) = .empty,
 tables: std.AutoArrayHashMapUnmanaged(TableImport.Resolution, void) = .empty,
 table_imports: std.AutoArrayHashMapUnmanaged(String, TableImport.Index) = .empty,
 
+/// All functions that have had their address taken and therefore might be
+/// called via a `call_indirect` function.
+indirect_function_table: std.AutoArrayHashMapUnmanaged(InternPool.Index, void) = .empty,
+
 error_name_table_ref_count: u32 = 0,
 
 /// Set to true if any `GLOBAL_INDEX` relocation is encountered with
@@ -260,6 +264,11 @@ error_name_bytes: std.ArrayListUnmanaged(u8) = .empty,
 /// is stored. No need to serialize; trivially reconstructed.
 error_name_offs: std.ArrayListUnmanaged(u32) = .empty,
 
+/// Index into `Wasm.indirect_function_table`.
+pub const IndirectFunctionTableIndex = enum(u32) {
+    _,
+};
+
 pub const UavFixup = extern struct {
     uavs_exe_index: UavsExeIndex,
     /// Index into `string_bytes`.
@@ -335,17 +344,24 @@ pub const OutputFunctionIndex = enum(u32) {
         return @enumFromInt(wasm.function_imports.entries.len + @intFromEnum(index));
     }
 
+    pub fn fromIpIndex(wasm: *const Wasm, ip_index: InternPool.Index) OutputFunctionIndex {
+        const zcu = wasm.base.comp.zcu.?;
+        const ip = &zcu.intern_pool;
+        return switch (ip.indexToKey(ip_index)) {
+            .@"extern" => |ext| {
+                const name = wasm.getExistingString(ext.name.toSlice(ip)).?;
+                if (wasm.function_imports.getIndex(name)) |i| return @enumFromInt(i);
+                return fromFunctionIndex(wasm, FunctionIndex.fromSymbolName(wasm, name).?);
+            },
+            else => fromResolution(wasm, .fromIpIndex(wasm, ip_index)).?,
+        };
+    }
+
     pub fn fromIpNav(wasm: *const Wasm, nav_index: InternPool.Nav.Index) OutputFunctionIndex {
         const zcu = wasm.base.comp.zcu.?;
         const ip = &zcu.intern_pool;
         const nav = ip.getNav(nav_index);
-        if (nav.toExtern(ip)) |ext| {
-            const name = wasm.getExistingString(ext.name.toSlice(ip)).?;
-            if (wasm.function_imports.getIndex(name)) |i| return @enumFromInt(i);
-            return fromFunctionIndex(wasm, FunctionIndex.fromSymbolName(wasm, name).?);
-        } else {
-            return fromFunctionIndex(wasm, FunctionIndex.fromIpNav(wasm, nav_index).?);
-        }
+        return fromIpIndex(wasm, nav.status.resolved.val);
     }
 
     pub fn fromTagNameType(wasm: *const Wasm, tag_type: InternPool.Index) OutputFunctionIndex {
@@ -894,11 +910,11 @@ pub const FunctionImport = extern struct {
         pub fn fromIpNav(wasm: *const Wasm, nav_index: InternPool.Nav.Index) Resolution {
             const zcu = wasm.base.comp.zcu.?;
             const ip = &zcu.intern_pool;
-            const nav = ip.getNav(nav_index);
-            //log.debug("Resolution.fromIpNav {}({})", .{ nav.fqn.fmt(ip), nav_index });
-            return pack(wasm, .{
-                .zcu_func = @enumFromInt(wasm.zcu_funcs.getIndex(nav.status.resolved.val).?),
-            });
+            return fromIpIndex(wasm, ip.getNav(nav_index).status.resolved.val);
+        }
+
+        pub fn fromIpIndex(wasm: *const Wasm, ip_index: InternPool.Index) Resolution {
+            return pack(wasm, .{ .zcu_func = @enumFromInt(wasm.zcu_funcs.getIndex(ip_index).?) });
         }
 
         pub fn isNavOrUnresolved(r: Resolution, wasm: *const Wasm) bool {
@@ -1168,7 +1184,7 @@ pub const TableImport = extern struct {
         pub fn refType(r: Resolution, wasm: *const Wasm) std.wasm.RefType {
             return switch (unpack(r)) {
                 .unresolved => unreachable,
-                .__indirect_function_table => @panic("TODO"),
+                .__indirect_function_table => .funcref,
                 .object_table => |i| i.ptr(wasm).flags.ref_type.to(),
             };
         }
@@ -1176,7 +1192,11 @@ pub const TableImport = extern struct {
         pub fn limits(r: Resolution, wasm: *const Wasm) std.wasm.Limits {
             return switch (unpack(r)) {
                 .unresolved => unreachable,
-                .__indirect_function_table => @panic("TODO"),
+                .__indirect_function_table => .{
+                    .flags = .{ .has_max = true, .is_shared = false },
+                    .min = @intCast(wasm.indirect_function_table.entries.len + 1),
+                    .max = @intCast(wasm.indirect_function_table.entries.len + 1),
+                },
                 .object_table => |i| i.ptr(wasm).limits(),
             };
         }
@@ -2370,10 +2390,12 @@ pub fn deinit(wasm: *Wasm) void {
     wasm.global_exports.deinit(gpa);
     wasm.global_imports.deinit(gpa);
     wasm.table_imports.deinit(gpa);
+    wasm.tables.deinit(gpa);
     wasm.symbol_table.deinit(gpa);
     wasm.out_relocs.deinit(gpa);
     wasm.uav_fixups.deinit(gpa);
     wasm.nav_fixups.deinit(gpa);
+    wasm.indirect_function_table.deinit(gpa);
 
     wasm.string_bytes.deinit(gpa);
     wasm.string_table.deinit(gpa);