Commit cd04b49041

Andrew Kelley <andrew@ziglang.org>
2022-05-19 02:01:45
stage2: fix `@call` when used in a comptime or nosuspend block
`@call` allows specifying the modifier explicitly, however it can still appear in a context that overrides the modifier. This commit adds flags to the BuiltinCall ZIR encoding. Since we have unused bits I also threw in the ensure_result_used mechanism. I also deleted a behavior test that was checking for bound function behavior where I think stage2 behavior is correct and stage1 behavior is incorrect.
1 parent 5626bb4
Changed files (5)
src/AstGen.zig
@@ -71,6 +71,7 @@ fn setExtra(astgen: *AstGen, index: usize, extra: anytype) void {
             Zir.Inst.Ref => @enumToInt(@field(extra, field.name)),
             i32 => @bitCast(u32, @field(extra, field.name)),
             Zir.Inst.Call.Flags => @bitCast(u32, @field(extra, field.name)),
+            Zir.Inst.BuiltinCall.Flags => @bitCast(u32, @field(extra, field.name)),
             Zir.Inst.SwitchBlock.Bits => @bitCast(u32, @field(extra, field.name)),
             Zir.Inst.ExtendedFunc.Bits => @bitCast(u32, @field(extra, field.name)),
             else => @compileError("bad field type"),
@@ -2213,6 +2214,14 @@ fn unusedResultExpr(gz: *GenZir, scope: *Scope, statement: Ast.Node.Index) Inner
                 slot.* = @bitCast(u32, flags);
                 break :b true;
             },
+            .builtin_call => {
+                const extra_index = gz.astgen.instructions.items(.data)[inst].pl_node.payload_index;
+                const slot = &gz.astgen.extra.items[extra_index];
+                var flags = @bitCast(Zir.Inst.BuiltinCall.Flags, slot.*);
+                flags.ensure_result_used = true;
+                slot.* = @bitCast(u32, flags);
+                break :b true;
+            },
 
             // ZIR instructions that might be a type other than `noreturn` or `void`.
             .add,
@@ -2412,7 +2421,6 @@ fn unusedResultExpr(gz: *GenZir, scope: *Scope, statement: Ast.Node.Index) Inner
             .atomic_load,
             .atomic_rmw,
             .mul_add,
-            .builtin_call,
             .field_parent_ptr,
             .maximum,
             .minimum,
@@ -7502,6 +7510,11 @@ fn builtinCall(
                 .options = options,
                 .callee = callee,
                 .args = args,
+                .flags = .{
+                    .is_nosuspend = gz.nosuspend_node != 0,
+                    .is_comptime = gz.force_comptime,
+                    .ensure_result_used = false,
+                },
             });
             return rvalue(gz, rl, result, node);
         },
src/print_zir.zig
@@ -365,7 +365,7 @@ const Writer = struct {
             .@"export" => try self.writePlNodeExport(stream, inst),
             .export_value => try self.writePlNodeExportValue(stream, inst),
 
-            .call => try self.writePlNodeCall(stream, inst),
+            .call => try self.writeCall(stream, inst),
 
             .block,
             .block_inline,
@@ -793,6 +793,11 @@ const Writer = struct {
     fn writeBuiltinCall(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
         const inst_data = self.code.instructions.items(.data)[inst].pl_node;
         const extra = self.code.extraData(Zir.Inst.BuiltinCall, inst_data.payload_index).data;
+
+        try self.writeFlag(stream, "nodiscard ", extra.flags.ensure_result_used);
+        try self.writeFlag(stream, "nosuspend ", extra.flags.is_nosuspend);
+        try self.writeFlag(stream, "comptime ", extra.flags.is_comptime);
+
         try self.writeInstRef(stream, extra.options);
         try stream.writeAll(", ");
         try self.writeInstRef(stream, extra.callee);
@@ -1144,7 +1149,7 @@ const Writer = struct {
         try self.writeSrc(stream, src);
     }
 
-    fn writePlNodeCall(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
+    fn writeCall(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
         const inst_data = self.code.instructions.items(.data)[inst].pl_node;
         const extra = self.code.extraData(Zir.Inst.Call, inst_data.payload_index);
         const args = self.code.refSlice(extra.end, extra.data.flags.args_len);
src/Sema.zig
@@ -16097,12 +16097,12 @@ fn zirBuiltinCall(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
     const args_src: LazySrcLoc = .{ .node_offset_builtin_call_arg2 = inst_data.src_node };
     const call_src = inst_data.src();
 
-    const extra = sema.code.extraData(Zir.Inst.BuiltinCall, inst_data.payload_index);
-    var func = sema.resolveInst(extra.data.callee);
-    const options = sema.resolveInst(extra.data.options);
-    const args = sema.resolveInst(extra.data.args);
+    const extra = sema.code.extraData(Zir.Inst.BuiltinCall, inst_data.payload_index).data;
+    var func = sema.resolveInst(extra.callee);
+    const options = sema.resolveInst(extra.options);
+    const args = sema.resolveInst(extra.args);
 
-    const modifier: std.builtin.CallOptions.Modifier = modifier: {
+    const wanted_modifier: std.builtin.CallOptions.Modifier = modifier: {
         const call_options_ty = try sema.getBuiltinType(block, options_src, "CallOptions");
         const coerced_options = try sema.coerce(block, call_options_ty, options, options_src);
 
@@ -16118,6 +16118,41 @@ fn zirBuiltinCall(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         break :modifier modifier_val.toEnum(std.builtin.CallOptions.Modifier);
     };
 
+    const modifier: std.builtin.CallOptions.Modifier = switch (wanted_modifier) {
+        // These can be upgraded to comptime or nosuspend calls.
+        .auto, .never_tail, .no_async => m: {
+            if (extra.flags.is_comptime) {
+                break :m .compile_time;
+            }
+            if (extra.flags.is_nosuspend) {
+                break :m .no_async;
+            }
+            break :m wanted_modifier;
+        },
+        // These can be upgraded to comptime. nosuspend bit can be safely ignored.
+        .always_tail, .always_inline, .compile_time => m: {
+            if (extra.flags.is_comptime) {
+                break :m .compile_time;
+            }
+            break :m wanted_modifier;
+        },
+        .async_kw => m: {
+            if (extra.flags.is_nosuspend) {
+                return sema.fail(block, options_src, "modifier 'async_kw' cannot be used inside nosuspend block", .{});
+            }
+            if (extra.flags.is_comptime) {
+                return sema.fail(block, options_src, "modifier 'async_kw' cannot be used in combination with comptime function call", .{});
+            }
+            break :m wanted_modifier;
+        },
+        .never_inline => m: {
+            if (extra.flags.is_comptime) {
+                return sema.fail(block, options_src, "modifier 'never_inline' cannot be used in combination with comptime function call", .{});
+            }
+            break :m wanted_modifier;
+        },
+    };
+
     const args_ty = sema.typeOf(args);
     if (!args_ty.isTuple() and args_ty.tag() != .empty_struct_literal) {
         return sema.fail(block, args_src, "expected a tuple, found {}", .{args_ty.fmt(sema.mod)});
@@ -16141,8 +16176,8 @@ fn zirBuiltinCall(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
             resolved.* = try sema.tupleFieldValByIndex(block, args_src, args, @intCast(u32, i), args_ty);
         }
     }
-
-    return sema.analyzeCall(block, func, func_src, call_src, modifier, false, resolved_args);
+    const ensure_result_used = extra.flags.ensure_result_used;
+    return sema.analyzeCall(block, func, func_src, call_src, modifier, ensure_result_used, resolved_args);
 }
 
 fn zirFieldParentPtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
src/Zir.zig
@@ -72,6 +72,7 @@ pub fn extraData(code: Zir, comptime T: type, index: usize) struct { data: T, en
             Inst.Ref => @intToEnum(Inst.Ref, code.extra[i]),
             i32 => @bitCast(i32, code.extra[i]),
             Inst.Call.Flags => @bitCast(Inst.Call.Flags, code.extra[i]),
+            Inst.BuiltinCall.Flags => @bitCast(Inst.BuiltinCall.Flags, code.extra[i]),
             Inst.SwitchBlock.Bits => @bitCast(Inst.SwitchBlock.Bits, code.extra[i]),
             Inst.ExtendedFunc.Bits => @bitCast(Inst.ExtendedFunc.Bits, code.extra[i]),
             else => @compileError("bad field type"),
@@ -280,8 +281,13 @@ pub const Inst = struct {
         /// Uses the `break` union field.
         break_inline,
         /// Function call.
-        /// Uses `pl_node`. AST node is the function call. Payload is `Call`.
+        /// Uses the `pl_node` union field with payload `Call`.
+        /// AST node is the function call.
         call,
+        /// Implements the `@call` builtin.
+        /// Uses the `pl_node` union field with payload `BuiltinCall`.
+        /// AST node is the builtin call.
+        builtin_call,
         /// `<`
         /// Uses the `pl_node` union field. Payload is `Bin`.
         cmp_lt,
@@ -916,9 +922,6 @@ pub const Inst = struct {
         /// The addend communicates the type of the builtin.
         /// The mulends need to be coerced to the same type.
         mul_add,
-        /// Implements the `@call` builtin.
-        /// Uses the `pl_node` union field with payload `BuiltinCall`.
-        builtin_call,
         /// Implements the `@fieldParentPtr` builtin.
         /// Uses the `pl_node` union field with payload `FieldParentPtr`.
         field_parent_ptr,
@@ -2733,9 +2736,24 @@ pub const Inst = struct {
     };
 
     pub const BuiltinCall = struct {
+        // Note: Flags *must* come first so that unusedResultExpr
+        // can find it when it goes to modify them.
+        flags: Flags,
         options: Ref,
         callee: Ref,
         args: Ref,
+
+        pub const Flags = packed struct {
+            is_nosuspend: bool,
+            is_comptime: bool,
+            ensure_result_used: bool,
+            _: u29 = undefined,
+
+            comptime {
+                if (@sizeOf(Flags) != 4 or @bitSizeOf(Flags) != 32)
+                    @compileError("Layout of BuiltinCall.Flags needs to be updated!");
+            }
+        };
     };
 
     /// This data is stored inside extra, with two sets of trailing `Ref`:
test/behavior/call.zig
@@ -19,7 +19,11 @@ test "super basic invocations" {
 }
 
 test "basic invocations" {
-    if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
 
     const foo = struct {
         fn foo() i32 {
@@ -41,7 +45,10 @@ test "basic invocations" {
     }
     {
         // call of non comptime-known function
-        var alias_foo = foo;
+        var alias_foo = switch (builtin.zig_backend) {
+            .stage1 => foo,
+            else => &foo,
+        };
         try expect(@call(.{ .modifier = .no_async }, alias_foo, .{}) == 1234);
         try expect(@call(.{ .modifier = .never_tail }, alias_foo, .{}) == 1234);
         try expect(@call(.{ .modifier = .never_inline }, alias_foo, .{}) == 1234);
@@ -79,26 +86,6 @@ test "tuple parameters" {
     }
 }
 
-test "comptime call with bound function as parameter" {
-    if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
-
-    const S = struct {
-        fn ReturnType(func: anytype) type {
-            return switch (@typeInfo(@TypeOf(func))) {
-                .BoundFn => |info| info,
-                else => unreachable,
-            }.return_type orelse void;
-        }
-
-        fn call_me_maybe() ?i32 {
-            return 123;
-        }
-    };
-
-    var inst: S = undefined;
-    try expectEqual(?i32, S.ReturnType(inst.call_me_maybe));
-}
-
 test "result location of function call argument through runtime condition and struct init" {
     if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO