Commit 67a44211f7

Veikka Tuominen <git@vexu.eu>
2022-08-30 13:04:13
Sema: improve handling of always_tail call modifier
Closes #4301 Closes #5692 Closes #6281 Closes #10786 Closes #11149 Closes #11776
1 parent 0a42602
Changed files (5)
src
test
src/codegen/llvm.zig
@@ -4522,7 +4522,7 @@ pub const FuncGen = struct {
             "",
         );
 
-        if (return_type.isNoReturn()) {
+        if (return_type.isNoReturn() and attr != .AlwaysTail) {
             _ = self.builder.buildUnreachable();
             return null;
         }
src/Sema.zig
@@ -6152,9 +6152,23 @@ fn analyzeCall(
     if (ensure_result_used) {
         try sema.ensureResultUsed(block, result, call_src);
     }
+    if (call_tag == .call_always_tail) {
+        return sema.handleTailCall(block, call_src, func_ty, result);
+    }
     return result;
 }
 
+fn handleTailCall(sema: *Sema, block: *Block, call_src: LazySrcLoc, func_ty: Type, result: Air.Inst.Ref) !Air.Inst.Ref {
+    const func_decl = sema.mod.declPtr(sema.owner_func.?.owner_decl);
+    if (!func_ty.eql(func_decl.ty, sema.mod)) {
+        return sema.fail(block, call_src, "unable to perform tail call: type of function being called '{}' does not match type of calling function '{}'", .{
+            func_ty.fmt(sema.mod), func_decl.ty.fmt(sema.mod),
+        });
+    }
+    _ = try block.addUnOp(.ret, result);
+    return Air.Inst.Ref.unreachable_value;
+}
+
 fn analyzeInlineCallArg(
     sema: *Sema,
     arg_block: *Block,
@@ -6670,7 +6684,8 @@ fn instantiateGenericCall(
     try sema.requireFunctionBlock(block, call_src);
 
     const comptime_args = callee.comptime_args.?;
-    const new_fn_info = mod.declPtr(callee.owner_decl).ty.fnInfo();
+    const func_ty = mod.declPtr(callee.owner_decl).ty;
+    const new_fn_info = func_ty.fnInfo();
     const runtime_args_len = @intCast(u32, new_fn_info.param_types.len);
     const runtime_args = try sema.arena.alloc(Air.Inst.Ref, runtime_args_len);
     {
@@ -6717,7 +6732,7 @@ fn instantiateGenericCall(
 
     try sema.air_extra.ensureUnusedCapacity(sema.gpa, @typeInfo(Air.Call).Struct.fields.len +
         runtime_args_len);
-    const func_inst = try block.addInst(.{
+    const result = try block.addInst(.{
         .tag = call_tag,
         .data = .{ .pl_op = .{
             .operand = callee_inst,
@@ -6729,9 +6744,12 @@ fn instantiateGenericCall(
     sema.appendRefsAssumeCapacity(runtime_args);
 
     if (ensure_result_used) {
-        try sema.ensureResultUsed(block, func_inst, call_src);
+        try sema.ensureResultUsed(block, result, call_src);
     }
-    return func_inst;
+    if (call_tag == .call_always_tail) {
+        return sema.handleTailCall(block, call_src, func_ty, result);
+    }
+    return result;
 }
 
 fn emitDbgInline(
@@ -19262,7 +19280,7 @@ fn resolveCallOptions(
             return wanted_modifier;
         },
         // These can be upgraded to comptime. nosuspend bit can be safely ignored.
-        .always_tail, .always_inline, .compile_time => {
+        .always_inline, .compile_time => {
             _ = (try sema.resolveDefinedValue(block, func_src, func)) orelse {
                 return sema.fail(block, func_src, "modifier '{s}' requires a comptime-known function", .{@tagName(wanted_modifier)});
             };
@@ -19272,6 +19290,12 @@ fn resolveCallOptions(
             }
             return wanted_modifier;
         },
+        .always_tail => {
+            if (is_comptime) {
+                return .compile_time;
+            }
+            return wanted_modifier;
+        },
         .async_kw => {
             if (is_nosuspend) {
                 return sema.fail(block, modifier_src, "modifier 'async_kw' cannot be used inside nosuspend block", .{});
test/behavior/call.zig
@@ -261,3 +261,57 @@ test "arguments to comptime parameters generated in comptime blocks" {
     };
     S.foo(S.fortyTwo());
 }
+
+test "forced tail call" {
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+
+    const S = struct {
+        fn fibonacciTailInternal(n: u16, a: u16, b: u16) u16 {
+            if (n == 0) return a;
+            if (n == 1) return b;
+            return @call(
+                .{ .modifier = .always_tail },
+                fibonacciTailInternal,
+                .{ n - 1, b, a + b },
+            );
+        }
+
+        fn fibonacciTail(n: u16) u16 {
+            return fibonacciTailInternal(n, 0, 1);
+        }
+    };
+    try expect(S.fibonacciTail(10) == 55);
+}
+
+test "inline call preserves tail call" {
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+
+    const max = std.math.maxInt(u16);
+    const S = struct {
+        var a: u16 = 0;
+        fn foo() void {
+            return bar();
+        }
+
+        inline fn bar() void {
+            if (a == max) return;
+            // Stack overflow if not tail called
+            var buf: [max]u16 = undefined;
+            buf[a] = a;
+            a += 1;
+            return @call(.{ .modifier = .always_tail }, foo, .{});
+        }
+    };
+    S.foo();
+    try expect(S.a == std.math.maxInt(u16));
+}
test/cases/compile_errors/invalid_tail_call.zig
@@ -0,0 +1,12 @@
+fn myFn(_: usize) void {
+    return;
+}
+pub export fn entry() void {
+    @call(.{ .modifier = .always_tail }, myFn, .{0});
+}
+
+// error
+// backend=llvm
+// target=native
+//
+// :5:5: error: unable to perform tail call: type of function being called 'fn(usize) void' does not match type of calling function 'fn() callconv(.C) void'
test/cases/taill_call_noreturn.zig
@@ -0,0 +1,18 @@
+const std = @import("std");
+const builtin = std.builtin;
+pub fn foo(message: []const u8, stack_trace: ?*builtin.StackTrace) noreturn {
+    @call(.{ .modifier = .always_tail }, bar, .{ message, stack_trace });
+}
+pub fn bar(message: []const u8, stack_trace: ?*builtin.StackTrace) noreturn {
+    _ = message;
+    _ = stack_trace;
+    std.process.exit(0);
+}
+
+pub fn main() void {
+    foo("foo", null);
+}
+
+// run
+// backend=llvm
+// target=native