Commit 3924f173af

mlugg <mlugg@mlugg.co.uk>
2025-02-01 09:07:14
compiler: do not propagate result type to `try` operand
This commit effectively reverts 9e683f0, and hence un-accepts #19777. While nice in theory, this proposal turned out to have a few problems. Firstly, supplying a result type implicitly coerces the operand to this type -- that's the main point of result types! But for `try`, this is actually a bad idea; we want a redundant `try` to be a compile error, not to silently coerce the non-error value to an error union. In practice, this didn't always happen, because the implementation was buggy anyway; but when it did, it was really quite silly. For instance, `try try ... try .{ ... }` was an accepted expression, with the inner initializer being initially coerced to `E!E!...E!T`. Secondly, the result type inference here didn't play nicely with `return`. If you write `return try`, the operand would actually receive a result type of `E!E!T`, since the `return` gave a result type of `E!T` and the `try` wrapped it in *another* error union. More generally, the problem here is that `try` doesn't know when it should or shouldn't nest error unions. This occasionally broke code which looked like it should work. So, this commit prevents `try` from propagating result types through to its operand. A key motivation for the original proposal here was decl literals; so, as a special case, `try .foo(...)` is still an allowed syntax form, caught by AstGen and specially lowered. This does open the doors to allowing other special cases for decl literals in future, such as `.foo(...) catch ...`, but those proposals are for another time. Resolves: #21991 Resolves: #22633
1 parent c225b78
Changed files (6)
lib
src
test
behavior
cases
compile_errors
lib/std/zig/AstGen.zig
@@ -851,7 +851,7 @@ fn expr(gz: *GenZir, scope: *Scope, ri: ResultInfo, node: Ast.Node.Index) InnerE
         .async_call_comma,
         => {
             var buf: [1]Ast.Node.Index = undefined;
-            return callExpr(gz, scope, ri, node, tree.fullCall(&buf, node).?);
+            return callExpr(gz, scope, ri, .none, node, tree.fullCall(&buf, node).?);
         },
 
         .unreachable_literal => {
@@ -3009,8 +3009,6 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As
             .validate_ptr_array_init,
             .validate_ref_ty,
             .validate_const,
-            .try_operand_ty,
-            .try_ref_operand_ty,
             => break :b true,
 
             .@"defer" => unreachable,
@@ -6158,20 +6156,24 @@ fn tryExpr(
     const try_lc: LineColumn = .{ astgen.source_line - parent_gz.decl_line, astgen.source_column };
 
     const operand_rl: ResultInfo.Loc, const block_tag: Zir.Inst.Tag = switch (ri.rl) {
-        .ref => .{ .ref, .try_ptr },
-        .ref_coerced_ty => |payload_ptr_ty| .{
-            .{ .ref_coerced_ty = try parent_gz.addUnNode(.try_ref_operand_ty, payload_ptr_ty, node) },
-            .try_ptr,
-        },
-        else => if (try ri.rl.resultType(parent_gz, node)) |payload_ty| .{
-            // `coerced_ty` is OK due to the `rvalue` call below
-            .{ .coerced_ty = try parent_gz.addUnNode(.try_operand_ty, payload_ty, node) },
-            .@"try",
-        } else .{ .none, .@"try" },
+        .ref, .ref_coerced_ty => .{ .ref, .try_ptr },
+        else => .{ .none, .@"try" },
     };
     const operand_ri: ResultInfo = .{ .rl = operand_rl, .ctx = .error_handling_expr };
-    // This could be a pointer or value depending on the `ri` parameter.
-    const operand = try reachableExpr(parent_gz, scope, operand_ri, operand_node, node);
+    const operand = operand: {
+        // As a special case, we need to detect this form:
+        // `try .foo(...)`
+        // This is a decl literal form, even though we don't propagate a result type through `try`.
+        var buf: [1]Ast.Node.Index = undefined;
+        if (astgen.tree.fullCall(&buf, operand_node)) |full_call| {
+            const res_ty: Zir.Inst.Ref = try ri.rl.resultType(parent_gz, operand_node) orelse .none;
+            break :operand try callExpr(parent_gz, scope, operand_ri, res_ty, operand_node, full_call);
+        }
+
+        // This could be a pointer or value depending on the `ri` parameter.
+        break :operand try reachableExpr(parent_gz, scope, operand_ri, operand_node, node);
+    };
+
     const try_inst = try parent_gz.makeBlockInst(block_tag, node);
     try parent_gz.instructions.append(astgen.gpa, try_inst);
 
@@ -10236,12 +10238,15 @@ fn callExpr(
     gz: *GenZir,
     scope: *Scope,
     ri: ResultInfo,
+    /// If this is not `.none` and this call is a decl literal form (`.foo(...)`), then this
+    /// type is used as the decl literal result type instead of the result type from `ri.rl`.
+    override_decl_literal_type: Zir.Inst.Ref,
     node: Ast.Node.Index,
     call: Ast.full.Call,
 ) InnerError!Zir.Inst.Ref {
     const astgen = gz.astgen;
 
-    const callee = try calleeExpr(gz, scope, ri.rl, call.ast.fn_expr);
+    const callee = try calleeExpr(gz, scope, ri.rl, override_decl_literal_type, call.ast.fn_expr);
     const modifier: std.builtin.CallModifier = blk: {
         if (call.async_token != null) {
             break :blk .async_kw;
@@ -10367,6 +10372,9 @@ fn calleeExpr(
     gz: *GenZir,
     scope: *Scope,
     call_rl: ResultInfo.Loc,
+    /// If this is not `.none` and this call is a decl literal form (`.foo(...)`), then this
+    /// type is used as the decl literal result type instead of the result type from `call_rl`.
+    override_decl_literal_type: Zir.Inst.Ref,
     node: Ast.Node.Index,
 ) InnerError!Callee {
     const astgen = gz.astgen;
@@ -10393,7 +10401,14 @@ fn calleeExpr(
                 .field_name_start = str_index,
             } };
         },
-        .enum_literal => if (try call_rl.resultType(gz, node)) |res_ty| {
+        .enum_literal => {
+            const res_ty = res_ty: {
+                if (override_decl_literal_type != .none) break :res_ty override_decl_literal_type;
+                break :res_ty try call_rl.resultType(gz, node) orelse {
+                    // No result type; lower to a literal call of an enum literal.
+                    return .{ .direct = try expr(gz, scope, .{ .rl = .none }, node) };
+                };
+            };
             // Decl literal call syntax, e.g.
             // `const foo: T = .init();`
             // Look up `init` in `T`, but don't try and coerce it.
@@ -10403,8 +10418,6 @@ fn calleeExpr(
                 .field_name_start = str_index,
             });
             return .{ .direct = callee };
-        } else {
-            return .{ .direct = try expr(gz, scope, .{ .rl = .none }, node) };
         },
         else => return .{ .direct = try expr(gz, scope, .{ .rl = .none }, node) },
     }
lib/std/zig/Zir.zig
@@ -721,14 +721,6 @@ pub const Inst = struct {
         /// Result is always void.
         /// Uses the `un_node` union field. Node is the initializer. Operand is the initializer value.
         validate_const,
-        /// Given a type `T`, construct the type `E!T`, where `E` is this function's error set, to be used
-        /// as the result type of a `try` operand. Generic poison is propagated.
-        /// Uses the `un_node` union field. Node is the `try` expression. Operand is the type `T`.
-        try_operand_ty,
-        /// Given a type `*T`, construct the type `*E!T`, where `E` is this function's error set, to be used
-        /// as the result type of a `try` operand whose address is taken with `&`. Generic poison is propagated.
-        /// Uses the `un_node` union field. Node is the `try` expression. Operand is the type `*T`.
-        try_ref_operand_ty,
 
         // The following tags all relate to struct initialization expressions.
 
@@ -1304,8 +1296,6 @@ pub const Inst = struct {
                 .array_init_elem_ptr,
                 .validate_ref_ty,
                 .validate_const,
-                .try_operand_ty,
-                .try_ref_operand_ty,
                 .restore_err_ret_index_unconditional,
                 .restore_err_ret_index_fn_entry,
                 => false,
@@ -1365,8 +1355,6 @@ pub const Inst = struct {
                 .validate_ptr_array_init,
                 .validate_ref_ty,
                 .validate_const,
-                .try_operand_ty,
-                .try_ref_operand_ty,
                 => true,
 
                 .param,
@@ -1749,8 +1737,6 @@ pub const Inst = struct {
                 .coerce_ptr_elem_ty = .pl_node,
                 .validate_ref_ty = .un_tok,
                 .validate_const = .un_node,
-                .try_operand_ty = .un_node,
-                .try_ref_operand_ty = .un_node,
 
                 .int_from_ptr = .un_node,
                 .compile_error = .un_node,
@@ -4196,8 +4182,6 @@ fn findTrackableInner(
         .coerce_ptr_elem_ty,
         .validate_ref_ty,
         .validate_const,
-        .try_operand_ty,
-        .try_ref_operand_ty,
         .struct_init_empty,
         .struct_init_empty_result,
         .struct_init_empty_ref_result,
src/print_zir.zig
@@ -278,8 +278,6 @@ const Writer = struct {
             .opt_eu_base_ptr_init,
             .restore_err_ret_index_unconditional,
             .restore_err_ret_index_fn_entry,
-            .try_operand_ty,
-            .try_ref_operand_ty,
             => try self.writeUnNode(stream, inst),
 
             .ref,
src/Sema.zig
@@ -1257,8 +1257,6 @@ fn analyzeBodyInner(
             .validate_array_init_ref_ty   => try sema.zirValidateArrayInitRefTy(block, inst),
             .opt_eu_base_ptr_init         => try sema.zirOptEuBasePtrInit(block, inst),
             .coerce_ptr_elem_ty           => try sema.zirCoercePtrElemTy(block, inst),
-            .try_operand_ty               => try sema.zirTryOperandTy(block, inst, false),
-            .try_ref_operand_ty           => try sema.zirTryOperandTy(block, inst, true),
 
             .clz       => try sema.zirBitCount(block, inst, .clz,      Value.clz),
             .ctz       => try sema.zirBitCount(block, inst, .ctz,      Value.ctz),
@@ -2085,20 +2083,6 @@ fn genericPoisonReason(sema: *Sema, block: *Block, ref: Zir.Inst.Ref) GenericPoi
                 const un_node = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node;
                 cur = un_node.operand;
             },
-            .try_operand_ty => {
-                // Either the input type was itself poison, or it was a slice, which we cannot translate
-                // to an overall result type.
-                const un_node = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node;
-                const operand_ref = try sema.resolveInst(un_node.operand);
-                if (operand_ref == .generic_poison_type) {
-                    // The input was poison -- keep looking.
-                    cur = un_node.operand;
-                    continue;
-                }
-                // We got a poison because the result type was a slice. This is a tricky case -- let's just
-                // not bother explaining it to the user for now...
-                return .unknown;
-            },
             .struct_init_field_type => {
                 const pl_node = sema.code.instructions.items(.data)[@intFromEnum(inst)].pl_node;
                 const extra = sema.code.extraData(Zir.Inst.FieldType, pl_node.payload_index).data;
test/behavior/try.zig
@@ -68,25 +68,6 @@ test "`try`ing an if/else expression" {
     try std.testing.expectError(error.Test, S.getError2());
 }
 
-test "try forwards result location" {
-    if (builtin.zig_backend == .stage2_x86) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
-
-    const S = struct {
-        fn foo(err: bool) error{Foo}!u32 {
-            const result: error{ Foo, Bar }!u32 = if (err) error.Foo else 123;
-            const res_int: u32 = try @errorCast(result);
-            return res_int;
-        }
-    };
-
-    try expect((S.foo(false) catch return error.TestUnexpectedResult) == 123);
-    try std.testing.expectError(error.Foo, S.foo(true));
-}
-
 test "'return try' of empty error set in function returning non-error" {
     if (builtin.zig_backend == .stage2_x86) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
@@ -123,3 +104,24 @@ test "'return try' of empty error set in function returning non-error" {
     try S.doTheTest();
     try comptime S.doTheTest();
 }
+
+test "'return try' through conditional" {
+    const S = struct {
+        fn get(t: bool) !u32 {
+            return try if (t) inner() else error.TestFailed;
+        }
+        fn inner() !u16 {
+            return 123;
+        }
+    };
+
+    {
+        const result = try S.get(true);
+        try expect(result == 123);
+    }
+
+    {
+        const result = try comptime S.get(true);
+        comptime std.debug.assert(result == 123);
+    }
+}
test/cases/compile_errors/redundant_try.zig
@@ -0,0 +1,52 @@
+const S = struct { x: u32 = 0 };
+const T = struct { []const u8 };
+
+fn test0() !void {
+    const x: u8 = try 1;
+    _ = x;
+}
+
+fn test1() !void {
+    const x: S = try .{};
+    _ = x;
+}
+
+fn test2() !void {
+    const x: S = try .{ .x = 123 };
+    _ = x;
+}
+
+fn test3() !void {
+    const x: S = try try .{ .x = 123 };
+    _ = x;
+}
+
+fn test4() !void {
+    const x: T = try .{"hello"};
+    _ = x;
+}
+
+fn test5() !void {
+    const x: error{Foo}!u32 = 123;
+    _ = try try x;
+}
+
+comptime {
+    _ = &test0;
+    _ = &test1;
+    _ = &test2;
+    _ = &test3;
+    _ = &test4;
+    _ = &test5;
+}
+
+// error
+//
+// :5:23: error: expected error union type, found 'comptime_int'
+// :10:23: error: expected error union type, found '@TypeOf(.{})'
+// :15:23: error: expected error union type, found 'tmp.test2__struct_493'
+// :15:23: note: struct declared here
+// :20:27: error: expected error union type, found 'tmp.test3__struct_495'
+// :20:27: note: struct declared here
+// :25:23: error: expected error union type, found 'struct { comptime *const [5:0]u8 = "hello" }'
+// :31:13: error: expected error union type, found 'u32'