Commit f78f9388fe

mlugg <mlugg@mlugg.co.uk>
2025-01-12 21:55:30
Sema: allow tail calls of function pointers
Resolves: #22474
1 parent 15fe999
Changed files (4)
src/Sema.zig
@@ -7983,7 +7983,13 @@ fn analyzeCall(
         }
 
         if (call_tag == .call_always_tail) {
-            return sema.handleTailCall(block, call_src, sema.typeOf(runtime_func), result);
+            const func_or_ptr_ty = sema.typeOf(runtime_func);
+            const runtime_func_ty = switch (func_or_ptr_ty.zigTypeTag(zcu)) {
+                .@"fn" => func_or_ptr_ty,
+                .pointer => func_or_ptr_ty.childType(zcu),
+                else => unreachable,
+            };
+            return sema.handleTailCall(block, call_src, runtime_func_ty, result);
         }
 
         if (resolved_ret_ty.toIntern() == .noreturn_type) {
test/behavior/call.zig
@@ -651,3 +651,106 @@ test "function call with cast to anyopaque pointer" {
     };
     Foo.bar(Foo.t);
 }
+
+test "arguments pointed to on stack into tailcall" {
+    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_c and builtin.os.tag == .windows) return error.SkipZigTest; // MSVC doesn't support always tail calls
+
+    switch (builtin.cpu.arch) {
+        .wasm32,
+        .mips,
+        .mipsel,
+        .mips64,
+        .mips64el,
+        .powerpc,
+        .powerpcle,
+        .powerpc64,
+        .powerpc64le,
+        => return error.SkipZigTest,
+        else => {},
+    }
+
+    const S = struct {
+        var base: usize = undefined;
+        var result_off: [7]usize = undefined;
+        var result_len: [7]usize = undefined;
+        var result_index: usize = 0;
+
+        noinline fn insertionSort(data: []u64) void {
+            result_off[result_index] = @intFromPtr(data.ptr) - base;
+            result_len[result_index] = data.len;
+            result_index += 1;
+            if (data.len > 1) {
+                var least_i: usize = 0;
+                var i: usize = 1;
+                while (i < data.len) : (i += 1) {
+                    if (data[i] < data[least_i])
+                        least_i = i;
+                }
+                std.mem.swap(u64, &data[0], &data[least_i]);
+
+                // there used to be a bug where
+                // `data[1..]` is created on the stack
+                // and pointed to by the first argument register
+                // then stack is invalidated by the tailcall and
+                // overwritten by callee
+                // https://github.com/ziglang/zig/issues/9703
+                return @call(.always_tail, insertionSort, .{data[1..]});
+            }
+        }
+    };
+
+    var data = [_]u64{ 1, 6, 2, 7, 1, 9, 3 };
+    S.base = @intFromPtr(&data);
+    S.insertionSort(data[0..]);
+    try expect(S.result_len[0] == 7);
+    try expect(S.result_len[1] == 6);
+    try expect(S.result_len[2] == 5);
+    try expect(S.result_len[3] == 4);
+    try expect(S.result_len[4] == 3);
+    try expect(S.result_len[5] == 2);
+    try expect(S.result_len[6] == 1);
+
+    try expect(S.result_off[0] == 0);
+    try expect(S.result_off[1] == 8);
+    try expect(S.result_off[2] == 16);
+    try expect(S.result_off[3] == 24);
+    try expect(S.result_off[4] == 32);
+    try expect(S.result_off[5] == 40);
+    try expect(S.result_off[6] == 48);
+}
+
+test "tail call function pointer" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
+
+    if (builtin.zig_backend == .stage2_llvm) {
+        if (builtin.cpu.arch.isMIPS() or builtin.cpu.arch.isPowerPC() or builtin.cpu.arch.isWasm()) {
+            return error.SkipZigTest;
+        }
+    }
+
+    if (builtin.zig_backend == .stage2_c and builtin.os.tag == .windows) return error.SkipZigTest; // MSVC doesn't support always tail calls
+
+    const S = struct {
+        fn foo(n: u8) void {
+            if (n == 0) return;
+            const other: *const fn (u8) void = &bar;
+            return @call(.always_tail, other, .{n - 1});
+        }
+        fn bar(n: u8) void {
+            var other: *const fn (u8) void = undefined;
+            other = &foo; // runtime-known pointer
+            return @call(.always_tail, other, .{n});
+        }
+    };
+
+    S.foo(100);
+}
test/behavior/call_tail.zig
@@ -1,72 +0,0 @@
-const builtin = @import("builtin");
-const std = @import("std");
-const expect = std.testing.expect;
-
-var base: usize = undefined;
-var result_off: [7]usize = undefined;
-var result_len: [7]usize = undefined;
-var result_index: usize = 0;
-
-noinline fn insertionSort(data: []u64) void {
-    result_off[result_index] = @intFromPtr(data.ptr) - base;
-    result_len[result_index] = data.len;
-    result_index += 1;
-    if (data.len > 1) {
-        var least_i: usize = 0;
-        var i: usize = 1;
-        while (i < data.len) : (i += 1) {
-            if (data[i] < data[least_i])
-                least_i = i;
-        }
-        std.mem.swap(u64, &data[0], &data[least_i]);
-
-        // there used to be a bug where
-        // `data[1..]` is created on the stack
-        // and pointed to by the first argument register
-        // then stack is invalidated by the tailcall and
-        // overwritten by callee
-        // https://github.com/ziglang/zig/issues/9703
-        return @call(.always_tail, insertionSort, .{data[1..]});
-    }
-}
-
-test "arguments pointed to on stack into tailcall" {
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
-
-    switch (builtin.cpu.arch) {
-        .wasm32,
-        .mips,
-        .mipsel,
-        .mips64,
-        .mips64el,
-        .powerpc,
-        .powerpcle,
-        .powerpc64,
-        .powerpc64le,
-        => return error.SkipZigTest,
-        else => {},
-    }
-    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
-    if (builtin.zig_backend == .stage2_c and builtin.os.tag == .windows) return error.SkipZigTest; // MSVC doesn't support always tail calls
-
-    var data = [_]u64{ 1, 6, 2, 7, 1, 9, 3 };
-    base = @intFromPtr(&data);
-    insertionSort(data[0..]);
-    try expect(result_len[0] == 7);
-    try expect(result_len[1] == 6);
-    try expect(result_len[2] == 5);
-    try expect(result_len[3] == 4);
-    try expect(result_len[4] == 3);
-    try expect(result_len[5] == 2);
-    try expect(result_len[6] == 1);
-
-    try expect(result_off[0] == 0);
-    try expect(result_off[1] == 8);
-    try expect(result_off[2] == 16);
-    try expect(result_off[3] == 24);
-    try expect(result_off[4] == 32);
-    try expect(result_off[5] == 40);
-    try expect(result_off[6] == 48);
-}
test/behavior.zig
@@ -17,7 +17,6 @@ test {
     _ = @import("behavior/byteswap.zig");
     _ = @import("behavior/byval_arg_var.zig");
     _ = @import("behavior/call.zig");
-    _ = @import("behavior/call_tail.zig");
     _ = @import("behavior/cast.zig");
     _ = @import("behavior/cast_int.zig");
     _ = @import("behavior/comptime_memory.zig");