Commit b529d8e48f

Cody Tapscott <topolarity@tapscott.me>
2022-09-23 20:50:55
stage2: Propagate error return trace into fn call
This change extends the "lifetime" of the error return trace associated with an error to include the duration of a function call it is passed to. This means that if a function returns an error, its return trace will include the error return trace for any error inputs. This is needed to support `testing.expectError` and similar functions. If a function returns a non-error, we have to clean up any error return traces created by error-able call arguments.
1 parent 77720e3
Changed files (3)
src/AstGen.zig
@@ -335,6 +335,8 @@ pub const ResultInfo = struct {
         error_handling_expr,
         /// The expression is the right-hand side of a shift operation.
         shift_op,
+        /// The expression is an argument in a function call.
+        fn_arg,
         /// No specific operator in particular.
         none,
     };
@@ -5217,9 +5219,9 @@ fn popErrorReturnTrace(
 
     const result_is_err = nodeMayEvalToError(tree, node);
 
-    // If we are breaking to a try/catch/error-union-if/return, the error trace propagates.
+    // If we are breaking to a try/catch/error-union-if/return or a function call, the error trace propagates.
     const propagate_error_trace = switch (ri.ctx) {
-        .error_handling_expr, .@"return" => true,
+        .error_handling_expr, .@"return", .fn_arg => true,
         else => false,
     };
 
@@ -8548,7 +8550,7 @@ fn callExpr(
         defer arg_block.unstack();
 
         // `call_inst` is reused to provide the param type.
-        const arg_ref = try expr(&arg_block, &arg_block.base, .{ .rl = .{ .coerced_ty = call_inst } }, param_node);
+        const arg_ref = try expr(&arg_block, &arg_block.base, .{ .rl = .{ .coerced_ty = call_inst }, .ctx = .fn_arg }, param_node);
         _ = try arg_block.addBreak(.break_inline, call_index, arg_ref);
 
         const body = arg_block.instructionsSlice();
@@ -8562,7 +8564,7 @@ fn callExpr(
     // If our result location is a try/catch/error-union-if/return, the error trace propagates.
     // Otherwise, it should always be popped (handled in Sema).
     const propagate_error_trace = switch (ri.ctx) {
-        .error_handling_expr, .@"return" => true, // Propagate to try/catch/error-union-if and return
+        .error_handling_expr, .@"return", .fn_arg => true, // Propagate to try/catch/error-union-if, return, and other function calls
         else => false,
     };
 
src/Sema.zig
@@ -499,6 +499,25 @@ pub const Block = struct {
         return result_index;
     }
 
+    /// Insert an instruction into the block at `index`. Moves all following
+    /// instructions forward in the block to make room. Operation is O(N).
+    pub fn insertInst(block: *Block, index: Air.Inst.Index, inst: Air.Inst) error{OutOfMemory}!Air.Inst.Ref {
+        return Air.indexToRef(try block.insertInstAsIndex(index, inst));
+    }
+
+    pub fn insertInstAsIndex(block: *Block, index: Air.Inst.Index, inst: Air.Inst) error{OutOfMemory}!Air.Inst.Index {
+        const sema = block.sema;
+        const gpa = sema.gpa;
+
+        try sema.air_instructions.ensureUnusedCapacity(gpa, 1);
+
+        const result_index = @intCast(Air.Inst.Index, sema.air_instructions.len);
+        sema.air_instructions.appendAssumeCapacity(inst);
+
+        try block.instructions.insert(gpa, index, result_index);
+        return result_index;
+    }
+
     fn addUnreachable(block: *Block, src: LazySrcLoc, safety_check: bool) !void {
         if (safety_check and block.wantSafety()) {
             _ = try block.sema.safetyPanic(block, src, .unreach);
@@ -5648,6 +5667,85 @@ fn funcDeclSrc(sema: *Sema, block: *Block, src: LazySrcLoc, func_inst: Air.Inst.
     return owner_decl.srcLoc();
 }
 
+/// Add instructions to block to "pop" the error return trace.
+/// If `operand` is provided, only pops if operand is non-error.
+fn popErrorReturnTrace(
+    sema: *Sema,
+    block: *Block,
+    src: LazySrcLoc,
+    operand: ?Air.Inst.Ref,
+    saved_error_trace_index: Air.Inst.Ref,
+) CompileError!void {
+    var is_non_error: ?bool = null;
+    var is_non_error_inst: Air.Inst.Ref = undefined;
+    if (operand) |op| {
+        is_non_error_inst = try sema.analyzeIsNonErr(block, src, op);
+        if (try sema.resolveDefinedValue(block, src, is_non_error_inst)) |cond_val|
+            is_non_error = cond_val.toBool();
+    } else is_non_error = true; // no operand means pop unconditionally
+
+    if (is_non_error == true) {
+        // AstGen determined this result does not go to an error-handling expr (try/catch/return etc.), or
+        // the result is comptime-known to be a non-error. Either way, pop unconditionally.
+
+        const unresolved_stack_trace_ty = try sema.getBuiltinType(block, src, "StackTrace");
+        const stack_trace_ty = try sema.resolveTypeFields(block, src, unresolved_stack_trace_ty);
+        const ptr_stack_trace_ty = try Type.Tag.single_mut_pointer.create(sema.arena, stack_trace_ty);
+        const err_return_trace = try block.addTy(.err_return_trace, ptr_stack_trace_ty);
+        const field_ptr = try sema.structFieldPtr(block, src, err_return_trace, "index", src, stack_trace_ty, true);
+        try sema.storePtr2(block, src, field_ptr, src, saved_error_trace_index, src, .store);
+    } else if (is_non_error == null) {
+        // The result might be an error. If it is, we leave the error trace alone. If it isn't, we need
+        // to pop any error trace that may have been propagated from our arguments.
+
+        try sema.air_extra.ensureUnusedCapacity(sema.gpa, @typeInfo(Air.Block).Struct.fields.len);
+        const cond_block_inst = try block.addInstAsIndex(.{
+            .tag = .block,
+            .data = .{
+                .ty_pl = .{
+                    .ty = Air.Inst.Ref.void_type,
+                    .payload = undefined, // updated below
+                },
+            },
+        });
+
+        var then_block = block.makeSubBlock();
+        defer then_block.instructions.deinit(sema.gpa);
+
+        // If non-error, then pop the error return trace by restoring the index.
+        const unresolved_stack_trace_ty = try sema.getBuiltinType(block, src, "StackTrace");
+        const stack_trace_ty = try sema.resolveTypeFields(block, src, unresolved_stack_trace_ty);
+        const ptr_stack_trace_ty = try Type.Tag.single_mut_pointer.create(sema.arena, stack_trace_ty);
+        const err_return_trace = try then_block.addTy(.err_return_trace, ptr_stack_trace_ty);
+        const field_ptr = try sema.structFieldPtr(&then_block, src, err_return_trace, "index", src, stack_trace_ty, true);
+        try sema.storePtr2(&then_block, src, field_ptr, src, saved_error_trace_index, src, .store);
+        _ = try then_block.addBr(cond_block_inst, Air.Inst.Ref.void_value);
+
+        // Otherwise, do nothing
+        var else_block = block.makeSubBlock();
+        defer else_block.instructions.deinit(sema.gpa);
+        _ = try else_block.addBr(cond_block_inst, Air.Inst.Ref.void_value);
+
+        try sema.air_extra.ensureUnusedCapacity(sema.gpa, @typeInfo(Air.CondBr).Struct.fields.len +
+            then_block.instructions.items.len + else_block.instructions.items.len +
+            @typeInfo(Air.Block).Struct.fields.len + 1); // +1 for the sole .cond_br instruction in the .block
+
+        const cond_br_inst = @intCast(Air.Inst.Index, sema.air_instructions.len);
+        try sema.air_instructions.append(sema.gpa, .{ .tag = .cond_br, .data = .{ .pl_op = .{
+            .operand = is_non_error_inst,
+            .payload = sema.addExtraAssumeCapacity(Air.CondBr{
+                .then_body_len = @intCast(u32, then_block.instructions.items.len),
+                .else_body_len = @intCast(u32, else_block.instructions.items.len),
+            }),
+        } } });
+        sema.air_extra.appendSliceAssumeCapacity(then_block.instructions.items);
+        sema.air_extra.appendSliceAssumeCapacity(else_block.instructions.items);
+
+        sema.air_instructions.items(.data)[cond_block_inst].ty_pl.payload = sema.addExtraAssumeCapacity(Air.Block{ .body_len = 1 });
+        sema.air_extra.appendAssumeCapacity(cond_br_inst);
+    }
+}
+
 fn zirCall(
     sema: *Sema,
     block: *Block,
@@ -5737,6 +5835,9 @@ fn zirCall(
 
     const args_body = sema.code.extra[extra.end..];
 
+    var input_is_error = false;
+    const block_index = @intCast(Air.Inst.Index, block.instructions.items.len);
+
     const parent_comptime = block.is_comptime;
     // `extra_index` and `arg_index` are separate since the bound function is passed as the first argument.
     var extra_index: usize = 0;
@@ -5754,10 +5855,8 @@ fn zirCall(
         else
             func_ty_info.param_types[arg_index];
 
-        const old_comptime = block.is_comptime;
-        defer block.is_comptime = old_comptime;
         // Generate args to comptime params in comptime block.
-        block.is_comptime = parent_comptime;
+        defer block.is_comptime = parent_comptime;
         if (arg_index < fn_params_len and func_ty_info.comptime_params[arg_index]) {
             block.is_comptime = true;
         }
@@ -5766,13 +5865,58 @@ fn zirCall(
         try sema.inst_map.put(sema.gpa, inst, param_ty_inst);
 
         const resolved = try sema.resolveBody(block, args_body[arg_start..arg_end], inst);
-        if (sema.typeOf(resolved).zigTypeTag() == .NoReturn) {
+        const resolved_ty = sema.typeOf(resolved);
+        if (resolved_ty.zigTypeTag() == .NoReturn) {
             return resolved;
         }
+        if (resolved_ty.isError()) {
+            input_is_error = true;
+        }
         resolved_args[arg_index] = resolved;
     }
+    if (sema.owner_func == null or !sema.owner_func.?.calls_or_awaits_errorable_fn)
+        input_is_error = false; // input was an error type, but no errorable fn's were actually called
+
+    const backend_supports_error_return_tracing = sema.mod.comp.bin_file.options.use_llvm;
+    if (backend_supports_error_return_tracing and sema.mod.comp.bin_file.options.error_return_tracing and
+        !block.is_comptime and (input_is_error or pop_error_return_trace))
+    {
+        const call_inst: Air.Inst.Ref = if (modifier == .always_tail) undefined else b: {
+            break :b try sema.analyzeCall(block, func, func_src, call_src, modifier, ensure_result_used, resolved_args, bound_arg_src);
+        };
+
+        const return_ty = sema.typeOf(call_inst);
+        if (modifier != .always_tail and return_ty.isNoReturn())
+            return call_inst; // call to "fn(...) noreturn", don't pop
+
+        // If any input is an error-type, we might need to pop any trace it generated. Otherwise, we only
+        // need to clean-up our own trace if we were passed to a non-error-handling expression.
+        if (input_is_error or (pop_error_return_trace and modifier != .always_tail and return_ty.isError())) {
+            const unresolved_stack_trace_ty = try sema.getBuiltinType(block, call_src, "StackTrace");
+            const stack_trace_ty = try sema.resolveTypeFields(block, call_src, unresolved_stack_trace_ty);
+            const field_index = try sema.structFieldIndex(block, stack_trace_ty, "index", call_src);
+
+            // Insert a save instruction before the arg resolution + call instructions we just generated
+            const save_inst = try block.insertInst(block_index, .{
+                .tag = .save_err_return_trace_index,
+                .data = .{ .ty_pl = .{
+                    .ty = try sema.addType(stack_trace_ty),
+                    .payload = @intCast(u32, field_index),
+                } },
+            });
+
+            // Pop the error return trace, testing the result for non-error if necessary
+            const operand = if (pop_error_return_trace or modifier == .always_tail) null else call_inst;
+            try sema.popErrorReturnTrace(block, call_src, operand, save_inst);
+        }
 
-    return sema.analyzeCall(block, func, func_src, call_src, modifier, ensure_result_used, pop_error_return_trace, resolved_args, bound_arg_src);
+        if (modifier == .always_tail) // Perform the call *after* the restore, so that a tail call is possible.
+            return sema.analyzeCall(block, func, func_src, call_src, modifier, ensure_result_used, resolved_args, bound_arg_src);
+
+        return call_inst;
+    } else {
+        return sema.analyzeCall(block, func, func_src, call_src, modifier, ensure_result_used, resolved_args, bound_arg_src);
+    }
 }
 
 const GenericCallAdapter = struct {
@@ -5884,7 +6028,6 @@ fn analyzeCall(
     call_src: LazySrcLoc,
     modifier: std.builtin.CallOptions.Modifier,
     ensure_result_used: bool,
-    pop_error_return_trace: bool,
     uncasted_args: []const Air.Inst.Ref,
     bound_arg_src: ?LazySrcLoc,
 ) CompileError!Air.Inst.Ref {
@@ -6335,55 +6478,19 @@ fn analyzeCall(
             sema.owner_func.?.calls_or_awaits_errorable_fn = true;
         }
 
-        const backend_supports_error_return_tracing = sema.mod.comp.bin_file.options.use_llvm;
-        const emit_error_trace_save_restore = sema.mod.comp.bin_file.options.error_return_tracing and
-            backend_supports_error_return_tracing and
-            pop_error_return_trace and func_ty_info.return_type.isError();
-
-        if (emit_error_trace_save_restore) {
-            // This function call is error-able (and so can generate an error trace), but AstGen determined
-            // that its result does not go to an error-handling operator (try/catch/return etc.). We need to
-            // save and restore the error trace index here, effectively "popping" the new entries immediately.
-
-            const unresolved_stack_trace_ty = try sema.getBuiltinType(block, call_src, "StackTrace");
-            const stack_trace_ty = try sema.resolveTypeFields(block, call_src, unresolved_stack_trace_ty);
-            const ptr_stack_trace_ty = try Type.Tag.single_mut_pointer.create(sema.arena, stack_trace_ty);
-            const err_return_trace = try block.addTy(.err_return_trace, ptr_stack_trace_ty);
-            const field_ptr = try sema.structFieldPtr(block, call_src, err_return_trace, "index", call_src, stack_trace_ty, true);
-
-            const saved_index = try sema.analyzeLoad(block, call_src, field_ptr, call_src);
-
-            try sema.air_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.Call).Struct.fields.len +
-                args.len);
-            const func_inst = try block.addInst(.{
-                .tag = call_tag,
-                .data = .{ .pl_op = .{
-                    .operand = func,
-                    .payload = sema.addExtraAssumeCapacity(Air.Call{
-                        .args_len = @intCast(u32, args.len),
-                    }),
-                } },
-            });
-            sema.appendRefsAssumeCapacity(args);
-
-            try sema.storePtr2(block, call_src, field_ptr, call_src, saved_index, call_src, .store);
-
-            break :res func_inst;
-        } else {
-            try sema.air_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.Call).Struct.fields.len +
-                args.len);
-            const func_inst = try block.addInst(.{
-                .tag = call_tag,
-                .data = .{ .pl_op = .{
-                    .operand = func,
-                    .payload = sema.addExtraAssumeCapacity(Air.Call{
-                        .args_len = @intCast(u32, args.len),
-                    }),
-                } },
-            });
-            sema.appendRefsAssumeCapacity(args);
-            break :res func_inst;
-        }
+        try sema.air_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.Call).Struct.fields.len +
+            args.len);
+        const func_inst = try block.addInst(.{
+            .tag = call_tag,
+            .data = .{ .pl_op = .{
+                .operand = func,
+                .payload = sema.addExtraAssumeCapacity(Air.Call{
+                    .args_len = @intCast(u32, args.len),
+                }),
+            } },
+        });
+        sema.appendRefsAssumeCapacity(args);
+        break :res func_inst;
     };
 
     if (ensure_result_used) {
@@ -10965,7 +11072,7 @@ fn maybeErrorUnwrap(sema: *Sema, block: *Block, body: []const Zir.Inst.Index, op
                 const panic_fn = try sema.getBuiltin(block, src, "panicUnwrapError");
                 const err_return_trace = try sema.getErrorReturnTrace(block, src);
                 const args: [2]Air.Inst.Ref = .{ err_return_trace, operand };
-                _ = try sema.analyzeCall(block, panic_fn, src, src, .auto, false, false, &args, null);
+                _ = try sema.analyzeCall(block, panic_fn, src, src, .auto, false, &args, null);
                 return true;
             },
             .panic => {
@@ -10976,7 +11083,7 @@ fn maybeErrorUnwrap(sema: *Sema, block: *Block, body: []const Zir.Inst.Index, op
                 const panic_fn = try sema.getBuiltin(block, src, "panic");
                 const err_return_trace = try sema.getErrorReturnTrace(block, src);
                 const args: [3]Air.Inst.Ref = .{ msg_inst, err_return_trace, .null_value };
-                _ = try sema.analyzeCall(block, panic_fn, src, src, .auto, false, false, &args, null);
+                _ = try sema.analyzeCall(block, panic_fn, src, src, .auto, false, &args, null);
                 return true;
             },
             else => unreachable,
@@ -16179,7 +16286,7 @@ fn retWithErrTracing(
     const args: [1]Air.Inst.Ref = .{err_return_trace};
 
     if (!need_check) {
-        _ = try sema.analyzeCall(block, return_err_fn, src, src, .never_inline, false, false, &args, null);
+        _ = try sema.analyzeCall(block, return_err_fn, src, src, .never_inline, false, &args, null);
         _ = try block.addUnOp(ret_tag, operand);
         return always_noreturn;
     }
@@ -16190,7 +16297,7 @@ fn retWithErrTracing(
 
     var else_block = block.makeSubBlock();
     defer else_block.instructions.deinit(gpa);
-    _ = try sema.analyzeCall(&else_block, return_err_fn, src, src, .never_inline, false, false, &args, null);
+    _ = try sema.analyzeCall(&else_block, return_err_fn, src, src, .never_inline, false, &args, null);
     _ = try else_block.addUnOp(ret_tag, operand);
 
     try sema.air_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.CondBr).Struct.fields.len +
@@ -20414,7 +20521,7 @@ fn zirBuiltinCall(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         }
     }
     const ensure_result_used = extra.flags.ensure_result_used;
-    return sema.analyzeCall(block, func, func_src, call_src, modifier, ensure_result_used, false, resolved_args, bound_arg_src);
+    return sema.analyzeCall(block, func, func_src, call_src, modifier, ensure_result_used, resolved_args, bound_arg_src);
 }
 
 fn zirFieldParentPtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
@@ -21848,7 +21955,7 @@ fn panicWithMsg(
         Value.@"null",
     );
     const args: [3]Air.Inst.Ref = .{ msg_inst, null_stack_trace, .null_value };
-    _ = try sema.analyzeCall(block, panic_fn, src, src, .auto, false, false, &args, null);
+    _ = try sema.analyzeCall(block, panic_fn, src, src, .auto, false, &args, null);
     return always_noreturn;
 }
 
@@ -21889,7 +21996,7 @@ fn panicUnwrapError(
             const err = try fail_block.addTyOp(unwrap_err_tag, Type.anyerror, operand);
             const err_return_trace = try sema.getErrorReturnTrace(&fail_block, src);
             const args: [2]Air.Inst.Ref = .{ err_return_trace, err };
-            _ = try sema.analyzeCall(&fail_block, panic_fn, src, src, .auto, false, false, &args, null);
+            _ = try sema.analyzeCall(&fail_block, panic_fn, src, src, .auto, false, &args, null);
         }
     }
     try sema.addSafetyCheckExtra(parent_block, ok, &fail_block);
@@ -21930,7 +22037,7 @@ fn panicIndexOutOfBounds(
         } else {
             const panic_fn = try sema.getBuiltin(&fail_block, src, "panicOutOfBounds");
             const args: [2]Air.Inst.Ref = .{ index, len };
-            _ = try sema.analyzeCall(&fail_block, panic_fn, src, src, .auto, false, false, &args, null);
+            _ = try sema.analyzeCall(&fail_block, panic_fn, src, src, .auto, false, &args, null);
         }
     }
     try sema.addSafetyCheckExtra(parent_block, ok, &fail_block);
@@ -21972,7 +22079,7 @@ fn panicSentinelMismatch(
     else {
         const panic_fn = try sema.getBuiltin(parent_block, src, "checkNonScalarSentinel");
         const args: [2]Air.Inst.Ref = .{ expected_sentinel, actual_sentinel };
-        _ = try sema.analyzeCall(parent_block, panic_fn, src, src, .auto, false, false, &args, null);
+        _ = try sema.analyzeCall(parent_block, panic_fn, src, src, .auto, false, &args, null);
         return;
     };
     const gpa = sema.gpa;
@@ -22001,7 +22108,7 @@ fn panicSentinelMismatch(
         } else {
             const panic_fn = try sema.getBuiltin(&fail_block, src, "panicSentinelMismatch");
             const args: [2]Air.Inst.Ref = .{ expected_sentinel, actual_sentinel };
-            _ = try sema.analyzeCall(&fail_block, panic_fn, src, src, .auto, false, false, &args, null);
+            _ = try sema.analyzeCall(&fail_block, panic_fn, src, src, .auto, false, &args, null);
         }
     }
     try sema.addSafetyCheckExtra(parent_block, ok, &fail_block);
test/stack_traces.zig
@@ -260,6 +260,75 @@ pub fn addCases(cases: *tests.StackTracesContext) void {
         },
     });
 
+    cases.addCase(.{
+        .name = "error passed to function has its trace preserved for duration of the call",
+        .source = 
+        \\pub fn expectError(expected_error: anyerror, actual_error: anyerror!void) !void {
+        \\    actual_error catch |err| {
+        \\        if (err == expected_error) return {};
+        \\    };
+        \\    return error.TestExpectedError;
+        \\}
+        \\
+        \\fn alwaysErrors() !void { return error.ThisErrorShouldNotAppearInAnyTrace; }
+        \\fn foo() !void { return error.Foo; }
+        \\
+        \\pub fn main() !void {
+        \\    try expectError(error.ThisErrorShouldNotAppearInAnyTrace, alwaysErrors());
+        \\    try expectError(error.ThisErrorShouldNotAppearInAnyTrace, alwaysErrors());
+        \\    try expectError(error.Foo, foo());
+        \\
+        \\    // Only the error trace for this failing check should appear:
+        \\    try expectError(error.Bar, foo());
+        \\}
+        ,
+        .Debug = .{
+            .expect = 
+            \\error: TestExpectedError
+            \\source.zig:9:18: [address] in foo (test)
+            \\fn foo() !void { return error.Foo; }
+            \\                 ^
+            \\source.zig:5:5: [address] in expectError (test)
+            \\    return error.TestExpectedError;
+            \\    ^
+            \\source.zig:17:5: [address] in main (test)
+            \\    try expectError(error.Bar, foo());
+            \\    ^
+            \\
+            ,
+        },
+        .ReleaseSafe = .{
+            .exclude_os = .{
+                .windows, // TODO
+            },
+            .expect = 
+            \\error: TestExpectedError
+            \\source.zig:9:18: [address] in [function]
+            \\fn foo() !void { return error.Foo; }
+            \\                 ^
+            \\source.zig:5:5: [address] in [function]
+            \\    return error.TestExpectedError;
+            \\    ^
+            \\source.zig:17:5: [address] in [function]
+            \\    try expectError(error.Bar, foo());
+            \\    ^
+            \\
+            ,
+        },
+        .ReleaseFast = .{
+            .expect = 
+            \\error: TestExpectedError
+            \\
+            ,
+        },
+        .ReleaseSmall = .{
+            .expect = 
+            \\error: TestExpectedError
+            \\
+            ,
+        },
+    });
+
     cases.addCase(.{
         .name = "try return from within catch",
         .source =