Commit 67afd2a470

Veikka Tuominen <git@vexu.eu>
2023-05-10 11:27:59
Sema: make `@call` compile errors match regular calls
Closes #15642
1 parent 73f283e
Changed files (2)
src
test
src/Sema.zig
@@ -5217,7 +5217,7 @@ fn zirPanic(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Zir.I
     if (block.is_comptime) {
         return sema.fail(block, src, "encountered @panic at comptime", .{});
     }
-    try sema.panicWithMsg(block, src, msg_inst);
+    try sema.panicWithMsg(block, msg_inst);
     return always_noreturn;
 }
 
@@ -6295,7 +6295,6 @@ fn zirCall(
     } else {
         resolved_args = try sema.arena.alloc(Air.Inst.Ref, args_len);
     }
-    const total_args = args_len + @boolToInt(bound_arg_src != null);
 
     const callee_ty = sema.typeOf(func);
     const func_ty = func_ty: {
@@ -6311,45 +6310,16 @@ fn zirCall(
         }
         return sema.fail(block, func_src, "type '{}' not a function", .{callee_ty.fmt(sema.mod)});
     };
-    const func_ty_info = func_ty.fnInfo();
-
-    const fn_params_len = func_ty_info.param_types.len;
-    check_args: {
-        if (func_ty_info.is_var_args) {
-            assert(func_ty_info.cc == .C);
-            if (total_args >= fn_params_len) break :check_args;
-        } else if (fn_params_len == total_args) {
-            break :check_args;
-        }
-
-        const maybe_decl = try sema.funcDeclSrc(func);
-        const member_str = if (bound_arg_src != null) "member function " else "";
-        const variadic_str = if (func_ty_info.is_var_args) "at least " else "";
-        const msg = msg: {
-            const msg = try sema.errMsg(
-                block,
-                func_src,
-                "{s}expected {s}{d} argument(s), found {d}",
-                .{
-                    member_str,
-                    variadic_str,
-                    fn_params_len - @boolToInt(bound_arg_src != null),
-                    args_len,
-                },
-            );
-            errdefer msg.destroy(sema.gpa);
-
-            if (maybe_decl) |fn_decl| try sema.mod.errNoteNonLazy(fn_decl.srcLoc(), msg, "function declared here", .{});
-            break :msg msg;
-        };
-        return sema.failWithOwnedErrorMsg(msg);
-    }
+    const total_args = args_len + @boolToInt(bound_arg_src != null);
+    try sema.checkCallArgumentCount(block, func, func_src, func_ty, total_args, bound_arg_src != null);
 
     const args_body = sema.code.extra[extra.end..];
 
     var input_is_error = false;
     const block_index = @intCast(Air.Inst.Index, block.instructions.items.len);
 
+    const func_ty_info = func_ty.fnInfo();
+    const fn_params_len = func_ty_info.param_types.len;
     const parent_comptime = block.is_comptime;
     // `extra_index` and `arg_index` are separate since the bound function is passed as the first argument.
     var extra_index: usize = 0;
@@ -6398,7 +6368,7 @@ fn zirCall(
         !block.is_comptime and !block.is_typeof and (input_is_error or pop_error_return_trace))
     {
         const call_inst: Air.Inst.Ref = if (modifier == .always_tail) undefined else b: {
-            break :b try sema.analyzeCall(block, func, func_src, call_src, modifier, ensure_result_used, resolved_args, bound_arg_src, call_dbg_node);
+            break :b try sema.analyzeCall(block, func, func_ty, func_src, call_src, modifier, ensure_result_used, resolved_args, bound_arg_src, call_dbg_node);
         };
 
         const return_ty = sema.typeOf(call_inst);
@@ -6427,12 +6397,84 @@ fn zirCall(
         }
 
         if (modifier == .always_tail) // Perform the call *after* the restore, so that a tail call is possible.
-            return sema.analyzeCall(block, func, func_src, call_src, modifier, ensure_result_used, resolved_args, bound_arg_src, call_dbg_node);
+            return sema.analyzeCall(block, func, func_ty, func_src, call_src, modifier, ensure_result_used, resolved_args, bound_arg_src, call_dbg_node);
 
         return call_inst;
     } else {
-        return sema.analyzeCall(block, func, func_src, call_src, modifier, ensure_result_used, resolved_args, bound_arg_src, call_dbg_node);
+        return sema.analyzeCall(block, func, func_ty, func_src, call_src, modifier, ensure_result_used, resolved_args, bound_arg_src, call_dbg_node);
+    }
+}
+
+fn checkCallArgumentCount(
+    sema: *Sema,
+    block: *Block,
+    func: Air.Inst.Ref,
+    func_src: LazySrcLoc,
+    func_ty: Type,
+    total_args: usize,
+    member_fn: bool,
+) !void {
+    const func_ty_info = func_ty.fnInfo();
+    const fn_params_len = func_ty_info.param_types.len;
+    const args_len = total_args - @boolToInt(member_fn);
+    if (func_ty_info.is_var_args) {
+        assert(func_ty_info.cc == .C);
+        if (total_args >= fn_params_len) return;
+    } else if (fn_params_len == total_args) {
+        return;
     }
+
+    const maybe_decl = try sema.funcDeclSrc(func);
+    const member_str = if (member_fn) "member function " else "";
+    const variadic_str = if (func_ty_info.is_var_args) "at least " else "";
+    const msg = msg: {
+        const msg = try sema.errMsg(
+            block,
+            func_src,
+            "{s}expected {s}{d} argument(s), found {d}",
+            .{
+                member_str,
+                variadic_str,
+                fn_params_len - @boolToInt(member_fn),
+                args_len,
+            },
+        );
+        errdefer msg.destroy(sema.gpa);
+
+        if (maybe_decl) |fn_decl| try sema.mod.errNoteNonLazy(fn_decl.srcLoc(), msg, "function declared here", .{});
+        break :msg msg;
+    };
+    return sema.failWithOwnedErrorMsg(msg);
+}
+
+fn callBuiltin(
+    sema: *Sema,
+    block: *Block,
+    builtin_fn: Air.Inst.Ref,
+    modifier: std.builtin.CallModifier,
+    args: []const Air.Inst.Ref,
+) !void {
+    const callee_ty = sema.typeOf(builtin_fn);
+    const func_ty = func_ty: {
+        switch (callee_ty.zigTypeTag()) {
+            .Fn => break :func_ty callee_ty,
+            .Pointer => {
+                const ptr_info = callee_ty.ptrInfo().data;
+                if (ptr_info.size == .One and ptr_info.pointee_type.zigTypeTag() == .Fn) {
+                    break :func_ty ptr_info.pointee_type;
+                }
+            },
+            else => {},
+        }
+        std.debug.panic("type '{}' is not a function calling builtin fn", .{callee_ty.fmt(sema.mod)});
+    };
+
+    const func_ty_info = func_ty.fnInfo();
+    const fn_params_len = func_ty_info.param_types.len;
+    if (args.len != fn_params_len or (func_ty_info.is_var_args and args.len < fn_params_len)) {
+        std.debug.panic("parameter count mismatch calling builtin fn, expected {d}, found {d}", .{ fn_params_len, args.len });
+    }
+    _ = try sema.analyzeCall(block, builtin_fn, func_ty, sema.src, sema.src, modifier, false, args, null, null);
 }
 
 const GenericCallAdapter = struct {
@@ -6509,6 +6551,7 @@ fn analyzeCall(
     sema: *Sema,
     block: *Block,
     func: Air.Inst.Ref,
+    func_ty: Type,
     func_src: LazySrcLoc,
     call_src: LazySrcLoc,
     modifier: std.builtin.CallModifier,
@@ -6519,22 +6562,10 @@ fn analyzeCall(
 ) CompileError!Air.Inst.Ref {
     const mod = sema.mod;
 
-    const callee_ty = sema.typeOf(func);
-    const func_ty = func_ty: {
-        switch (callee_ty.zigTypeTag()) {
-            .Fn => break :func_ty callee_ty,
-            .Pointer => {
-                const ptr_info = callee_ty.ptrInfo().data;
-                if (ptr_info.size == .One and ptr_info.pointee_type.zigTypeTag() == .Fn) {
-                    break :func_ty ptr_info.pointee_type;
-                }
-            },
-            else => {},
-        }
-        return sema.fail(block, func_src, "type '{}' is not a function", .{callee_ty.fmt(sema.mod)});
-    };
 
+    const callee_ty = sema.typeOf(func);
     const func_ty_info = func_ty.fnInfo();
+    const fn_params_len = func_ty_info.param_types.len;
     const cc = func_ty_info.cc;
     if (cc == .Naked) {
         const maybe_decl = try sema.funcDeclSrc(func);
@@ -6552,27 +6583,6 @@ fn analyzeCall(
         };
         return sema.failWithOwnedErrorMsg(msg);
     }
-    const fn_params_len = func_ty_info.param_types.len;
-    if (func_ty_info.is_var_args) {
-        assert(cc == .C);
-        if (uncasted_args.len < fn_params_len) {
-            // TODO add error note: declared here
-            return sema.fail(
-                block,
-                func_src,
-                "expected at least {d} argument(s), found {d}",
-                .{ fn_params_len, uncasted_args.len },
-            );
-        }
-    } else if (fn_params_len != uncasted_args.len) {
-        // TODO add error note: declared here
-        return sema.fail(
-            block,
-            call_src,
-            "expected {d} argument(s), found {d}",
-            .{ fn_params_len, uncasted_args.len },
-        );
-    }
 
     const call_tag: Air.Inst.Tag = switch (modifier) {
         .auto,
@@ -11822,9 +11832,6 @@ fn maybeErrorUnwrap(sema: *Sema, block: *Block, body: []const Zir.Inst.Index, op
             .as_node => try sema.zirAsNode(block, inst),
             .field_val => try sema.zirFieldVal(block, inst),
             .@"unreachable" => {
-                const inst_data = sema.code.instructions.items(.data)[inst].@"unreachable";
-                const src = inst_data.src();
-
                 if (!sema.mod.comp.formatted_panics) {
                     try sema.safetyPanic(block, .unwrap_error);
                     return true;
@@ -11833,18 +11840,17 @@ fn maybeErrorUnwrap(sema: *Sema, block: *Block, body: []const Zir.Inst.Index, op
                 const panic_fn = try sema.getBuiltin("panicUnwrapError");
                 const err_return_trace = try sema.getErrorReturnTrace(block);
                 const args: [2]Air.Inst.Ref = .{ err_return_trace, operand };
-                _ = try sema.analyzeCall(block, panic_fn, src, src, .auto, false, &args, null, null);
+                try sema.callBuiltin(block, panic_fn, .auto, &args);
                 return true;
             },
             .panic => {
                 const inst_data = sema.code.instructions.items(.data)[inst].un_node;
-                const src = inst_data.src();
                 const msg_inst = try sema.resolveInst(inst_data.operand);
 
                 const panic_fn = try sema.getBuiltin("panic");
                 const err_return_trace = try sema.getErrorReturnTrace(block);
                 const args: [3]Air.Inst.Ref = .{ msg_inst, err_return_trace, .null_value };
-                _ = try sema.analyzeCall(block, panic_fn, src, src, .auto, false, &args, null, null);
+                try sema.callBuiltin(block, panic_fn, .auto, &args);
                 return true;
             },
             else => unreachable,
@@ -17258,7 +17264,7 @@ fn zirRetLoad(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Zir
 
     if (sema.wantErrorReturnTracing(sema.fn_ret_ty)) {
         const is_non_err = try sema.analyzePtrIsNonErr(block, src, ret_ptr);
-        return sema.retWithErrTracing(block, src, is_non_err, .ret_load, ret_ptr);
+        return sema.retWithErrTracing(block, is_non_err, .ret_load, ret_ptr);
     }
 
     _ = try block.addUnOp(.ret_load, ret_ptr);
@@ -17268,7 +17274,6 @@ fn zirRetLoad(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Zir
 fn retWithErrTracing(
     sema: *Sema,
     block: *Block,
-    src: LazySrcLoc,
     is_non_err: Air.Inst.Ref,
     ret_tag: Air.Inst.Tag,
     operand: Air.Inst.Ref,
@@ -17290,7 +17295,7 @@ fn retWithErrTracing(
     const args: [1]Air.Inst.Ref = .{err_return_trace};
 
     if (!need_check) {
-        _ = try sema.analyzeCall(block, return_err_fn, src, src, .never_inline, false, &args, null, null);
+        try sema.callBuiltin(block, return_err_fn, .never_inline, &args);
         _ = try block.addUnOp(ret_tag, operand);
         return always_noreturn;
     }
@@ -17301,7 +17306,7 @@ fn retWithErrTracing(
 
     var else_block = block.makeSubBlock();
     defer else_block.instructions.deinit(gpa);
-    _ = try sema.analyzeCall(&else_block, return_err_fn, src, src, .never_inline, false, &args, null, null);
+    try sema.callBuiltin(&else_block, return_err_fn, .never_inline, &args);
     _ = try else_block.addUnOp(ret_tag, operand);
 
     try sema.air_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.CondBr).Struct.fields.len +
@@ -17447,7 +17452,7 @@ fn analyzeRet(
         // Avoid adding a frame to the error return trace in case the value is comptime-known
         // to be not an error.
         const is_non_err = try sema.analyzeIsNonErr(block, src, operand);
-        return sema.retWithErrTracing(block, src, is_non_err, .ret, operand);
+        return sema.retWithErrTracing(block, is_non_err, .ret, operand);
     }
 
     _ = try block.addUnOp(.ret, operand);
@@ -21657,8 +21662,25 @@ fn zirBuiltinCall(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
             resolved.* = try sema.tupleFieldValByIndex(block, args_src, args, @intCast(u32, i), args_ty);
         }
     }
+
+    const callee_ty = sema.typeOf(func);
+    const func_ty = func_ty: {
+        switch (callee_ty.zigTypeTag()) {
+            .Fn => break :func_ty callee_ty,
+            .Pointer => {
+                const ptr_info = callee_ty.ptrInfo().data;
+                if (ptr_info.size == .One and ptr_info.pointee_type.zigTypeTag() == .Fn) {
+                    break :func_ty ptr_info.pointee_type;
+                }
+            },
+            else => {},
+        }
+        return sema.fail(block, func_src, "type '{}' not a function", .{callee_ty.fmt(sema.mod)});
+    };
+    try sema.checkCallArgumentCount(block, func, func_src, func_ty, resolved_args.len, bound_arg_src != null);
+
     const ensure_result_used = extra.flags.ensure_result_used;
-    return sema.analyzeCall(block, func, func_src, call_src, modifier, ensure_result_used, resolved_args, bound_arg_src, null);
+    return sema.analyzeCall(block, func, func_ty, func_src, call_src, modifier, ensure_result_used, resolved_args, bound_arg_src, null);
 }
 
 fn zirFieldParentPtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
@@ -23469,7 +23491,6 @@ fn addSafetyCheckExtra(
 fn panicWithMsg(
     sema: *Sema,
     block: *Block,
-    src: LazySrcLoc,
     msg_inst: Air.Inst.Ref,
 ) !void {
     const mod = sema.mod;
@@ -23492,7 +23513,7 @@ fn panicWithMsg(
         Value.null,
     );
     const args: [3]Air.Inst.Ref = .{ msg_inst, null_stack_trace, .null_value };
-    _ = try sema.analyzeCall(block, panic_fn, src, src, .auto, false, &args, null, null);
+    try sema.callBuiltin(block, panic_fn, .auto, &args);
 }
 
 fn panicUnwrapError(
@@ -23530,7 +23551,7 @@ fn panicUnwrapError(
             const err = try fail_block.addTyOp(unwrap_err_tag, Type.anyerror, operand);
             const err_return_trace = try sema.getErrorReturnTrace(&fail_block);
             const args: [2]Air.Inst.Ref = .{ err_return_trace, err };
-            _ = try sema.analyzeCall(&fail_block, panic_fn, sema.src, sema.src, .auto, false, &args, null, null);
+            try sema.callBuiltin(&fail_block, panic_fn, .auto, &args);
         }
     }
     try sema.addSafetyCheckExtra(parent_block, ok, &fail_block);
@@ -23615,7 +23636,7 @@ fn panicSentinelMismatch(
     else {
         const panic_fn = try sema.getBuiltin("checkNonScalarSentinel");
         const args: [2]Air.Inst.Ref = .{ expected_sentinel, actual_sentinel };
-        _ = try sema.analyzeCall(parent_block, panic_fn, sema.src, sema.src, .auto, false, &args, null, null);
+        try sema.callBuiltin(parent_block, panic_fn, .auto, &args);
         return;
     };
 
@@ -23652,7 +23673,7 @@ fn safetyCheckFormatted(
         _ = try fail_block.addNoOp(.trap);
     } else {
         const panic_fn = try sema.getBuiltin(func);
-        _ = try sema.analyzeCall(&fail_block, panic_fn, sema.src, sema.src, .auto, false, args, null, null);
+        try sema.callBuiltin(&fail_block, panic_fn, .auto, args);
     }
     try sema.addSafetyCheckExtra(parent_block, ok, &fail_block);
 }
@@ -23671,7 +23692,7 @@ fn safetyPanic(
     )).?;
 
     const msg_inst = try sema.analyzeDeclVal(block, sema.src, msg_decl_index);
-    try sema.panicWithMsg(block, sema.src, msg_inst);
+    try sema.panicWithMsg(block, msg_inst);
 }
 
 fn emitBackwardBranch(sema: *Sema, block: *Block, src: LazySrcLoc) !void {
test/cases/compile_errors/member_function_arg_mismatch.zig
@@ -6,6 +6,10 @@ pub export fn entry() void {
     var s: S = undefined;
     s.foo(true);
 }
+pub export fn entry2() void {
+    var s: S = undefined;
+    @call(.auto, s.foo, .{true});
+}
 
 // error
 // backend=stage2
@@ -13,3 +17,5 @@ pub export fn entry() void {
 //
 // :7:6: error: member function expected 2 argument(s), found 1
 // :3:5: note: function declared here
+// :11:19: error: member function expected 2 argument(s), found 1
+// :3:5: note: function declared here