Commit 0bdbd3e235

Veikka Tuominen <git@vexu.eu>
2023-10-02 14:44:50
Sema: fix issues in `@errorCast` with error unions
1 parent c9c3ee7
Changed files (3)
src/Sema.zig
@@ -21771,10 +21771,10 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData
     // operand must be defined since it can be an invalid error value
     const maybe_operand_val = try sema.resolveDefinedValue(block, operand_src, operand);
 
-    if (disjoint: {
+    const disjoint = disjoint: {
         // Try avoiding resolving inferred error sets if we can
-        if (!dest_ty.isAnyError(mod) and dest_ty.errorSetNames(mod).len == 0) break :disjoint true;
-        if (!operand_ty.isAnyError(mod) and operand_ty.errorSetNames(mod).len == 0) break :disjoint true;
+        if (!dest_ty.isAnyError(mod) and dest_ty.errorSetIsEmpty(mod)) break :disjoint true;
+        if (!operand_ty.isAnyError(mod) and operand_ty.errorSetIsEmpty(mod)) break :disjoint true;
         if (dest_ty.isAnyError(mod)) break :disjoint false;
         if (operand_ty.isAnyError(mod)) break :disjoint false;
         for (dest_ty.errorSetNames(mod)) |dest_err_name| {
@@ -21796,7 +21796,8 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData
         }
 
         break :disjoint true;
-    }) {
+    };
+    if (disjoint and dest_tag != .ErrorUnion) {
         const msg = msg: {
             const msg = try sema.errMsg(
                 block,
@@ -21850,10 +21851,16 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData
                 .int = .{ .ty = .u16_type, .storage = .{ .u64 = 0 } },
             }));
 
-            const has_value = try block.addTyOp(.error_set_has_value, dest_ty, err_code);
             const is_zero = try block.addBinOp(.cmp_eq, err_int, zero_u16);
-            const ok = try block.addBinOp(.bit_or, has_value, is_zero);
-            try sema.addSafetyCheck(block, src, ok, .invalid_error_code);
+            if (disjoint) {
+                // Error must be zero.
+                try sema.addSafetyCheck(block, src, is_zero, .invalid_error_code);
+            } else {
+                // Error must be in destination set or zero.
+                const has_value = try block.addTyOp(.error_set_has_value, dest_ty, err_code);
+                const ok = try block.addBinOp(.bit_or, has_value, is_zero);
+                try sema.addSafetyCheck(block, src, ok, .invalid_error_code);
+            }
         } else {
             const err_int_inst = try block.addBitCast(Type.err_int, operand);
             const ok = try block.addTyOp(.error_set_has_value, dest_ty, err_int_inst);
test/behavior/error.zig
@@ -238,13 +238,23 @@ fn testExplicitErrorSetCast(set1: Set1) !void {
 test "@errorCast on error unions" {
     const S = struct {
         fn doTheTest() !void {
-            const casted: error{Bad}!i32 = @errorCast(retErrUnion());
-            try expect((try casted) == 1234);
+            {
+                const casted: error{Bad}!i32 = @errorCast(retErrUnion());
+                try expect((try casted) == 1234);
+            }
+            {
+                const casted: error{Bad}!i32 = @errorCast(retInferredErrUnion());
+                try expect((try casted) == 5678);
+            }
         }
 
         fn retErrUnion() anyerror!i32 {
             return 1234;
         }
+
+        fn retInferredErrUnion() !i32 {
+            return 5678;
+        }
     };
 
     try S.doTheTest();
test/cases/safety/@errorCast error union casted to disjoint set.zig
@@ -0,0 +1,20 @@
+const std = @import("std");
+
+pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace, _: ?usize) noreturn {
+    _ = stack_trace;
+    if (std.mem.eql(u8, message, "invalid error code")) {
+        std.process.exit(0);
+    }
+    std.process.exit(1);
+}
+pub fn main() !void {
+    const bar: error{Foo}!i32 = @errorCast(foo());
+    _ = &bar;
+    return error.TestFailed;
+}
+fn foo() anyerror!i32 {
+    return error.Bar;
+}
+// run
+// backend=llvm
+// target=native