Commit 8f2e82dbf6

Veikka Tuominen <git@vexu.eu>
2022-09-10 22:09:09
safety: show error return trace when unwrapping error in switch
1 parent 62ecc15
src/AstGen.zig
@@ -884,33 +884,6 @@ fn expr(gz: *GenZir, scope: *Scope, rl: ResultLoc, node: Ast.Node.Index) InnerEr
                 catch_token + 2
             else
                 null;
-
-            var rhs = node_datas[node].rhs;
-            while (true) switch (node_tags[rhs]) {
-                .grouped_expression => rhs = node_datas[rhs].lhs,
-                .unreachable_literal => {
-                    if (payload_token != null and mem.eql(u8, tree.tokenSlice(payload_token.?), "_")) {
-                        return astgen.failTok(payload_token.?, "discard of error capture; omit it instead", .{});
-                    } else if (payload_token != null) {
-                        return astgen.failTok(payload_token.?, "unused capture", .{});
-                    }
-                    const lhs = node_datas[node].lhs;
-
-                    const operand = try reachableExpr(gz, scope, switch (rl) {
-                        .ref => .ref,
-                        else => .none,
-                    }, lhs, lhs);
-                    const result = try gz.addUnNode(switch (rl) {
-                        .ref => .err_union_payload_safe_ptr,
-                        else => .err_union_payload_safe,
-                    }, operand, node);
-                    switch (rl) {
-                        .none, .coerced_ty, .discard, .ref => return result,
-                        else => return rvalue(gz, rl, result, lhs),
-                    }
-                },
-                else => break,
-            };
             switch (rl) {
                 .ref => return orelseCatchExpr(
                     gz,
@@ -2375,9 +2348,7 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As
             .optional_payload_unsafe,
             .optional_payload_safe_ptr,
             .optional_payload_unsafe_ptr,
-            .err_union_payload_safe,
             .err_union_payload_unsafe,
-            .err_union_payload_safe_ptr,
             .err_union_payload_unsafe_ptr,
             .err_union_code,
             .err_union_code_ptr,
src/print_zir.zig
@@ -170,9 +170,7 @@ const Writer = struct {
             .optional_payload_unsafe,
             .optional_payload_safe_ptr,
             .optional_payload_unsafe_ptr,
-            .err_union_payload_safe,
             .err_union_payload_unsafe,
-            .err_union_payload_safe_ptr,
             .err_union_payload_unsafe_ptr,
             .err_union_code,
             .err_union_code_ptr,
src/Sema.zig
@@ -747,10 +747,8 @@ fn analyzeBodyInner(
             .int_to_enum                  => try sema.zirIntToEnum(block, inst),
             .err_union_code               => try sema.zirErrUnionCode(block, inst),
             .err_union_code_ptr           => try sema.zirErrUnionCodePtr(block, inst),
-            .err_union_payload_safe       => try sema.zirErrUnionPayload(block, inst, true),
-            .err_union_payload_safe_ptr   => try sema.zirErrUnionPayloadPtr(block, inst, true),
-            .err_union_payload_unsafe     => try sema.zirErrUnionPayload(block, inst, false),
-            .err_union_payload_unsafe_ptr => try sema.zirErrUnionPayloadPtr(block, inst, false),
+            .err_union_payload_unsafe     => try sema.zirErrUnionPayload(block, inst),
+            .err_union_payload_unsafe_ptr => try sema.zirErrUnionPayloadPtr(block, inst),
             .error_union_type             => try sema.zirErrorUnionType(block, inst),
             .error_value                  => try sema.zirErrorValue(block, inst),
             .field_ptr                    => try sema.zirFieldPtr(block, inst, false),
@@ -1355,6 +1353,8 @@ fn analyzeBodyInner(
                 const else_body = sema.code.extra[extra.end + then_body.len ..][0..extra.data.else_body_len];
                 const cond = try sema.resolveInstConst(block, cond_src, extra.data.condition, "condition in comptime branch must be comptime known");
                 const inline_body = if (cond.val.toBool()) then_body else else_body;
+
+                try sema.maybeErrorUnwrapCondbr(block, inline_body, extra.data.condition, cond_src);
                 const break_data = (try sema.analyzeBodyBreak(block, inline_body)) orelse
                     break always_noreturn;
                 if (inst == break_data.block_inst) {
@@ -7426,7 +7426,6 @@ fn zirErrUnionPayload(
     sema: *Sema,
     block: *Block,
     inst: Zir.Inst.Index,
-    safety_check: bool,
 ) CompileError!Air.Inst.Ref {
     const tracy = trace(@src());
     defer tracy.end();
@@ -7441,7 +7440,7 @@ fn zirErrUnionPayload(
             err_union_ty.fmt(sema.mod),
         });
     }
-    return sema.analyzeErrUnionPayload(block, src, err_union_ty, operand, operand_src, safety_check);
+    return sema.analyzeErrUnionPayload(block, src, err_union_ty, operand, operand_src, false);
 }
 
 fn analyzeErrUnionPayload(
@@ -7479,7 +7478,6 @@ fn zirErrUnionPayloadPtr(
     sema: *Sema,
     block: *Block,
     inst: Zir.Inst.Index,
-    safety_check: bool,
 ) CompileError!Air.Inst.Ref {
     const tracy = trace(@src());
     defer tracy.end();
@@ -7488,7 +7486,7 @@ fn zirErrUnionPayloadPtr(
     const operand = try sema.resolveInst(inst_data.operand);
     const src = inst_data.src();
 
-    return sema.analyzeErrUnionPayloadPtr(block, src, operand, safety_check, false);
+    return sema.analyzeErrUnionPayloadPtr(block, src, operand, false, false);
 }
 
 fn analyzeErrUnionPayloadPtr(
@@ -9247,6 +9245,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
     var empty_enum = false;
 
     const operand_ty = sema.typeOf(operand);
+    const err_set = operand_ty.zigTypeTag() == .ErrorSet;
 
     var else_error_ty: ?Type = null;
 
@@ -9829,6 +9828,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                 // Validation above ensured these will succeed.
                 const item_val = sema.resolveConstValue(&child_block, .unneeded, item, undefined) catch unreachable;
                 if (operand_val.eql(item_val, operand_ty, sema.mod)) {
+                    if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand);
                     return sema.resolveBlockBody(block, src, &child_block, body, inst, merges);
                 }
             }
@@ -9851,6 +9851,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                     // Validation above ensured these will succeed.
                     const item_val = sema.resolveConstValue(&child_block, .unneeded, item, undefined) catch unreachable;
                     if (operand_val.eql(item_val, operand_ty, sema.mod)) {
+                        if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand);
                         return sema.resolveBlockBody(block, src, &child_block, body, inst, merges);
                     }
                 }
@@ -9868,6 +9869,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                     if ((try sema.compare(block, src, operand_val, .gte, first_tv.val, operand_ty)) and
                         (try sema.compare(block, src, operand_val, .lte, last_tv.val, operand_ty)))
                     {
+                        if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand);
                         return sema.resolveBlockBody(block, src, &child_block, body, inst, merges);
                     }
                 }
@@ -9875,6 +9877,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                 extra_index += body_len;
             }
         }
+        if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, special.body, operand);
         return sema.resolveBlockBody(block, src, &child_block, special.body, inst, merges);
     }
 
@@ -9885,6 +9888,9 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         if (special_prong == .none) {
             return sema.fail(block, src, "switch must handle all possibilities", .{});
         }
+        if (err_set and try sema.maybeErrorUnwrap(block, special.body, operand)) {
+            return Air.Inst.Ref.unreachable_value;
+        }
         return sema.resolveBlockBody(block, src, &child_block, special.body, inst, merges);
     }
 
@@ -9927,7 +9933,9 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
             break :blk field_ty.zigTypeTag() != .NoReturn;
         } else true;
 
-        if (analyze_body) {
+        if (err_set and try sema.maybeErrorUnwrap(&case_block, body, operand)) {
+            // nothing to do here
+        } else if (analyze_body) {
             _ = sema.analyzeBodyInner(&case_block, body) catch |err| switch (err) {
                 error.ComptimeBreak => {
                     const zir_datas = sema.code.instructions.items(.data);
@@ -9995,7 +10003,9 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
 
             const body = sema.code.extra[extra_index..][0..body_len];
             extra_index += body_len;
-            if (analyze_body) {
+            if (err_set and try sema.maybeErrorUnwrap(&case_block, body, operand)) {
+                // nothing to do here
+            } else if (analyze_body) {
                 _ = sema.analyzeBodyInner(&case_block, body) catch |err| switch (err) {
                     error.ComptimeBreak => {
                         const zir_datas = sema.code.instructions.items(.data);
@@ -10085,18 +10095,22 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
 
             const body = sema.code.extra[extra_index..][0..body_len];
             extra_index += body_len;
-            _ = sema.analyzeBodyInner(&case_block, body) catch |err| switch (err) {
-                error.ComptimeBreak => {
-                    const zir_datas = sema.code.instructions.items(.data);
-                    const break_data = zir_datas[sema.comptime_break_inst].@"break";
-                    try sema.addRuntimeBreak(&case_block, .{
-                        .block_inst = break_data.block_inst,
-                        .operand = break_data.operand,
-                        .inst = sema.comptime_break_inst,
-                    });
-                },
-                else => |e| return e,
-            };
+            if (err_set and try sema.maybeErrorUnwrap(&case_block, body, operand)) {
+                // nothing to do here
+            } else {
+                _ = sema.analyzeBodyInner(&case_block, body) catch |err| switch (err) {
+                    error.ComptimeBreak => {
+                        const zir_datas = sema.code.instructions.items(.data);
+                        const break_data = zir_datas[sema.comptime_break_inst].@"break";
+                        try sema.addRuntimeBreak(&case_block, .{
+                            .block_inst = break_data.block_inst,
+                            .operand = break_data.operand,
+                            .inst = sema.comptime_break_inst,
+                        });
+                    },
+                    else => |e| return e,
+                };
+            }
 
             try wip_captures.finalize();
 
@@ -10141,8 +10155,11 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
             } else false
         else
             true;
-
-        if (special.body.len != 0 and analyze_body) {
+        if (special.body.len != 0 and err_set and
+            try sema.maybeErrorUnwrap(&case_block, special.body, operand))
+        {
+            // nothing to do here
+        } else if (special.body.len != 0 and analyze_body) {
             _ = sema.analyzeBodyInner(&case_block, special.body) catch |err| switch (err) {
                 error.ComptimeBreak => {
                     const zir_datas = sema.code.instructions.items(.data);
@@ -10400,6 +10417,109 @@ fn validateSwitchNoRange(
     return sema.failWithOwnedErrorMsg(msg);
 }
 
+fn maybeErrorUnwrap(sema: *Sema, block: *Block, body: []const Zir.Inst.Index, operand: Air.Inst.Ref) !bool {
+    const this_feature_is_implemented_in_the_backend =
+        sema.mod.comp.bin_file.options.use_llvm;
+
+    if (!this_feature_is_implemented_in_the_backend) return false;
+
+    const tags = sema.code.instructions.items(.tag);
+    for (body) |inst| {
+        switch (tags[inst]) {
+            .dbg_block_begin,
+            .dbg_block_end,
+            .dbg_stmt,
+            .@"unreachable",
+            .str,
+            .as_node,
+            .panic,
+            .field_val,
+            => {},
+            else => return false,
+        }
+    }
+
+    for (body) |inst| {
+        const air_inst = switch (tags[inst]) {
+            .dbg_block_begin,
+            .dbg_block_end,
+            => continue,
+            .dbg_stmt => {
+                try sema.zirDbgStmt(block, inst);
+                continue;
+            },
+            .str => try sema.zirStr(block, inst),
+            .as_node => try sema.zirAsNode(block, inst),
+            .field_val => try sema.zirFieldVal(block, inst),
+            .@"unreachable" => {
+                const inst_data = sema.code.instructions.items(.data)[inst].@"unreachable";
+                const src = inst_data.src();
+
+                const panic_fn = try sema.getBuiltin(block, src, "panicUnwrapError");
+                const err_return_trace = try sema.getErrorReturnTrace(block, src);
+                const args: [2]Air.Inst.Ref = .{ err_return_trace, operand };
+                _ = try sema.analyzeCall(block, panic_fn, src, src, .auto, false, &args, null);
+                return true;
+            },
+            .panic => {
+                const inst_data = sema.code.instructions.items(.data)[inst].un_node;
+                const src = inst_data.src();
+                const msg_inst = try sema.resolveInst(inst_data.operand);
+
+                const panic_fn = try sema.getBuiltin(block, src, "panic");
+                const err_return_trace = try sema.getErrorReturnTrace(block, src);
+                const args: [2]Air.Inst.Ref = .{ msg_inst, err_return_trace };
+                _ = try sema.analyzeCall(block, panic_fn, src, src, .auto, false, &args, null);
+                return true;
+            },
+            else => unreachable,
+        };
+        if (sema.typeOf(air_inst).isNoReturn())
+            return true;
+        try sema.inst_map.put(sema.gpa, inst, air_inst);
+    }
+    unreachable;
+}
+
+fn maybeErrorUnwrapCondbr(sema: *Sema, block: *Block, body: []const Zir.Inst.Index, cond: Zir.Inst.Ref, cond_src: LazySrcLoc) !void {
+    const index = Zir.refToIndex(cond) orelse return;
+    if (sema.code.instructions.items(.tag)[index] != .is_non_err) return;
+
+    const err_inst_data = sema.code.instructions.items(.data)[index].un_node;
+    const err_operand = try sema.resolveInst(err_inst_data.operand);
+    const operand_ty = sema.typeOf(err_operand);
+    if (operand_ty.zigTypeTag() == .ErrorSet) {
+        try sema.maybeErrorUnwrapComptime(block, body, err_operand);
+        return;
+    }
+    if (try sema.resolveDefinedValue(block, cond_src, err_operand)) |val| {
+        if (val.getError() == null) return;
+        try sema.maybeErrorUnwrapComptime(block, body, err_operand);
+    }
+}
+
+fn maybeErrorUnwrapComptime(sema: *Sema, block: *Block, body: []const Zir.Inst.Index, operand: Air.Inst.Ref) !void {
+    const tags = sema.code.instructions.items(.tag);
+    const inst = for (body) |inst| {
+        switch (tags[inst]) {
+            .dbg_block_begin,
+            .dbg_block_end,
+            .dbg_stmt,
+            => {},
+            .@"unreachable" => break inst,
+            else => return,
+        }
+    } else return;
+    const inst_data = sema.code.instructions.items(.data)[inst].@"unreachable";
+    const src = inst_data.src();
+
+    if (try sema.resolveDefinedValue(block, src, operand)) |val| {
+        if (val.getError()) |name| {
+            return sema.fail(block, src, "caught unexpected error '{s}'", .{name});
+        }
+    }
+}
+
 fn zirHasField(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
     const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
     const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
@@ -15152,6 +15272,8 @@ fn zirCondbr(
 
     if (try sema.resolveDefinedValue(parent_block, cond_src, cond)) |cond_val| {
         const body = if (cond_val.toBool()) then_body else else_body;
+
+        try sema.maybeErrorUnwrapCondbr(parent_block, body, extra.data.condition, cond_src);
         // We use `analyzeBodyInner` since we want to propagate any possible
         // `error.ComptimeBreak` to the caller.
         return sema.analyzeBodyInner(parent_block, body);
@@ -15182,18 +15304,34 @@ fn zirCondbr(
     const true_instructions = sub_block.instructions.toOwnedSlice(gpa);
     defer gpa.free(true_instructions);
 
-    _ = sema.analyzeBodyInner(&sub_block, else_body) catch |err| switch (err) {
-        error.ComptimeBreak => {
-            const zir_datas = sema.code.instructions.items(.data);
-            const break_data = zir_datas[sema.comptime_break_inst].@"break";
-            try sema.addRuntimeBreak(&sub_block, .{
-                .block_inst = break_data.block_inst,
-                .operand = break_data.operand,
-                .inst = sema.comptime_break_inst,
-            });
-        },
-        else => |e| return e,
+    const err_cond = blk: {
+        const index = Zir.refToIndex(extra.data.condition) orelse break :blk null;
+        if (sema.code.instructions.items(.tag)[index] != .is_non_err) break :blk null;
+
+        const err_inst_data = sema.code.instructions.items(.data)[index].un_node;
+        const err_operand = try sema.resolveInst(err_inst_data.operand);
+        const operand_ty = sema.typeOf(err_operand);
+        assert(operand_ty.zigTypeTag() == .ErrorUnion);
+        const result_ty = operand_ty.errorUnionSet();
+        break :blk try sub_block.addTyOp(.unwrap_errunion_err, result_ty, err_operand);
     };
+
+    if (err_cond != null and try sema.maybeErrorUnwrap(&sub_block, else_body, err_cond.?)) {
+        // nothing to do
+    } else {
+        _ = sema.analyzeBodyInner(&sub_block, else_body) catch |err| switch (err) {
+            error.ComptimeBreak => {
+                const zir_datas = sema.code.instructions.items(.data);
+                const break_data = zir_datas[sema.comptime_break_inst].@"break";
+                try sema.addRuntimeBreak(&sub_block, .{
+                    .block_inst = break_data.block_inst,
+                    .operand = break_data.operand,
+                    .inst = sema.comptime_break_inst,
+                });
+            },
+            else => |e| return e,
+        };
+    }
     try sema.air_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.CondBr).Struct.fields.len +
         true_instructions.len + sub_block.instructions.items.len);
     _ = try parent_block.addInst(.{
src/Zir.zig
@@ -629,20 +629,10 @@ pub const Inst = struct {
         /// No safety checks.
         /// Uses the `un_node` field.
         optional_payload_unsafe_ptr,
-        /// E!T => T with safety.
-        /// Given an error union value, returns the payload value, with a safety check
-        /// that the value is not an error. Used for catch, if, and while.
-        /// Uses the `un_node` field.
-        err_union_payload_safe,
         /// E!T => T without safety.
         /// Given an error union value, returns the payload value. No safety checks.
         /// Uses the `un_node` field.
         err_union_payload_unsafe,
-        /// *E!T => *T with safety.
-        /// Given a pointer to an error union value, returns a pointer to the payload value,
-        /// with a safety check that the value is not an error. Used for catch, if, and while.
-        /// Uses the `un_node` field.
-        err_union_payload_safe_ptr,
         /// *E!T => *T without safety.
         /// Given a pointer to a error union value, returns a pointer to the payload value.
         /// No safety checks.
@@ -1120,9 +1110,7 @@ pub const Inst = struct {
                 .optional_payload_unsafe,
                 .optional_payload_safe_ptr,
                 .optional_payload_unsafe_ptr,
-                .err_union_payload_safe,
                 .err_union_payload_unsafe,
-                .err_union_payload_safe_ptr,
                 .err_union_payload_unsafe_ptr,
                 .err_union_code,
                 .err_union_code_ptr,
@@ -1421,9 +1409,7 @@ pub const Inst = struct {
                 .optional_payload_unsafe,
                 .optional_payload_safe_ptr,
                 .optional_payload_unsafe_ptr,
-                .err_union_payload_safe,
                 .err_union_payload_unsafe,
-                .err_union_payload_safe_ptr,
                 .err_union_payload_unsafe_ptr,
                 .err_union_code,
                 .err_union_code_ptr,
@@ -1692,9 +1678,7 @@ pub const Inst = struct {
                 .optional_payload_unsafe = .un_node,
                 .optional_payload_safe_ptr = .un_node,
                 .optional_payload_unsafe_ptr = .un_node,
-                .err_union_payload_safe = .un_node,
                 .err_union_payload_unsafe = .un_node,
-                .err_union_payload_safe_ptr = .un_node,
                 .err_union_payload_unsafe_ptr = .un_node,
                 .err_union_code = .un_node,
                 .err_union_code_ptr = .un_node,
test/cases/safety/unwrap error switch.zig
@@ -0,0 +1,21 @@
+const std = @import("std");
+
+pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn {
+    _ = stack_trace;
+    if (std.mem.eql(u8, message, "attempt to unwrap error: Whatever")) {
+        std.process.exit(0);
+    }
+    std.process.exit(1);
+}
+pub fn main() !void {
+    bar() catch |err| switch (err) {
+        error.Whatever => unreachable,
+    };
+    return error.TestFailed;
+}
+fn bar() !void {
+    return error.Whatever;
+}
+// run
+// backend=llvm
+// target=native