Commit 9e683f0f35

mlugg <mlugg@mlugg.co.uk>
2024-08-31 02:25:23
compiler: provide result type to operand of `try`
This is mainly useful in conjunction with Decl Literals (#9938). Resolves: #19777
1 parent fbac7af
Changed files (5)
lib
src
test
behavior
lib/std/zig/AstGen.zig
@@ -2914,6 +2914,8 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As
             .validate_array_init_result_ty,
             .validate_ptr_array_init,
             .validate_ref_ty,
+            .try_operand_ty,
+            .try_ref_operand_ty,
             => break :b true,
 
             .@"defer" => unreachable,
@@ -5887,9 +5889,18 @@ fn tryExpr(
     }
     const try_lc = LineColumn{ astgen.source_line - parent_gz.decl_line, astgen.source_column };
 
-    const operand_ri: ResultInfo = switch (ri.rl) {
-        .ref, .ref_coerced_ty => .{ .rl = .ref, .ctx = .error_handling_expr },
-        else => .{ .rl = .none, .ctx = .error_handling_expr },
+    const operand_ri: ResultInfo = .{
+        .rl = switch (ri.rl) {
+            .ref => .ref,
+            .ref_coerced_ty => |payload_ptr_ty| .{
+                .ref_coerced_ty = try parent_gz.addUnNode(.try_ref_operand_ty, payload_ptr_ty, node),
+            },
+            else => if (try ri.rl.resultType(parent_gz, node)) |payload_ty| .{
+                // `coerced_ty` is OK due to the `rvalue` call below
+                .coerced_ty = try parent_gz.addUnNode(.try_operand_ty, payload_ty, node),
+            } else .none,
+        },
+        .ctx = .error_handling_expr,
     };
     // This could be a pointer or value depending on the `ri` parameter.
     const operand = try reachableExpr(parent_gz, scope, operand_ri, operand_node, node);
lib/std/zig/Zir.zig
@@ -684,6 +684,14 @@ pub const Inst = struct {
         /// operator. Emit a compile error if not.
         /// Uses the `un_tok` union field. Token is the `&` operator. Operand is the type.
         validate_ref_ty,
+        /// Given a type `T`, construct the type `E!T`, where `E` is this function's error set, to be used
+        /// as the result type of a `try` operand. Generic poison is propagated.
+        /// Uses the `un_node` union field. Node is the `try` expression. Operand is the type `T`.
+        try_operand_ty,
+        /// Given a type `*T`, construct the type `*E!T`, where `E` is this function's error set, to be used
+        /// as the result type of a `try` operand whose address is taken with `&`. Generic poison is propagated.
+        /// Uses the `un_node` union field. Node is the `try` expression. Operand is the type `*T`.
+        try_ref_operand_ty,
 
         // The following tags all relate to struct initialization expressions.
 
@@ -1254,6 +1262,8 @@ pub const Inst = struct {
                 .array_init_elem_type,
                 .array_init_elem_ptr,
                 .validate_ref_ty,
+                .try_operand_ty,
+                .try_ref_operand_ty,
                 .restore_err_ret_index_unconditional,
                 .restore_err_ret_index_fn_entry,
                 => false,
@@ -1324,6 +1334,8 @@ pub const Inst = struct {
                 .validate_array_init_result_ty,
                 .validate_ptr_array_init,
                 .validate_ref_ty,
+                .try_operand_ty,
+                .try_ref_operand_ty,
                 => true,
 
                 .param,
@@ -1698,6 +1710,8 @@ pub const Inst = struct {
                 .opt_eu_base_ptr_init = .un_node,
                 .coerce_ptr_elem_ty = .pl_node,
                 .validate_ref_ty = .un_tok,
+                .try_operand_ty = .un_node,
+                .try_ref_operand_ty = .un_node,
 
                 .int_from_ptr = .un_node,
                 .compile_error = .un_node,
@@ -3834,6 +3848,8 @@ fn findDeclsInner(
         .opt_eu_base_ptr_init,
         .coerce_ptr_elem_ty,
         .validate_ref_ty,
+        .try_operand_ty,
+        .try_ref_operand_ty,
         .struct_init_empty,
         .struct_init_empty_result,
         .struct_init_empty_ref_result,
src/print_zir.zig
@@ -277,6 +277,8 @@ const Writer = struct {
             .opt_eu_base_ptr_init,
             .restore_err_ret_index_unconditional,
             .restore_err_ret_index_fn_entry,
+            .try_operand_ty,
+            .try_ref_operand_ty,
             => try self.writeUnNode(stream, inst),
 
             .ref,
src/Sema.zig
@@ -1177,6 +1177,8 @@ fn analyzeBodyInner(
             .validate_array_init_ref_ty   => try sema.zirValidateArrayInitRefTy(block, inst),
             .opt_eu_base_ptr_init         => try sema.zirOptEuBasePtrInit(block, inst),
             .coerce_ptr_elem_ty           => try sema.zirCoercePtrElemTy(block, inst),
+            .try_operand_ty               => try sema.zirTryOperandTy(block, inst, false),
+            .try_ref_operand_ty           => try sema.zirTryOperandTy(block, inst, true),
 
             .clz       => try sema.zirBitCount(block, inst, .clz,      Value.clz),
             .ctz       => try sema.zirBitCount(block, inst, .ctz,      Value.ctz),
@@ -2024,6 +2026,22 @@ fn genericPoisonReason(sema: *Sema, block: *Block, ref: Zir.Inst.Ref) GenericPoi
                 const un_node = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node;
                 cur = un_node.operand;
             },
+            .try_operand_ty => {
+                // Either the input type was itself poison, or it was a slice, which we cannot translate
+                // to an overall result type.
+                const un_node = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node;
+                const operand_ref = sema.resolveInst(un_node.operand) catch |err| switch (err) {
+                    error.GenericPoison => unreachable, // this is a type, not a value
+                };
+                if (operand_ref == .generic_poison_type) {
+                    // The input was poison -- keep looking.
+                    cur = un_node.operand;
+                    continue;
+                }
+                // We got a poison because the result type was a slice. This is a tricky case -- let's just
+                // not bother explaining it to the user for now...
+                return .unknown;
+            },
             .struct_init_field_type => {
                 const pl_node = sema.code.instructions.items(.data)[@intFromEnum(inst)].pl_node;
                 const extra = sema.code.extraData(Zir.Inst.FieldType, pl_node.payload_index).data;
@@ -4423,6 +4441,59 @@ fn zirCoercePtrElemTy(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileE
     }
 }
 
+fn zirTryOperandTy(sema: *Sema, block: *Block, inst: Zir.Inst.Index, is_ref: bool) CompileError!Air.Inst.Ref {
+    const pt = sema.pt;
+    const zcu = pt.zcu;
+    const un_node = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node;
+    const src = block.nodeOffset(un_node.src_node);
+
+    const operand_ty = sema.resolveType(block, src, un_node.operand) catch |err| switch (err) {
+        error.GenericPoison => return .generic_poison_type,
+        else => |e| return e,
+    };
+
+    const payload_ty = if (is_ref) ty: {
+        if (!operand_ty.isSinglePointer(zcu)) {
+            return .generic_poison_type; // we can't get a meaningful result type here, since it will be `*E![n]T`, and we don't know `n`.
+        }
+        break :ty operand_ty.childType(zcu);
+    } else operand_ty;
+
+    const err_set_ty = err_set: {
+        // There are awkward cases, like `?E`. Our strategy is to repeatedly unwrap optionals
+        // until we hit an error union or set.
+        var cur_ty = sema.fn_ret_ty;
+        while (true) {
+            switch (cur_ty.zigTypeTag(zcu)) {
+                .error_set => break :err_set cur_ty,
+                .error_union => break :err_set cur_ty.errorUnionSet(zcu),
+                .optional => cur_ty = cur_ty.optionalChild(zcu),
+                else => return sema.failWithOwnedErrorMsg(block, msg: {
+                    const msg = try sema.errMsg(src, "expected '{}', found error set", .{sema.fn_ret_ty.fmt(pt)});
+                    errdefer msg.destroy(sema.gpa);
+                    const ret_ty_src: LazySrcLoc = .{
+                        .base_node_inst = sema.getOwnerFuncDeclInst(),
+                        .offset = .{ .node_offset_fn_type_ret_ty = 0 },
+                    };
+                    try sema.errNote(ret_ty_src, msg, "function cannot return an error", .{});
+                    break :msg msg;
+                }),
+            }
+        }
+    };
+
+    const eu_ty = try pt.errorUnionType(err_set_ty, payload_ty);
+
+    if (is_ref) {
+        var ptr_info = operand_ty.ptrInfo(zcu);
+        ptr_info.child = eu_ty.toIntern();
+        const eu_ptr_ty = try pt.ptrTypeSema(ptr_info);
+        return Air.internedToRef(eu_ptr_ty.toIntern());
+    } else {
+        return Air.internedToRef(eu_ty.toIntern());
+    }
+}
+
 fn zirValidateRefTy(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!void {
     const pt = sema.pt;
     const zcu = pt.zcu;
test/behavior/try.zig
@@ -67,3 +67,22 @@ test "`try`ing an if/else expression" {
 
     try std.testing.expectError(error.Test, S.getError2());
 }
+
+test "try forwards result location" {
+    if (builtin.zig_backend == .stage2_x86) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
+
+    const S = struct {
+        fn foo(err: bool) error{Foo}!u32 {
+            const result: error{ Foo, Bar }!u32 = if (err) error.Foo else 123;
+            const res_int: u32 = try @errorCast(result);
+            return res_int;
+        }
+    };
+
+    try expect((S.foo(false) catch return error.TestUnexpectedResult) == 123);
+    try std.testing.expectError(error.Foo, S.foo(true));
+}