Commit 869880adac

Dominic <4678790+dweiller@users.noreply.github.com>
2024-05-11 11:06:13
astgen: fix result info for catch switch_block_err_union
1 parent 511aa28
Changed files (3)
lib/std/zig/AstGen.zig
@@ -7071,8 +7071,10 @@ fn switchExprErrUnion(
         .ctx = ri.ctx,
     };
 
-    const payload_is_ref = node_ty == .@"if" and
-        if_full.payload_token != null and token_tags[if_full.payload_token.?] == .asterisk;
+    const payload_is_ref = switch (node_ty) {
+        .@"if" => if_full.payload_token != null and token_tags[if_full.payload_token.?] == .asterisk,
+        .@"catch" => ri.rl == .ref or ri.rl == .ref_coerced_ty,
+    };
 
     // We need to call `rvalue` to write through to the pointer only if we had a
     // result pointer and aren't forwarding it.
test/behavior/switch_on_captured_error.zig
@@ -3,9 +3,11 @@ const assert = std.debug.assert;
 const expect = std.testing.expect;
 const expectError = std.testing.expectError;
 const expectEqual = std.testing.expectEqual;
+const builtin = @import("builtin");
 
 test "switch on error union catch capture" {
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
 
     const S = struct {
         const Error = error{ A, B, C };
@@ -16,6 +18,7 @@ test "switch on error union catch capture" {
             try testCapture();
             try testInline();
             try testEmptyErrSet();
+            try testAddressOf();
         }
 
         fn testScalar() !void {
@@ -252,6 +255,44 @@ test "switch on error union catch capture" {
                 try expectEqual(@as(u64, 0), b);
             }
         }
+
+        fn testAddressOf() !void {
+            {
+                const a: anyerror!usize = 0;
+                const ptr = &(a catch |e| switch (e) {
+                    else => 3,
+                });
+                comptime assert(@TypeOf(ptr) == *const usize);
+                try expectEqual(ptr, &(a catch unreachable));
+            }
+            {
+                const a: anyerror!usize = error.A;
+                const ptr = &(a catch |e| switch (e) {
+                    else => 3,
+                });
+                comptime assert(@TypeOf(ptr) == *const comptime_int);
+                try expectEqual(3, ptr.*);
+            }
+            {
+                var a: anyerror!usize = 0;
+                _ = &a;
+                const ptr = &(a catch |e| switch (e) {
+                    else => return,
+                });
+                comptime assert(@TypeOf(ptr) == *usize);
+                ptr.* += 1;
+                try expectEqual(@as(usize, 1), a catch unreachable);
+            }
+            {
+                var a: anyerror!usize = error.A;
+                _ = &a;
+                const ptr = &(a catch |e| switch (e) {
+                    else => return,
+                });
+                comptime assert(@TypeOf(ptr) == *usize);
+                unreachable;
+            }
+        }
     };
 
     try comptime S.doTheTest();
@@ -260,6 +301,7 @@ test "switch on error union catch capture" {
 
 test "switch on error union if else capture" {
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
 
     const S = struct {
         const Error = error{ A, B, C };
@@ -276,6 +318,7 @@ test "switch on error union if else capture" {
             try testInlinePtr();
             try testEmptyErrSet();
             try testEmptyErrSetPtr();
+            try testAddressOf();
         }
 
         fn testScalar() !void {
@@ -747,6 +790,45 @@ test "switch on error union if else capture" {
                 try expectEqual(@as(u64, 0), b);
             }
         }
+
+        fn testAddressOf() !void {
+            if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
+            {
+                const a: anyerror!usize = 0;
+                const ptr = &(if (a) |*v| v.* else |e| switch (e) {
+                    else => 3,
+                });
+                comptime assert(@TypeOf(ptr) == *const usize);
+                try expectEqual(ptr, &(a catch unreachable));
+            }
+            {
+                const a: anyerror!usize = error.A;
+                const ptr = &(if (a) |*v| v.* else |e| switch (e) {
+                    else => 3,
+                });
+                comptime assert(@TypeOf(ptr) == *const comptime_int);
+                try expectEqual(3, ptr.*);
+            }
+            {
+                var a: anyerror!usize = 0;
+                _ = &a;
+                const ptr = &(if (a) |*v| v.* else |e| switch (e) {
+                    else => return,
+                });
+                comptime assert(@TypeOf(ptr) == *usize);
+                ptr.* += 1;
+                try expectEqual(@as(usize, 1), a catch unreachable);
+            }
+            {
+                var a: anyerror!usize = error.A;
+                _ = &a;
+                const ptr = &(if (a) |*v| v.* else |e| switch (e) {
+                    else => return,
+                });
+                comptime assert(@TypeOf(ptr) == *usize);
+                unreachable;
+            }
+        }
     };
 
     try comptime S.doTheTest();
test/behavior.zig
@@ -89,6 +89,7 @@ test {
     _ = @import("behavior/switch.zig");
     _ = @import("behavior/switch_prong_err_enum.zig");
     _ = @import("behavior/switch_prong_implicit_cast.zig");
+    _ = @import("behavior/switch_on_captured_error.zig");
     _ = @import("behavior/this.zig");
     _ = @import("behavior/threadlocal.zig");
     _ = @import("behavior/truncate.zig");