Commit 0fd90749d1

Veikka Tuominen <git@vexu.eu>
2022-08-07 18:55:31
stage2: generate call arguments in separate blocks
1 parent 85a3f9b
src/AstGen.zig
@@ -2500,7 +2500,6 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As
             .closure_get,
             .array_base_ptr,
             .field_base_ptr,
-            .param_type,
             .ret_ptr,
             .ret_type,
             .@"try",
@@ -8228,6 +8227,33 @@ fn callExpr(
     assert(callee != .none);
     assert(node != 0);
 
+    const call_index = @intCast(Zir.Inst.Index, astgen.instructions.len);
+    const call_inst = Zir.indexToRef(call_index);
+    try gz.astgen.instructions.append(astgen.gpa, undefined);
+    try gz.instructions.append(astgen.gpa, call_index);
+
+    const scratch_top = astgen.scratch.items.len;
+    defer astgen.scratch.items.len = scratch_top;
+
+    var scratch_index = scratch_top;
+    try astgen.scratch.resize(astgen.gpa, scratch_top + call.ast.params.len);
+
+    for (call.ast.params) |param_node| {
+        var arg_block = gz.makeSubBlock(scope);
+        defer arg_block.unstack();
+
+        // `call_inst` is reused to provide the param type.
+        const arg_ref = try expr(&arg_block, &arg_block.base, .{ .coerced_ty = call_inst }, param_node);
+        _ = try arg_block.addBreak(.break_inline, call_index, arg_ref);
+
+        const body = arg_block.instructionsSlice();
+        try astgen.scratch.ensureUnusedCapacity(astgen.gpa, countBodyLenAfterFixups(astgen, body));
+        appendBodyWithFixupsArrayList(astgen, &astgen.scratch, body);
+
+        astgen.scratch.items[scratch_index] = @intCast(u32, astgen.scratch.items.len - scratch_top);
+        scratch_index += 1;
+    }
+
     const payload_index = try addExtra(astgen, Zir.Inst.Call{
         .callee = callee,
         .flags = .{
@@ -8235,22 +8261,16 @@ fn callExpr(
             .args_len = @intCast(Zir.Inst.Call.Flags.PackedArgsLen, call.ast.params.len),
         },
     });
-    var extra_index = try reserveExtra(astgen, call.ast.params.len);
-
-    for (call.ast.params) |param_node, i| {
-        const param_type = try gz.add(.{
-            .tag = .param_type,
-            .data = .{ .param_type = .{
-                .callee = callee,
-                .param_index = @intCast(u32, i),
-            } },
-        });
-        const arg_ref = try expr(gz, scope, .{ .coerced_ty = param_type }, param_node);
-        astgen.extra.items[extra_index] = @enumToInt(arg_ref);
-        extra_index += 1;
+    if (call.ast.params.len != 0) {
+        try astgen.extra.appendSlice(astgen.gpa, astgen.scratch.items[scratch_top..]);
     }
-
-    const call_inst = try gz.addPlNodePayloadIndex(.call, node, payload_index);
+    gz.astgen.instructions.set(call_index, .{
+        .tag = .call,
+        .data = .{ .pl_node = .{
+            .src_node = gz.nodeIndexToRelative(node),
+            .payload_index = payload_index,
+        } },
+    });
     return rvalue(gz, rl, call_inst, node); // TODO function call with result location
 }
 
src/print_zir.zig
@@ -246,7 +246,6 @@ const Writer = struct {
 
             .validate_array_init_ty => try self.writeValidateArrayInitTy(stream, inst),
             .array_type_sentinel => try self.writeArrayTypeSentinel(stream, inst),
-            .param_type => try self.writeParamType(stream, inst),
             .ptr_type => try self.writePtrType(stream, inst),
             .int => try self.writeInt(stream, inst),
             .int_big => try self.writeIntBig(stream, inst),
@@ -605,16 +604,6 @@ const Writer = struct {
         try self.writeSrc(stream, inst_data.src());
     }
 
-    fn writeParamType(
-        self: *Writer,
-        stream: anytype,
-        inst: Zir.Inst.Index,
-    ) (@TypeOf(stream).Error || error{OutOfMemory})!void {
-        const inst_data = self.code.instructions.items(.data)[inst].param_type;
-        try self.writeInstRef(stream, inst_data.callee);
-        try stream.print(", {d})", .{inst_data.param_index});
-    }
-
     fn writePtrType(
         self: *Writer,
         stream: anytype,
@@ -1158,7 +1147,8 @@ const Writer = struct {
     fn writeCall(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
         const inst_data = self.code.instructions.items(.data)[inst].pl_node;
         const extra = self.code.extraData(Zir.Inst.Call, inst_data.payload_index);
-        const args = self.code.refSlice(extra.end, extra.data.flags.args_len);
+        const args_len = extra.data.flags.args_len;
+        const body = self.code.extra[extra.end..];
 
         if (extra.data.flags.ensure_result_used) {
             try stream.writeAll("nodiscard ");
@@ -1166,10 +1156,27 @@ const Writer = struct {
         try stream.print(".{s}, ", .{@tagName(@intToEnum(std.builtin.CallOptions.Modifier, extra.data.flags.packed_modifier))});
         try self.writeInstRef(stream, extra.data.callee);
         try stream.writeAll(", [");
-        for (args) |arg, i| {
-            if (i != 0) try stream.writeAll(", ");
-            try self.writeInstRef(stream, arg);
+
+        self.indent += 2;
+        if (args_len != 0) {
+            try stream.writeAll("\n");
+        }
+        var i: usize = 0;
+        var arg_start: u32 = args_len;
+        while (i < args_len) : (i += 1) {
+            try stream.writeByteNTimes(' ', self.indent);
+            const arg_end = self.code.extra[extra.end + i];
+            defer arg_start = arg_end;
+            const arg_body = body[arg_start..arg_end];
+            try self.writeBracedBody(stream, arg_body);
+
+            try stream.writeAll(",\n");
         }
+        self.indent -= 2;
+        if (args_len != 0) {
+            try stream.writeByteNTimes(' ', self.indent);
+        }
+
         try stream.writeAll("]) ");
         try self.writeSrc(stream, inst_data.src());
     }
src/Sema.zig
@@ -772,7 +772,6 @@ fn analyzeBodyInner(
             .optional_payload_unsafe      => try sema.zirOptionalPayload(block, inst, false),
             .optional_payload_unsafe_ptr  => try sema.zirOptionalPayloadPtr(block, inst, false),
             .optional_type                => try sema.zirOptionalType(block, inst),
-            .param_type                   => try sema.zirParamType(block, inst),
             .ptr_type                     => try sema.zirPtrType(block, inst),
             .overflow_arithmetic_ptr      => try sema.zirOverflowArithmeticPtr(block, inst),
             .ref                          => try sema.zirRef(block, inst),
@@ -4441,43 +4440,6 @@ fn zirStoreNode(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!v
     return sema.storePtr2(block, src, ptr, src, operand, src, if (is_ret) .ret_ptr else .store);
 }
 
-fn zirParamType(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
-    const callee_src = sema.src;
-
-    const inst_data = sema.code.instructions.items(.data)[inst].param_type;
-    const callee = try sema.resolveInst(inst_data.callee);
-    const callee_ty = sema.typeOf(callee);
-    var param_index = inst_data.param_index;
-
-    const fn_ty = if (callee_ty.tag() == .bound_fn) fn_ty: {
-        const bound_fn_val = try sema.resolveConstValue(block, .unneeded, callee, undefined);
-        const bound_fn = bound_fn_val.castTag(.bound_fn).?.data;
-        const fn_ty = sema.typeOf(bound_fn.func_inst);
-        param_index += 1;
-        break :fn_ty fn_ty;
-    } else callee_ty;
-
-    const fn_info = if (fn_ty.zigTypeTag() == .Pointer)
-        fn_ty.childType().fnInfo()
-    else
-        fn_ty.fnInfo();
-
-    if (param_index >= fn_info.param_types.len) {
-        if (fn_info.is_var_args) {
-            return sema.addType(Type.initTag(.var_args_param));
-        }
-        // TODO implement begin_call/end_call Zir instructions and check
-        // argument count before casting arguments to parameter types.
-        return sema.fail(block, callee_src, "wrong number of arguments", .{});
-    }
-
-    if (fn_info.param_types[param_index].tag() == .generic_poison) {
-        return sema.addType(Type.initTag(.var_args_param));
-    }
-
-    return sema.addType(fn_info.param_types[param_index]);
-}
-
 fn zirStr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
     const tracy = trace(@src());
     defer tracy.end();
@@ -5459,13 +5421,14 @@ fn zirCall(
     const func_src: LazySrcLoc = .{ .node_offset_call_func = inst_data.src_node };
     const call_src = inst_data.src();
     const extra = sema.code.extraData(Zir.Inst.Call, inst_data.payload_index);
-    const args = sema.code.refSlice(extra.end, extra.data.flags.args_len);
+    const args_len = extra.data.flags.args_len;
 
     const modifier = @intToEnum(std.builtin.CallOptions.Modifier, extra.data.flags.packed_modifier);
     const ensure_result_used = extra.data.flags.ensure_result_used;
 
     var func = try sema.resolveInst(extra.data.callee);
     var resolved_args: []Air.Inst.Ref = undefined;
+    var arg_index: u32 = 0;
 
     const func_type = sema.typeOf(func);
 
@@ -5476,16 +5439,99 @@ fn zirCall(
         const bound_func = try sema.resolveValue(block, .unneeded, func, undefined);
         const bound_data = &bound_func.cast(Value.Payload.BoundFn).?.data;
         func = bound_data.func_inst;
-        resolved_args = try sema.arena.alloc(Air.Inst.Ref, args.len + 1);
-        resolved_args[0] = bound_data.arg0_inst;
-        for (args) |zir_arg, i| {
-            resolved_args[i + 1] = try sema.resolveInst(zir_arg);
-        }
+        resolved_args = try sema.arena.alloc(Air.Inst.Ref, args_len + 1);
+        resolved_args[arg_index] = bound_data.arg0_inst;
+        arg_index += 1;
     } else {
-        resolved_args = try sema.arena.alloc(Air.Inst.Ref, args.len);
-        for (args) |zir_arg, i| {
-            resolved_args[i] = try sema.resolveInst(zir_arg);
+        resolved_args = try sema.arena.alloc(Air.Inst.Ref, args_len);
+    }
+    const total_args = args_len + @boolToInt(bound_arg_src != null);
+
+    const callee_ty = sema.typeOf(func);
+    const func_ty = func_ty: {
+        switch (callee_ty.zigTypeTag()) {
+            .Fn => break :func_ty callee_ty,
+            .Pointer => {
+                const ptr_info = callee_ty.ptrInfo().data;
+                if (ptr_info.size == .One and ptr_info.pointee_type.zigTypeTag() == .Fn) {
+                    break :func_ty ptr_info.pointee_type;
+                }
+            },
+            else => {},
+        }
+        return sema.fail(block, func_src, "type '{}' not a function", .{callee_ty.fmt(sema.mod)});
+    };
+    const func_ty_info = func_ty.fnInfo();
+
+    const fn_params_len = func_ty_info.param_types.len;
+    if (func_ty_info.is_var_args) {
+        assert(func_ty_info.cc == .C);
+        if (total_args < fn_params_len) {
+            // TODO add error note: declared here
+            if (bound_arg_src != null) {
+                return sema.fail(
+                    block,
+                    call_src,
+                    "member function expected at least {d} argument(s), found {d}",
+                    .{ fn_params_len - 1, args_len },
+                );
+            }
+            return sema.fail(
+                block,
+                func_src,
+                "expected at least {d} argument(s), found {d}",
+                .{ fn_params_len, args_len },
+            );
+        }
+    } else if (fn_params_len != total_args) {
+        // TODO add error note: declared here
+        if (bound_arg_src != null) {
+            return sema.fail(
+                block,
+                call_src,
+                "member function expected {d} argument(s), found {d}",
+                .{ fn_params_len - 1, args_len },
+            );
+        }
+        return sema.fail(
+            block,
+            call_src,
+            "expected {d} argument(s), found {d}",
+            .{ fn_params_len, args_len },
+        );
+    }
+
+    const args_body = sema.code.extra[extra.end..];
+
+    const parent_comptime = block.is_comptime;
+    // `extra_index` and `arg_index` are separate since the bound function is passed as the first argument.
+    var extra_index: usize = 0;
+    var arg_start: u32 = args_len;
+    while (extra_index < args_len) : ({
+        extra_index += 1;
+        arg_index += 1;
+    }) {
+        const arg_end = sema.code.extra[extra.end + extra_index];
+        defer arg_start = arg_end;
+
+        const param_ty = if (arg_index >= fn_params_len or
+            func_ty_info.param_types[arg_index].tag() == .generic_poison)
+            Type.initTag(.var_args_param)
+        else
+            func_ty_info.param_types[arg_index];
+
+        const old_comptime = block.is_comptime;
+        defer block.is_comptime = old_comptime;
+        // Generate args to comptime params in comptime block.
+        block.is_comptime = parent_comptime;
+        if (arg_index < fn_params_len and func_ty_info.comptime_params[arg_index]) {
+            block.is_comptime = true;
         }
+
+        const param_ty_inst = try sema.addType(param_ty);
+        try sema.inst_map.put(sema.gpa, inst, param_ty_inst);
+
+        resolved_args[arg_index] = try sema.resolveBody(block, args_body[arg_start..arg_end], inst);
     }
 
     return sema.analyzeCall(block, func, func_src, call_src, modifier, ensure_result_used, resolved_args, bound_arg_src);
@@ -5964,7 +6010,18 @@ fn analyzeCall(
                     else => |e| return e,
                 };
             } else {
-                args[i] = uncasted_arg;
+                args[i] = sema.coerceVarArgParam(block, uncasted_arg, .unneeded) catch |err| switch (err) {
+                    error.NeededSourceLocation => {
+                        const decl = sema.mod.declPtr(block.src_decl);
+                        _ = try sema.coerceVarArgParam(
+                            block,
+                            uncasted_arg,
+                            Module.argSrc(call_src.node_offset.x, sema.gpa, decl, i, bound_arg_src),
+                        );
+                        return error.AnalysisFail;
+                    },
+                    else => |e| return e,
+                };
             }
         }
 
@@ -23534,7 +23591,10 @@ fn coerceVarArgParam(
     inst_src: LazySrcLoc,
 ) !Air.Inst.Ref {
     const inst_ty = sema.typeOf(inst);
+    if (block.is_typeof) return inst;
+
     switch (inst_ty.zigTypeTag()) {
+        // TODO consider casting to c_int/f64 if they fit
         .ComptimeInt, .ComptimeFloat => return sema.fail(block, inst_src, "integer and float literals in var args function must be casted", .{}),
         else => {},
     }
src/Zir.zig
@@ -490,14 +490,6 @@ pub const Inst = struct {
         /// Merge two error sets into one, `E1 || E2`.
         /// Uses the `pl_node` field with payload `Bin`.
         merge_error_sets,
-        /// Given a reference to a function and a parameter index, returns the
-        /// type of the parameter. The only usage of this instruction is for the
-        /// result location of parameters of function calls. In the case of a function's
-        /// parameter type being `anytype`, it is the type coercion's job to detect this
-        /// scenario and skip the coercion, so that semantic analysis of this instruction
-        /// is not in a position where it must create an invalid type.
-        /// Uses the `param_type` union field.
-        param_type,
         /// Turns an R-Value into a const L-Value. In other words, it takes a value,
         /// stores it in a memory location, and returns a const pointer to it. If the value
         /// is `comptime`, the memory location is global static constant data. Otherwise,
@@ -1095,7 +1087,6 @@ pub const Inst = struct {
                 .mul,
                 .mulwrap,
                 .mul_sat,
-                .param_type,
                 .ref,
                 .shl,
                 .shl_sat,
@@ -1397,7 +1388,6 @@ pub const Inst = struct {
                 .mul,
                 .mulwrap,
                 .mul_sat,
-                .param_type,
                 .ref,
                 .shl,
                 .shl_sat,
@@ -1569,7 +1559,6 @@ pub const Inst = struct {
                 .mulwrap = .pl_node,
                 .mul_sat = .pl_node,
 
-                .param_type = .param_type,
                 .param = .pl_tok,
                 .param_comptime = .pl_tok,
                 .param_anytype = .str_tok,
@@ -2540,10 +2529,6 @@ pub const Inst = struct {
             /// Points to a `Block`.
             payload_index: u32,
         },
-        param_type: struct {
-            callee: Ref,
-            param_index: u32,
-        },
         @"unreachable": struct {
             /// Offset from Decl AST node index.
             /// `Tag` determines which kind of AST node this points to.
@@ -2614,7 +2599,6 @@ pub const Inst = struct {
             ptr_type,
             int_type,
             bool_br,
-            param_type,
             @"unreachable",
             @"break",
             switch_capture,
@@ -2794,7 +2778,9 @@ pub const Inst = struct {
     };
 
     /// Stored inside extra, with trailing arguments according to `args_len`.
-    /// Each argument is a `Ref`.
+    /// Implicit 0. arg_0_start: u32, // always same as `args_len`
+    /// 1. arg_end: u32, // for each `args_len`
+    /// arg_N_start is the same as arg_N-1_end
     pub const Call = struct {
         // Note: Flags *must* come first so that unusedResultExpr
         // can find it when it goes to modify them.
test/behavior/call.zig
@@ -246,3 +246,18 @@ test "function call with 40 arguments" {
     };
     try S.doTheTest(39);
 }
+
+test "arguments to comptime parameters generated in comptime blocks" {
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+
+    const S = struct {
+        fn fortyTwo() i32 {
+            return 42;
+        }
+
+        fn foo(comptime x: i32) void {
+            if (x != 42) @compileError("bad");
+        }
+    };
+    S.foo(S.fortyTwo());
+}
test/cases/compile_errors/int_literal_passed_as_variadic_arg.zig
@@ -0,0 +1,11 @@
+extern fn printf([*:0]const u8, ...) c_int;
+
+pub export fn entry() void {
+    _ = printf("%d %d %d %d\n", 1, 2, 3, 4);
+}
+
+// error
+// backend=stage2
+// target=native
+//
+// :4:33: error: integer and float literals in var args function must be casted
test/cases/compile_errors/member_function_arg_mismatch.zig
@@ -0,0 +1,14 @@
+const S = struct {
+    a: u32,
+    fn foo(_: *S, _: u32, _: bool) void {}
+};
+pub export fn entry() void {
+    var s: S = undefined;
+    s.foo(true);
+}
+
+// error
+// backend=stage2
+// target=native
+//
+// :7:10: error: member function expected 2 argument(s), found 1