Commit 63bd2bff12

Veikka Tuominen <git@vexu.eu>
2023-10-01 12:16:02
Sema: add `@errorCast` which works for both error sets and error unions
Closes #17343
1 parent d8bfbbb
doc/langref.html.in
@@ -6657,7 +6657,7 @@ test "coercion from homogenous tuple to array" {
           <li>{#link|@alignCast#} - make a pointer have more alignment</li>
           <li>{#link|@enumFromInt#} - obtain an enum value based on its integer tag value</li>
           <li>{#link|@errorFromInt#} - obtain an error code based on its integer value</li>
-          <li>{#link|@errSetCast#} - convert to a smaller error set</li>
+          <li>{#link|@errorCast#} - convert to a smaller error set</li>
           <li>{#link|@floatCast#} - convert a larger float to a smaller float</li>
           <li>{#link|@floatFromInt#} - convert an integer to a float value</li>
           <li>{#link|@intCast#} - convert between integer types</li>
@@ -8410,10 +8410,10 @@ test "main" {
       </p>
       {#header_close#}
 
-      {#header_open|@errSetCast#}
-      <pre>{#syntax#}@errSetCast(value: anytype) anytype{#endsyntax#}</pre>
+      {#header_open|@errorCast#}
+      <pre>{#syntax#}@errorCast(value: anytype) anytype{#endsyntax#}</pre>
       <p>
-      Converts an error value from one error set to another error set. The return type is the
+      Converts an error set or error union value from one error set to another error set. The return type is the
 			inferred result type. Attempting to convert an error which is not in the destination error
 			set results in safety-protected {#link|Undefined Behavior#}.
       </p>
@@ -10257,7 +10257,7 @@ const Set2 = error{
     C,
 };
 comptime {
-    _ = @as(Set2, @errSetCast(Set1.B));
+    _ = @as(Set2, @errorCast(Set1.B));
 }
       {#code_end#}
       <p>At runtime:</p>
@@ -10276,7 +10276,7 @@ pub fn main() void {
     foo(Set1.B);
 }
 fn foo(set1: Set1) void {
-    const x: Set2 = @errSetCast(set1);
+    const x: Set2 = @errorCast(set1);
     std.debug.print("value: {}\n", .{x});
 }
       {#code_end#}
lib/std/zig/render.zig
@@ -1444,7 +1444,7 @@ fn renderBuiltinCall(
     const slice = tree.tokenSlice(builtin_token);
     const rewrite_two_param_cast = params.len == 2 and for ([_][]const u8{
         "@bitCast",
-        "@errSetCast",
+        "@errorCast",
         "@floatCast",
         "@intCast",
         "@ptrCast",
@@ -1505,6 +1505,8 @@ fn renderBuiltinCall(
         try ais.writer().writeAll("@intFromPtr");
     } else if (mem.eql(u8, slice, "@fabs")) {
         try ais.writer().writeAll("@abs");
+    } else if (mem.eql(u8, slice, "@errSetCast")) {
+        try ais.writer().writeAll("@errorCast");
     } else {
         try renderToken(ais, tree, builtin_token, .none); // @name
     }
lib/std/child_process.zig
@@ -446,7 +446,7 @@ pub const ChildProcess = struct {
                 // has a value greater than 0
                 if ((fd[0].revents & std.os.POLL.IN) != 0) {
                     const err_int = try readIntFd(err_pipe[0]);
-                    return @as(SpawnError, @errSetCast(@errorFromInt(err_int)));
+                    return @as(SpawnError, @errorCast(@errorFromInt(err_int)));
                 }
             } else {
                 // Write maxInt(ErrInt) to the write end of the err_pipe. This is after
@@ -459,7 +459,7 @@ pub const ChildProcess = struct {
                 // Here we potentially return the fork child's error from the parent
                 // pid.
                 if (err_int != maxInt(ErrInt)) {
-                    return @as(SpawnError, @errSetCast(@errorFromInt(err_int)));
+                    return @as(SpawnError, @errorCast(@errorFromInt(err_int)));
                 }
             }
         }
lib/std/os.zig
@@ -5419,7 +5419,7 @@ pub fn dl_iterate_phdr(
             }
         }.callbackC, @as(?*anyopaque, @ptrFromInt(@intFromPtr(&context))))) {
             0 => return,
-            else => |err| return @as(Error, @errSetCast(@errorFromInt(@as(u16, @intCast(err))))), // TODO don't hardcode u16
+            else => |err| return @as(Error, @errorCast(@errorFromInt(@as(u16, @intCast(err))))), // TODO don't hardcode u16
         }
     }
 
src/AstGen.zig
@@ -8454,11 +8454,11 @@ fn builtinCall(
             });
             return rvalue(gz, ri, result, node);
         },
-        .err_set_cast => {
+        .error_cast => {
             try emitDbgNode(gz, node);
 
-            const result = try gz.addExtendedPayload(.err_set_cast, Zir.Inst.BinNode{
-                .lhs = try ri.rl.resultTypeForCast(gz, node, "@errSetCast"),
+            const result = try gz.addExtendedPayload(.error_cast, Zir.Inst.BinNode{
+                .lhs = try ri.rl.resultTypeForCast(gz, node, "@errorCast"),
                 .rhs = try expr(gz, scope, .{ .rl = .none }, params[0]),
                 .node = gz.nodeIndexToRelative(node),
             });
src/AstRlAnnotate.zig
@@ -945,7 +945,7 @@ fn builtinCall(astrl: *AstRlAnnotate, block: ?*Block, ri: ResultInfo, node: Ast.
         .float_cast,
         .int_cast,
         .truncate,
-        .err_set_cast,
+        .error_cast,
         .ptr_cast,
         .align_cast,
         .addrspace_cast,
src/BuiltinFn.zig
@@ -43,7 +43,7 @@ pub const Tag = enum {
     error_name,
     error_return_trace,
     int_from_error,
-    err_set_cast,
+    error_cast,
     @"export",
     @"extern",
     fence,
@@ -455,9 +455,9 @@ pub const list = list: {
             },
         },
         .{
-            "@errSetCast",
+            "@errorCast",
             .{
-                .tag = .err_set_cast,
+                .tag = .error_cast,
                 .eval_to_error = .always,
                 .param_count = 1,
             },
src/print_zir.zig
@@ -594,7 +594,7 @@ const Writer = struct {
 
             .builtin_extern,
             .c_define,
-            .err_set_cast,
+            .error_cast,
             .wasm_memory_grow,
             .prefetch,
             .c_va_arg,
src/Sema.zig
@@ -1252,7 +1252,7 @@ fn analyzeBodyInner(
                     .wasm_memory_size   => try sema.zirWasmMemorySize(    block, extended),
                     .wasm_memory_grow   => try sema.zirWasmMemoryGrow(    block, extended),
                     .prefetch           => try sema.zirPrefetch(          block, extended),
-                    .err_set_cast       => try sema.zirErrSetCast(        block, extended),
+                    .error_cast         => try sema.zirErrorCast(         block, extended),
                     .await_nosuspend    => try sema.zirAwaitNosuspend(    block, extended),
                     .select             => try sema.zirSelect(            block, extended),
                     .int_from_error     => try sema.zirIntFromError(      block, extended),
@@ -21747,17 +21747,31 @@ fn ptrFromIntVal(
     };
 }
 
-fn zirErrSetCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref {
+fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref {
     const mod = sema.mod;
     const ip = &mod.intern_pool;
     const extra = sema.code.extraData(Zir.Inst.BinNode, extended.operand).data;
     const src = LazySrcLoc.nodeOffset(extra.node);
     const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = extra.node };
-    const dest_ty = try sema.resolveDestType(block, src, extra.lhs, .remove_eu_opt, "@errSetCast");
+    const base_dest_ty = try sema.resolveDestType(block, src, extra.lhs, .remove_opt, "@errorCast");
     const operand = try sema.resolveInst(extra.rhs);
-    const operand_ty = sema.typeOf(operand);
-    try sema.checkErrorSetType(block, src, dest_ty);
-    try sema.checkErrorSetType(block, operand_src, operand_ty);
+    const base_operand_ty = sema.typeOf(operand);
+    const dest_tag = base_dest_ty.zigTypeTag(mod);
+    const operand_tag = base_operand_ty.zigTypeTag(mod);
+    if (dest_tag != operand_tag) {
+        return sema.fail(block, src, "expected source and destination types to match, found '{s}' and '{s}'", .{
+            @tagName(operand_tag), @tagName(dest_tag),
+        });
+    } else if (dest_tag != .ErrorSet and dest_tag != .ErrorUnion) {
+        return sema.fail(block, src, "expected error set or error union type, found '{s}'", .{@tagName(dest_tag)});
+    }
+    const dest_ty, const operand_ty = if (dest_tag == .ErrorUnion) .{
+        base_dest_ty.errorUnionSet(mod),
+        base_operand_ty.errorUnionSet(mod),
+    } else .{
+        base_dest_ty,
+        base_operand_ty,
+    };
 
     // operand must be defined since it can be an invalid error value
     const maybe_operand_val = try sema.resolveDefinedValue(block, operand_src, operand);
@@ -21804,8 +21818,15 @@ fn zirErrSetCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat
     }
 
     if (maybe_operand_val) |val| {
-        if (!dest_ty.isAnyError(mod)) {
-            const error_name = mod.intern_pool.indexToKey(val.toIntern()).err.name;
+        if (!dest_ty.isAnyError(mod)) check: {
+            const operand_val = mod.intern_pool.indexToKey(val.toIntern());
+            var error_name: InternPool.NullTerminatedString = undefined;
+            if (dest_tag == .ErrorUnion) {
+                if (operand_val.error_union.val != .err_name) break :check;
+                error_name = operand_val.error_union.val.err_name;
+            } else {
+                error_name = operand_val.err.name;
+            }
             if (!Type.errorSetHasFieldIp(ip, dest_ty.toIntern(), error_name)) {
                 const msg = msg: {
                     const msg = try sema.errMsg(
@@ -21822,16 +21843,29 @@ fn zirErrSetCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat
             }
         }
 
-        return Air.internedToRef((try mod.getCoerced(val, dest_ty)).toIntern());
+        return Air.internedToRef((try mod.getCoerced(val, base_dest_ty)).toIntern());
     }
 
     try sema.requireRuntimeBlock(block, src, operand_src);
     if (block.wantSafety() and !dest_ty.isAnyError(mod) and sema.mod.backendSupportsFeature(.error_set_has_value)) {
-        const err_int_inst = try block.addBitCast(Type.err_int, operand);
-        const ok = try block.addTyOp(.error_set_has_value, dest_ty, err_int_inst);
-        try sema.addSafetyCheck(block, src, ok, .invalid_error_code);
+        if (dest_tag == .ErrorUnion) {
+            const err_code = try sema.analyzeErrUnionCode(block, operand_src, operand);
+            const err_int = try block.addBitCast(Type.err_int, err_code);
+            const zero_u16 = Air.internedToRef(try mod.intern(.{
+                .int = .{ .ty = .u16_type, .storage = .{ .u64 = 0 } },
+            }));
+
+            const has_value = try block.addTyOp(.error_set_has_value, dest_ty, err_code);
+            const is_zero = try block.addBinOp(.cmp_eq, err_int, zero_u16);
+            const ok = try block.addBinOp(.bit_or, has_value, is_zero);
+            try sema.addSafetyCheck(block, src, ok, .invalid_error_code);
+        } else {
+            const err_int_inst = try block.addBitCast(Type.err_int, operand);
+            const ok = try block.addTyOp(.error_set_has_value, dest_ty, err_int_inst);
+            try sema.addSafetyCheck(block, src, ok, .invalid_error_code);
+        }
     }
-    return block.addBitCast(dest_ty, operand);
+    return block.addBitCast(base_dest_ty, operand);
 }
 
 fn zirPtrCastFull(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref {
@@ -22916,14 +22950,6 @@ fn checkIntOrVectorAllowComptime(
     }
 }
 
-fn checkErrorSetType(sema: *Sema, block: *Block, src: LazySrcLoc, ty: Type) CompileError!void {
-    const mod = sema.mod;
-    switch (ty.zigTypeTag(mod)) {
-        .ErrorSet => return,
-        else => return sema.fail(block, src, "expected error set type, found '{}'", .{ty.fmt(mod)}),
-    }
-}
-
 const SimdBinOp = struct {
     len: ?usize,
     /// Coerced to `result_ty`.
src/Zir.zig
@@ -1997,9 +1997,9 @@ pub const Inst = struct {
         /// Implements `@setCold`.
         /// `operand` is payload index to `UnNode`.
         set_cold,
-        /// Implements the `@errSetCast` builtin.
+        /// Implements the `@errorCast` builtin.
         /// `operand` is payload index to `BinNode`. `lhs` is dest type, `rhs` is operand.
-        err_set_cast,
+        error_cast,
         /// `operand` is payload index to `UnNode`.
         await_nosuspend,
         /// Implements `@breakpoint`.
test/behavior/error.zig
@@ -228,13 +228,29 @@ const Set1 = error{ A, B };
 const Set2 = error{ A, C };
 
 fn testExplicitErrorSetCast(set1: Set1) !void {
-    var x = @as(Set2, @errSetCast(set1));
+    var x = @as(Set2, @errorCast(set1));
     try expect(@TypeOf(x) == Set2);
-    var y = @as(Set1, @errSetCast(x));
+    var y = @as(Set1, @errorCast(x));
     try expect(@TypeOf(y) == Set1);
     try expect(y == error.A);
 }
 
+test "@errorCast on error unions" {
+    const S = struct {
+        fn doTheTest() !void {
+            const casted: error{Bad}!i32 = @errorCast(retErrUnion());
+            try expect((try casted) == 1234);
+        }
+
+        fn retErrUnion() anyerror!i32 {
+            return 1234;
+        }
+    };
+
+    try S.doTheTest();
+    try comptime S.doTheTest();
+}
+
 test "comptime test error for empty error set" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
 
test/cases/compile_errors/explicit_error_set_cast_known_at_comptime_violates_error_sets.zig
@@ -2,7 +2,7 @@ const Set1 = error{ A, B };
 const Set2 = error{ A, C };
 comptime {
     var x = Set1.B;
-    var y: Set2 = @errSetCast(x);
+    var y: Set2 = @errorCast(x);
     _ = y;
 }
 
test/cases/compile_errors/int_to_err_non_global_invalid_number.zig
@@ -8,7 +8,7 @@ const Set2 = error{
 };
 comptime {
     var x = @intFromError(Set1.B);
-    var y: Set2 = @errSetCast(@errorFromInt(x));
+    var y: Set2 = @errorCast(@errorFromInt(x));
     _ = y;
 }
 
test/cases/safety/@errSetCast error not present in destination.zig → test/cases/safety/@errorCast error not present in destination.zig
@@ -14,7 +14,7 @@ pub fn main() !void {
     return error.TestFailed;
 }
 fn foo(set1: Set1) Set2 {
-    return @errSetCast(set1);
+    return @errorCast(set1);
 }
 // run
 // backend=llvm