Commit 5317d88414

mlugg <mlugg@mlugg.co.uk>
2025-02-05 22:06:39
Sema: fix `@errorCast` with error unions
Resolves: #20169
1 parent fbbf34e
Changed files (2)
src
test
behavior
src/Sema.zig
@@ -23175,11 +23175,12 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData
     const extra = sema.code.extraData(Zir.Inst.BinNode, extended.operand).data;
     const src = block.nodeOffset(extra.node);
     const operand_src = block.builtinCallArgSrc(extra.node, 0);
-    const base_dest_ty = try sema.resolveDestType(block, src, extra.lhs, .remove_opt, "@errorCast");
+    const dest_ty = try sema.resolveDestType(block, src, extra.lhs, .remove_opt, "@errorCast");
     const operand = try sema.resolveInst(extra.rhs);
-    const base_operand_ty = sema.typeOf(operand);
-    const dest_tag = base_dest_ty.zigTypeTag(zcu);
-    const operand_tag = base_operand_ty.zigTypeTag(zcu);
+    const operand_ty = sema.typeOf(operand);
+
+    const dest_tag = dest_ty.zigTypeTag(zcu);
+    const operand_tag = operand_ty.zigTypeTag(zcu);
 
     if (dest_tag != .error_set and dest_tag != .error_union) {
         return sema.fail(block, src, "expected error set or error union type, found '{s}'", .{@tagName(dest_tag)});
@@ -23191,107 +23192,133 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData
         return sema.fail(block, src, "cannot cast an error union type to error set", .{});
     }
     if (dest_tag == .error_union and operand_tag == .error_union and
-        base_dest_ty.errorUnionPayload(zcu).toIntern() != base_operand_ty.errorUnionPayload(zcu).toIntern())
+        dest_ty.errorUnionPayload(zcu).toIntern() != operand_ty.errorUnionPayload(zcu).toIntern())
     {
         return sema.failWithOwnedErrorMsg(block, msg: {
             const msg = try sema.errMsg(src, "payload types of error unions must match", .{});
             errdefer msg.destroy(sema.gpa);
-            const dest_ty = base_dest_ty.errorUnionPayload(zcu);
-            const operand_ty = base_operand_ty.errorUnionPayload(zcu);
-            try sema.errNote(src, msg, "destination payload is '{}'", .{dest_ty.fmt(pt)});
-            try sema.errNote(src, msg, "operand payload is '{}'", .{operand_ty.fmt(pt)});
+            const dest_payload_ty = dest_ty.errorUnionPayload(zcu);
+            const operand_payload_ty = operand_ty.errorUnionPayload(zcu);
+            try sema.errNote(src, msg, "destination payload is '{}'", .{dest_payload_ty.fmt(pt)});
+            try sema.errNote(src, msg, "operand payload is '{}'", .{operand_payload_ty.fmt(pt)});
             try addDeclaredHereNote(sema, msg, dest_ty);
             try addDeclaredHereNote(sema, msg, operand_ty);
             break :msg msg;
         });
     }
-    const dest_ty = if (dest_tag == .error_union) base_dest_ty.errorUnionSet(zcu) else base_dest_ty;
-    const operand_ty = if (operand_tag == .error_union) base_operand_ty.errorUnionSet(zcu) else base_operand_ty;
-
-    // operand must be defined since it can be an invalid error value
-    const maybe_operand_val = try sema.resolveDefinedValue(block, operand_src, operand);
+    const dest_err_ty = switch (dest_tag) {
+        .error_union => dest_ty.errorUnionSet(zcu),
+        .error_set => dest_ty,
+        else => unreachable,
+    };
+    const operand_err_ty = switch (operand_tag) {
+        .error_union => operand_ty.errorUnionSet(zcu),
+        .error_set => operand_ty,
+        else => unreachable,
+    };
 
     const disjoint = disjoint: {
         // Try avoiding resolving inferred error sets if we can
-        if (!dest_ty.isAnyError(zcu) and dest_ty.errorSetIsEmpty(zcu)) break :disjoint true;
-        if (!operand_ty.isAnyError(zcu) and operand_ty.errorSetIsEmpty(zcu)) break :disjoint true;
-        if (dest_ty.isAnyError(zcu)) break :disjoint false;
-        if (operand_ty.isAnyError(zcu)) break :disjoint false;
-        const dest_err_names = dest_ty.errorSetNames(zcu);
+        if (!dest_err_ty.isAnyError(zcu) and dest_err_ty.errorSetIsEmpty(zcu)) break :disjoint true;
+        if (!operand_err_ty.isAnyError(zcu) and operand_err_ty.errorSetIsEmpty(zcu)) break :disjoint true;
+        if (dest_err_ty.isAnyError(zcu)) break :disjoint false;
+        if (operand_err_ty.isAnyError(zcu)) break :disjoint false;
+        const dest_err_names = dest_err_ty.errorSetNames(zcu);
         for (0..dest_err_names.len) |dest_err_index| {
-            if (Type.errorSetHasFieldIp(ip, operand_ty.toIntern(), dest_err_names.get(ip)[dest_err_index]))
+            if (Type.errorSetHasFieldIp(ip, operand_err_ty.toIntern(), dest_err_names.get(ip)[dest_err_index]))
                 break :disjoint false;
         }
 
-        if (!ip.isInferredErrorSetType(dest_ty.toIntern()) and
-            !ip.isInferredErrorSetType(operand_ty.toIntern()))
+        if (!ip.isInferredErrorSetType(dest_err_ty.toIntern()) and
+            !ip.isInferredErrorSetType(operand_err_ty.toIntern()))
         {
             break :disjoint true;
         }
 
-        _ = try sema.resolveInferredErrorSetTy(block, src, dest_ty.toIntern());
-        _ = try sema.resolveInferredErrorSetTy(block, operand_src, operand_ty.toIntern());
+        _ = try sema.resolveInferredErrorSetTy(block, src, dest_err_ty.toIntern());
+        _ = try sema.resolveInferredErrorSetTy(block, operand_src, operand_err_ty.toIntern());
         for (0..dest_err_names.len) |dest_err_index| {
-            if (Type.errorSetHasFieldIp(ip, operand_ty.toIntern(), dest_err_names.get(ip)[dest_err_index]))
+            if (Type.errorSetHasFieldIp(ip, operand_err_ty.toIntern(), dest_err_names.get(ip)[dest_err_index]))
                 break :disjoint false;
         }
 
         break :disjoint true;
     };
-    if (disjoint and dest_tag != .error_union) {
+    if (disjoint and !(operand_tag == .error_union and dest_tag == .error_union)) {
         return sema.fail(block, src, "error sets '{}' and '{}' have no common errors", .{
-            operand_ty.fmt(pt), dest_ty.fmt(pt),
+            operand_err_ty.fmt(pt), dest_err_ty.fmt(pt),
         });
     }
 
-    if (maybe_operand_val) |val| {
-        if (!dest_ty.isAnyError(zcu)) check: {
-            const operand_val = zcu.intern_pool.indexToKey(val.toIntern());
-            var error_name: InternPool.NullTerminatedString = undefined;
-            if (operand_tag == .error_union) {
-                if (operand_val.error_union.val != .err_name) break :check;
-                error_name = operand_val.error_union.val.err_name;
-            } else {
-                error_name = operand_val.err.name;
-            }
-            if (!Type.errorSetHasFieldIp(ip, dest_ty.toIntern(), error_name)) {
-                return sema.fail(block, src, "'error.{}' not a member of error set '{}'", .{
-                    error_name.fmt(ip), dest_ty.fmt(pt),
-                });
-            }
+    // operand must be defined since it can be an invalid error value
+    if (try sema.resolveDefinedValue(block, operand_src, operand)) |operand_val| {
+        const err_name: InternPool.NullTerminatedString = switch (operand_tag) {
+            .error_set => ip.indexToKey(operand_val.toIntern()).err.name,
+            .error_union => switch (ip.indexToKey(operand_val.toIntern()).error_union.val) {
+                .err_name => |name| name,
+                .payload => |payload_val| {
+                    assert(dest_tag == .error_union); // should be guaranteed from the type checks above
+                    return sema.coerce(block, dest_ty, Air.internedToRef(payload_val), operand_src);
+                },
+            },
+            else => unreachable,
+        };
+
+        if (!dest_err_ty.isAnyError(zcu) and !Type.errorSetHasFieldIp(ip, dest_err_ty.toIntern(), err_name)) {
+            return sema.fail(block, src, "'error.{}' not a member of error set '{}'", .{
+                err_name.fmt(ip), dest_err_ty.fmt(pt),
+            });
         }
 
-        return Air.internedToRef((try pt.getCoerced(val, base_dest_ty)).toIntern());
+        return Air.internedToRef(try pt.intern(switch (dest_tag) {
+            .error_set => .{ .err = .{
+                .ty = dest_ty.toIntern(),
+                .name = err_name,
+            } },
+            .error_union => .{ .error_union = .{
+                .ty = dest_ty.toIntern(),
+                .val = .{ .err_name = err_name },
+            } },
+            else => unreachable,
+        }));
     }
 
-    try sema.requireRuntimeBlock(block, src, operand_src);
     const err_int_ty = try pt.errorIntType();
-    if (block.wantSafety() and !dest_ty.isAnyError(zcu) and
-        dest_ty.toIntern() != .adhoc_inferred_error_set_type and
+    if (block.wantSafety() and !dest_err_ty.isAnyError(zcu) and
+        dest_err_ty.toIntern() != .adhoc_inferred_error_set_type and
         zcu.backendSupportsFeature(.error_set_has_value))
     {
-        if (dest_tag == .error_union) {
-            const err_code = try block.addTyOp(.unwrap_errunion_err, operand_ty, operand);
-            const err_int = try block.addBitCast(err_int_ty, err_code);
-            const zero_err = try pt.intRef(try pt.errorIntType(), 0);
+        const err_code_inst = switch (operand_tag) {
+            .error_set => operand,
+            .error_union => try block.addTyOp(.unwrap_errunion_err, operand_err_ty, operand),
+            else => unreachable,
+        };
+        const err_int_inst = try block.addBitCast(err_int_ty, err_code_inst);
 
-            const is_zero = try block.addBinOp(.cmp_eq, err_int, zero_err);
+        if (dest_tag == .error_union) {
+            const zero_err = try pt.intRef(err_int_ty, 0);
+            const is_zero = try block.addBinOp(.cmp_eq, err_int_inst, zero_err);
             if (disjoint) {
                 // Error must be zero.
                 try sema.addSafetyCheck(block, src, is_zero, .invalid_error_code);
             } else {
                 // Error must be in destination set or zero.
-                const has_value = try block.addTyOp(.error_set_has_value, dest_ty, err_code);
+                const has_value = try block.addTyOp(.error_set_has_value, dest_err_ty, err_int_inst);
                 const ok = try block.addBinOp(.bool_or, has_value, is_zero);
                 try sema.addSafetyCheck(block, src, ok, .invalid_error_code);
             }
         } else {
-            const err_int_inst = try block.addBitCast(err_int_ty, operand);
-            const ok = try block.addTyOp(.error_set_has_value, dest_ty, err_int_inst);
+            const ok = try block.addTyOp(.error_set_has_value, dest_err_ty, err_int_inst);
             try sema.addSafetyCheck(block, src, ok, .invalid_error_code);
         }
     }
-    return block.addBitCast(base_dest_ty, operand);
+
+    if (operand_tag == .error_set and dest_tag == .error_union) {
+        const err_val = try block.addBitCast(dest_err_ty, operand);
+        return block.addTyOp(.wrap_errunion_err, dest_ty, err_val);
+    } else {
+        return block.addBitCast(dest_ty, operand);
+    }
 }
 
 fn zirPtrCastFull(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref {
test/behavior/error.zig
@@ -1060,9 +1060,24 @@ test "errorCast to adhoc inferred error set" {
     try std.testing.expect((try S.baz()) == 1234);
 }
 
-test "errorCast from error sets to error unions" {
-    const err_union: Set1!void = @errorCast(error.A);
-    try expectError(error.A, err_union);
+test "@errorCast from error set to error union" {
+    const S = struct {
+        fn doTheTest(set: error{ A, B }) error{A}!i32 {
+            return @errorCast(set);
+        }
+    };
+    try expectError(error.A, S.doTheTest(error.A));
+    try expectError(error.A, comptime S.doTheTest(error.A));
+}
+
+test "@errorCast from error union to error union" {
+    const S = struct {
+        fn doTheTest(set: error{ A, B }!i32) error{A}!i32 {
+            return @errorCast(set);
+        }
+    };
+    try expectError(error.A, S.doTheTest(error.A));
+    try expectError(error.A, comptime S.doTheTest(error.A));
 }
 
 test "result location initialization of error union with OPV payload" {