Commit 568f333681

Andrew Kelley <andrew@ziglang.org>
2021-03-23 02:57:46
astgen: improve the ensure_unused_result elision
1 parent 2f391df
src/astgen.zig
@@ -1024,9 +1024,178 @@ fn blockExprStmts(
             .assign_mul_wrap => try assignOp(mod, scope, statement, .mulwrap),
 
             else => {
-                const possibly_unused_result = try expr(mod, scope, .none, statement);
-                if (!gz.zir_code.isVoidOrNoReturn(possibly_unused_result)) {
-                    _ = try gz.addUnNode(.ensure_result_used, possibly_unused_result, statement);
+                // We need to emit an error if the result is not `noreturn` or `void`, but
+                // we want to avoid adding the ZIR instruction if possible for performance.
+                const maybe_unused_result = try expr(mod, scope, .none, statement);
+                const elide_check = if (maybe_unused_result >= gz.zir_code.ref_start_index) b: {
+                    const inst = maybe_unused_result - gz.zir_code.ref_start_index;
+                    // Note that this array becomes invalid after appending more items to it
+                    // in the above while loop.
+                    const zir_tags = gz.zir_code.instructions.items(.tag);
+                    switch (zir_tags[inst]) {
+                        .@"const" => {
+                            const tv = gz.zir_code.instructions.items(.data)[inst].@"const";
+                            break :b switch (tv.ty.zigTypeTag()) {
+                                .NoReturn, .Void => true,
+                                else => false,
+                            };
+                        },
+                        // For some instructions, swap in a slightly different ZIR tag
+                        // so we can avoid a separate ensure_result_used instruction.
+                        .call_none_chkused => unreachable,
+                        .call_none => {
+                            zir_tags[inst] = .call_none_chkused;
+                            break :b true;
+                        },
+                        .call_chkused => unreachable,
+                        .call => {
+                            zir_tags[inst] = .call_chkused;
+                            break :b true;
+                        },
+
+                        // ZIR instructions that might be a type other than `noreturn` or `void`.
+                        .add,
+                        .addwrap,
+                        .alloc,
+                        .alloc_mut,
+                        .alloc_inferred,
+                        .alloc_inferred_mut,
+                        .array_cat,
+                        .array_mul,
+                        .array_type,
+                        .array_type_sentinel,
+                        .indexable_ptr_len,
+                        .as,
+                        .as_node,
+                        .@"asm",
+                        .asm_volatile,
+                        .bit_and,
+                        .bitcast,
+                        .bitcast_ref,
+                        .bitcast_result_ptr,
+                        .bit_or,
+                        .block,
+                        .block_comptime,
+                        .bool_br_and,
+                        .bool_br_or,
+                        .bool_not,
+                        .bool_and,
+                        .bool_or,
+                        .call_compile_time,
+                        .cmp_lt,
+                        .cmp_lte,
+                        .cmp_eq,
+                        .cmp_gte,
+                        .cmp_gt,
+                        .cmp_neq,
+                        .coerce_result_ptr,
+                        .decl_ref,
+                        .decl_val,
+                        .deref_node,
+                        .div,
+                        .elem_ptr,
+                        .elem_val,
+                        .elem_ptr_node,
+                        .elem_val_node,
+                        .floatcast,
+                        .field_ptr,
+                        .field_val,
+                        .field_ptr_named,
+                        .field_val_named,
+                        .fn_type,
+                        .fn_type_var_args,
+                        .fn_type_cc,
+                        .fn_type_cc_var_args,
+                        .int,
+                        .intcast,
+                        .int_type,
+                        .is_non_null,
+                        .is_null,
+                        .is_non_null_ptr,
+                        .is_null_ptr,
+                        .is_err,
+                        .is_err_ptr,
+                        .mod_rem,
+                        .mul,
+                        .mulwrap,
+                        .param_type,
+                        .ptrtoint,
+                        .ref,
+                        .ret_ptr,
+                        .ret_type,
+                        .shl,
+                        .shr,
+                        .str,
+                        .sub,
+                        .subwrap,
+                        .negate,
+                        .negate_wrap,
+                        .typeof,
+                        .xor,
+                        .optional_type,
+                        .optional_type_from_ptr_elem,
+                        .optional_payload_safe,
+                        .optional_payload_unsafe,
+                        .optional_payload_safe_ptr,
+                        .optional_payload_unsafe_ptr,
+                        .err_union_payload_safe,
+                        .err_union_payload_unsafe,
+                        .err_union_payload_safe_ptr,
+                        .err_union_payload_unsafe_ptr,
+                        .err_union_code,
+                        .err_union_code_ptr,
+                        .ptr_type,
+                        .ptr_type_simple,
+                        .enum_literal,
+                        .enum_literal_small,
+                        .merge_error_sets,
+                        .error_union_type,
+                        .bit_not,
+                        .error_set,
+                        .error_value,
+                        .slice_start,
+                        .slice_end,
+                        .slice_sentinel,
+                        .import,
+                        .typeof_peer,
+                        => break :b false,
+
+                        // ZIR instructions that are always either `noreturn` or `void`.
+                        .breakpoint,
+                        .dbg_stmt_node,
+                        .ensure_result_used,
+                        .ensure_result_non_error,
+                        .set_eval_branch_quota,
+                        .compile_log,
+                        .ensure_err_payload_void,
+                        .@"break",
+                        .break_void_tok,
+                        .break_flat,
+                        .condbr,
+                        .compile_error,
+                        .ret_node,
+                        .ret_tok,
+                        .ret_coerce,
+                        .@"unreachable",
+                        .loop,
+                        .elided,
+                        .store,
+                        .store_to_block_ptr,
+                        .store_to_inferred_ptr,
+                        .resolve_inferred_alloc,
+                        => break :b true,
+                    }
+                } else switch (maybe_unused_result) {
+                    @enumToInt(zir.Const.unused) => unreachable,
+
+                    @enumToInt(zir.Const.void_value),
+                    @enumToInt(zir.Const.unreachable_value),
+                    => true,
+
+                    else => false,
+                };
+                if (!elide_check) {
+                    _ = try gz.addUnNode(.ensure_result_used, maybe_unused_result, statement);
                 }
             },
         }
src/Module.zig
@@ -1398,163 +1398,6 @@ pub const WipZirCode = struct {
         return result;
     }
 
-    /// Returns `true` if and only if the instruction *always* has a void type, or
-    /// *always* has a NoReturn type. Function calls return false because
-    /// the answer depends on their type.
-    /// This is used to elide unnecessary `ensure_result_used` instructions.
-    pub fn isVoidOrNoReturn(wzc: WipZirCode, inst_ref: zir.Inst.Ref) bool {
-        if (inst_ref >= wzc.ref_start_index) {
-            const inst = inst_ref - wzc.ref_start_index;
-            const tags = wzc.instructions.items(.tag);
-            switch (tags[inst]) {
-                .@"const" => {
-                    const tv = wzc.instructions.items(.data)[inst].@"const";
-                    return switch (tv.ty.zigTypeTag()) {
-                        .NoReturn, .Void => true,
-                        else => false,
-                    };
-                },
-
-                .add,
-                .addwrap,
-                .alloc,
-                .alloc_mut,
-                .alloc_inferred,
-                .alloc_inferred_mut,
-                .array_cat,
-                .array_mul,
-                .array_type,
-                .array_type_sentinel,
-                .indexable_ptr_len,
-                .as,
-                .as_node,
-                .@"asm",
-                .asm_volatile,
-                .bit_and,
-                .bitcast,
-                .bitcast_ref,
-                .bitcast_result_ptr,
-                .bit_or,
-                .block,
-                .block_comptime,
-                .bool_br_and,
-                .bool_br_or,
-                .bool_not,
-                .bool_and,
-                .bool_or,
-                .call,
-                .call_compile_time,
-                .call_none,
-                .cmp_lt,
-                .cmp_lte,
-                .cmp_eq,
-                .cmp_gte,
-                .cmp_gt,
-                .cmp_neq,
-                .coerce_result_ptr,
-                .decl_ref,
-                .decl_val,
-                .deref_node,
-                .div,
-                .elem_ptr,
-                .elem_val,
-                .elem_ptr_node,
-                .elem_val_node,
-                .floatcast,
-                .field_ptr,
-                .field_val,
-                .field_ptr_named,
-                .field_val_named,
-                .fn_type,
-                .fn_type_var_args,
-                .fn_type_cc,
-                .fn_type_cc_var_args,
-                .int,
-                .intcast,
-                .int_type,
-                .is_non_null,
-                .is_null,
-                .is_non_null_ptr,
-                .is_null_ptr,
-                .is_err,
-                .is_err_ptr,
-                .mod_rem,
-                .mul,
-                .mulwrap,
-                .param_type,
-                .ptrtoint,
-                .ref,
-                .ret_ptr,
-                .ret_type,
-                .shl,
-                .shr,
-                .str,
-                .sub,
-                .subwrap,
-                .negate,
-                .negate_wrap,
-                .typeof,
-                .xor,
-                .optional_type,
-                .optional_type_from_ptr_elem,
-                .optional_payload_safe,
-                .optional_payload_unsafe,
-                .optional_payload_safe_ptr,
-                .optional_payload_unsafe_ptr,
-                .err_union_payload_safe,
-                .err_union_payload_unsafe,
-                .err_union_payload_safe_ptr,
-                .err_union_payload_unsafe_ptr,
-                .err_union_code,
-                .err_union_code_ptr,
-                .ptr_type,
-                .ptr_type_simple,
-                .enum_literal,
-                .enum_literal_small,
-                .merge_error_sets,
-                .error_union_type,
-                .bit_not,
-                .error_set,
-                .error_value,
-                .slice_start,
-                .slice_end,
-                .slice_sentinel,
-                .import,
-                .typeof_peer,
-                => return false,
-
-                .breakpoint,
-                .dbg_stmt_node,
-                .ensure_result_used,
-                .ensure_result_non_error,
-                .set_eval_branch_quota,
-                .compile_log,
-                .ensure_err_payload_void,
-                .@"break",
-                .break_void_tok,
-                .break_flat,
-                .condbr,
-                .compile_error,
-                .ret_node,
-                .ret_tok,
-                .ret_coerce,
-                .@"unreachable",
-                .loop,
-                .elided,
-                .store,
-                .store_to_block_ptr,
-                .store_to_inferred_ptr,
-                .resolve_inferred_alloc,
-                => return true,
-            }
-        }
-        return switch (inst_ref) {
-            @enumToInt(zir.Const.unused) => unreachable,
-            @enumToInt(zir.Const.void_value), @enumToInt(zir.Const.unreachable_value) => true,
-            else => false,
-        };
-    }
-
     pub fn deinit(wzc: *WipZirCode) void {
         wzc.instructions.deinit(wzc.gpa);
         wzc.extra.deinit(wzc.gpa);
src/Sema.zig
@@ -126,9 +126,11 @@ pub fn analyzeBody(sema: *Sema, block: *Scope.Block, body: []const zir.Inst.Inde
             .bool_or => try sema.zirBoolOp(block, inst, true),
             .bool_br_and => try sema.zirBoolBr(block, inst, false),
             .bool_br_or => try sema.zirBoolBr(block, inst, true),
-            .call => try sema.zirCall(block, inst, .auto),
-            .call_compile_time => try sema.zirCall(block, inst, .compile_time),
-            .call_none => try sema.zirCallNone(block, inst),
+            .call => try sema.zirCall(block, inst, .auto, false),
+            .call_chkused => try sema.zirCall(block, inst, .auto, true),
+            .call_compile_time => try sema.zirCall(block, inst, .compile_time, false),
+            .call_none => try sema.zirCallNone(block, inst, false),
+            .call_none_chkused => try sema.zirCallNone(block, inst, true),
             .cmp_eq => try sema.zirCmp(block, inst, .eq),
             .cmp_gt => try sema.zirCmp(block, inst, .gt),
             .cmp_gte => try sema.zirCmp(block, inst, .gte),
@@ -457,6 +459,16 @@ fn zirEnsureResultUsed(sema: *Sema, block: *Scope.Block, inst: zir.Inst.Index) I
     const inst_data = sema.code.instructions.items(.data)[inst].un_node;
     const operand = try sema.resolveInst(inst_data.operand);
     const src = inst_data.src();
+
+    return sema.ensureResultUsed(block, operand, src);
+}
+
+fn ensureResultUsed(
+    sema: *Sema,
+    block: *Scope.Block,
+    operand: *Inst,
+    src: LazySrcLoc,
+) InnerError!void {
     switch (operand.ty.zigTypeTag()) {
         .Void, .NoReturn => return,
         else => return sema.mod.fail(&block.base, src, "expression value is ignored", .{}),
@@ -1027,14 +1039,19 @@ fn zirDeclVal(sema: *Sema, block: *Scope.Block, inst: zir.Inst.Index) InnerError
     return sema.analyzeDeclVal(block, .unneeded, decl);
 }
 
-fn zirCallNone(sema: *Sema, block: *Scope.Block, inst: zir.Inst.Index) InnerError!*Inst {
+fn zirCallNone(
+    sema: *Sema,
+    block: *Scope.Block,
+    inst: zir.Inst.Index,
+    ensure_result_used: bool,
+) InnerError!*Inst {
     const tracy = trace(@src());
     defer tracy.end();
 
     const inst_data = sema.code.instructions.items(.data)[inst].un_node;
     const func_src: LazySrcLoc = .{ .node_offset_call_func = inst_data.src_node };
 
-    return sema.analyzeCall(block, inst_data.operand, func_src, inst_data.src(), .auto, &.{});
+    return sema.analyzeCall(block, inst_data.operand, func_src, inst_data.src(), .auto, ensure_result_used, &.{});
 }
 
 fn zirCall(
@@ -1042,6 +1059,7 @@ fn zirCall(
     block: *Scope.Block,
     inst: zir.Inst.Index,
     modifier: std.builtin.CallOptions.Modifier,
+    ensure_result_used: bool,
 ) InnerError!*Inst {
     const tracy = trace(@src());
     defer tracy.end();
@@ -1052,7 +1070,7 @@ fn zirCall(
     const extra = sema.code.extraData(zir.Inst.Call, inst_data.payload_index);
     const args = sema.code.extra[extra.end..][0..extra.data.args_len];
 
-    return sema.analyzeCall(block, extra.data.callee, func_src, call_src, modifier, args);
+    return sema.analyzeCall(block, extra.data.callee, func_src, call_src, modifier, ensure_result_used, args);
 }
 
 fn analyzeCall(
@@ -1062,6 +1080,7 @@ fn analyzeCall(
     func_src: LazySrcLoc,
     call_src: LazySrcLoc,
     modifier: std.builtin.CallOptions.Modifier,
+    ensure_result_used: bool,
     zir_args: []const zir.Inst.Ref,
 ) InnerError!*ir.Inst {
     const func = try sema.resolveInst(zir_func);
@@ -1121,7 +1140,7 @@ fn analyzeCall(
     const is_comptime_call = block.is_comptime or modifier == .compile_time;
     const is_inline_call = is_comptime_call or modifier == .always_inline or
         func.ty.fnCallingConvention() == .Inline;
-    if (is_inline_call) {
+    const result: *Inst = if (is_inline_call) res: {
         const func_val = try sema.resolveConstValue(block, func_src, func);
         const module_fn = switch (func_val.tag()) {
             .function => func_val.castTag(.function).?.data,
@@ -1195,10 +1214,13 @@ fn analyzeCall(
         // the block_inst above.
         _ = try sema.root(&child_block);
 
-        return sema.analyzeBlockBody(block, &child_block, merges);
-    }
+        break :res try sema.analyzeBlockBody(block, &child_block, merges);
+    } else try block.addCall(call_src, ret_type, func, casted_args);
 
-    return block.addCall(call_src, ret_type, func, casted_args);
+    if (ensure_result_used) {
+        try sema.ensureResultUsed(block, result, call_src);
+    }
+    return result;
 }
 
 fn zirIntType(sema: *Sema, block: *Scope.Block, inst: zir.Inst.Index) InnerError!*Inst {
src/zir.zig
@@ -489,11 +489,15 @@ pub const Inst = struct {
         /// Function call with modifier `.auto`.
         /// Uses `pl_node`. AST node is the function call. Payload is `Call`.
         call,
+        /// Same as `call` but it also does `ensure_result_used` on the return value.
+        call_chkused,
         /// Same as `call` but with modifier `.compile_time`.
         call_compile_time,
         /// Function call with modifier `.auto`, empty parameter list.
         /// Uses the `un_node` field. Operand is callee. AST node is the function call.
         call_none,
+        /// Same as `call_none` but it also does `ensure_result_used` on the return value.
+        call_none_chkused,
         /// `<`
         /// Uses the `pl_node` union field. Payload is `Bin`.
         cmp_lt,
@@ -898,8 +902,10 @@ pub const Inst = struct {
                 .bool_or,
                 .breakpoint,
                 .call,
+                .call_chkused,
                 .call_compile_time,
                 .call_none,
+                .call_none_chkused,
                 .cmp_lt,
                 .cmp_lte,
                 .cmp_eq,
@@ -1337,6 +1343,7 @@ const Writer = struct {
             .negate,
             .negate_wrap,
             .call_none,
+            .call_none_chkused,
             .compile_error,
             .deref_node,
             .ensure_result_used,
@@ -1393,6 +1400,7 @@ const Writer = struct {
             .block,
             .block_comptime,
             .call,
+            .call_chkused,
             .call_compile_time,
             .compile_log,
             .condbr,