Commit 1d9b1c0212

Justus Klausecker <justus@klausecker.de>
2025-07-08 01:32:49
Permit explicit tags with '_' switch prong Mainly affects ZIR representation of switch_block[_ref] and special prong (detection) logic for switch. Adds a new SpecialProng tag 'absorbing_under' that allows specifying additional explicit tags in a '_' prong which are respected when checking that every value is handled during semantic analysis but are not transformed into AIR and instead 'absorbed' by the '_' branch.
1 parent fd9cfc3
lib/std/zig/Ast.zig
@@ -2877,6 +2877,24 @@ pub const full = struct {
             arrow_token: TokenIndex,
             target_expr: Node.Index,
         };
+
+        /// Returns:
+        ///   `null` if case is not special
+        ///   `.none` if case is else prong
+        ///   Index of underscore otherwise
+        pub fn isSpecial(case: *const SwitchCase, tree: *const Ast) ?Node.OptionalIndex {
+            if (case.ast.values.len == 0) {
+                return .none;
+            }
+            for (case.ast.values) |val| {
+                if (tree.nodeTag(val) == .identifier and
+                    mem.eql(u8, tree.tokenSlice(tree.nodeMainToken(val)), "_"))
+                {
+                    return val.toOptional();
+                }
+            }
+            return null;
+        }
     };
 
     pub const Asm = struct {
lib/std/zig/AstGen.zig
@@ -7666,6 +7666,7 @@ fn switchExpr(
     var special_node: Ast.Node.OptionalIndex = .none;
     var else_src: ?Ast.TokenIndex = null;
     var underscore_src: ?Ast.TokenIndex = null;
+    var underscore_node: Ast.Node.OptionalIndex = .none;
     for (case_nodes) |case_node| {
         const case = tree.fullSwitchCase(case_node).?;
         if (case.payload_token) |payload_token| {
@@ -7686,7 +7687,7 @@ fn switchExpr(
                 any_non_inline_capture = true;
             }
         }
-        // Check for else/`_` prong.
+        // Check for else prong.
         if (case.ast.values.len == 0) {
             const case_src = case.ast.arrow_token - 1;
             if (else_src) |src| {
@@ -7725,56 +7726,60 @@ fn switchExpr(
             special_prong = .@"else";
             else_src = case_src;
             continue;
-        } else if (case.ast.values.len == 1 and
-            tree.nodeTag(case.ast.values[0]) == .identifier and
-            mem.eql(u8, tree.tokenSlice(tree.nodeMainToken(case.ast.values[0])), "_"))
-        {
-            const case_src = case.ast.arrow_token - 1;
-            if (underscore_src) |src| {
-                return astgen.failTokNotes(
-                    case_src,
-                    "multiple '_' prongs in switch expression",
-                    .{},
-                    &[_]u32{
-                        try astgen.errNoteTok(
-                            src,
-                            "previous '_' prong here",
-                            .{},
-                        ),
-                    },
-                );
-            } else if (else_src) |some_else| {
-                return astgen.failNodeNotes(
-                    node,
-                    "else and '_' prong in switch expression",
-                    .{},
-                    &[_]u32{
-                        try astgen.errNoteTok(
-                            some_else,
-                            "else prong here",
-                            .{},
-                        ),
-                        try astgen.errNoteTok(
-                            case_src,
-                            "'_' prong here",
-                            .{},
-                        ),
-                    },
-                );
-            }
-            if (case.inline_token != null) {
-                return astgen.failTok(case_src, "cannot inline '_' prong", .{});
-            }
-            special_node = case_node.toOptional();
-            special_prong = .under;
-            underscore_src = case_src;
-            continue;
         }
 
+        // Check for '_' prong.
+        var found_underscore = false;
         for (case.ast.values) |val| {
-            if (tree.nodeTag(val) == .string_literal)
-                return astgen.failNode(val, "cannot switch on strings", .{});
+            switch (tree.nodeTag(val)) {
+                .identifier => if (mem.eql(u8, tree.tokenSlice(tree.nodeMainToken(val)), "_")) {
+                    const case_src = case.ast.arrow_token - 1;
+                    if (underscore_src) |src| {
+                        return astgen.failTokNotes(
+                            case_src,
+                            "multiple '_' prongs in switch expression",
+                            .{},
+                            &[_]u32{
+                                try astgen.errNoteTok(
+                                    src,
+                                    "previous '_' prong here",
+                                    .{},
+                                ),
+                            },
+                        );
+                    } else if (else_src) |some_else| {
+                        return astgen.failNodeNotes(
+                            node,
+                            "else and '_' prong in switch expression",
+                            .{},
+                            &[_]u32{
+                                try astgen.errNoteTok(
+                                    some_else,
+                                    "else prong here",
+                                    .{},
+                                ),
+                                try astgen.errNoteTok(
+                                    case_src,
+                                    "'_' prong here",
+                                    .{},
+                                ),
+                            },
+                        );
+                    }
+                    if (case.inline_token != null) {
+                        return astgen.failTok(case_src, "cannot inline '_' prong", .{});
+                    }
+                    special_node = case_node.toOptional();
+                    special_prong = if (case.ast.values.len == 1) .under else .absorbing_under;
+                    underscore_src = case_src;
+                    underscore_node = val.toOptional();
+                    found_underscore = true;
+                },
+                .string_literal => return astgen.failNode(val, "cannot switch on strings", .{}),
+                else => {},
+            }
         }
+        if (found_underscore) continue;
 
         if (case.ast.values.len == 1 and tree.nodeTag(case.ast.values[0]) != .switch_range) {
             scalar_cases_len += 1;
@@ -7938,14 +7943,23 @@ fn switchExpr(
 
         const header_index: u32 = @intCast(payloads.items.len);
         const body_len_index = if (is_multi_case) blk: {
-            payloads.items[multi_case_table + multi_case_index] = header_index;
-            multi_case_index += 1;
+            if (case_node.toOptional() == special_node) {
+                assert(special_prong == .absorbing_under);
+                payloads.items[case_table_start] = header_index;
+            } else {
+                payloads.items[multi_case_table + multi_case_index] = header_index;
+                multi_case_index += 1;
+            }
             try payloads.resize(gpa, header_index + 3); // items_len, ranges_len, body_len
 
             // items
             var items_len: u32 = 0;
             for (case.ast.values) |item_node| {
-                if (tree.nodeTag(item_node) == .switch_range) continue;
+                if (item_node.toOptional() == underscore_node or
+                    tree.nodeTag(item_node) == .switch_range)
+                {
+                    continue;
+                }
                 items_len += 1;
 
                 const item_inst = try comptimeExpr(parent_gz, scope, item_ri, item_node, .switch_item);
@@ -7955,7 +7969,9 @@ fn switchExpr(
             // ranges
             var ranges_len: u32 = 0;
             for (case.ast.values) |range| {
-                if (tree.nodeTag(range) != .switch_range) continue;
+                if (tree.nodeTag(range) != .switch_range) {
+                    continue;
+                }
                 ranges_len += 1;
 
                 const first_node, const last_node = tree.nodeData(range).node_and_node;
@@ -7970,6 +7986,7 @@ fn switchExpr(
             payloads.items[header_index + 1] = ranges_len;
             break :blk header_index + 2;
         } else if (case_node.toOptional() == special_node) blk: {
+            assert(special_prong != .absorbing_under);
             payloads.items[case_table_start] = header_index;
             try payloads.resize(gpa, header_index + 1); // body_len
             break :blk header_index;
@@ -8025,15 +8042,13 @@ fn switchExpr(
     try astgen.extra.ensureUnusedCapacity(gpa, @typeInfo(Zir.Inst.SwitchBlock).@"struct".fields.len +
         @intFromBool(multi_cases_len != 0) +
         @intFromBool(any_has_tag_capture) +
-        payloads.items.len - case_table_end +
-        (case_table_end - case_table_start) * @typeInfo(Zir.Inst.As).@"struct".fields.len);
+        payloads.items.len - scratch_top);
 
     const payload_index = astgen.addExtraAssumeCapacity(Zir.Inst.SwitchBlock{
         .operand = raw_operand,
         .bits = Zir.Inst.SwitchBlock.Bits{
             .has_multi_cases = multi_cases_len != 0,
-            .has_else = special_prong == .@"else",
-            .has_under = special_prong == .under,
+            .special_prong = special_prong,
             .any_has_tag_capture = any_has_tag_capture,
             .any_non_inline_capture = any_non_inline_capture,
             .has_continue = switch_full.label_token != null and block_scope.label.?.used_for_continue,
@@ -8052,13 +8067,30 @@ fn switchExpr(
     const zir_datas = astgen.instructions.items(.data);
     zir_datas[@intFromEnum(switch_block)].pl_node.payload_index = payload_index;
 
-    for (payloads.items[case_table_start..case_table_end], 0..) |start_index, i| {
+    var normal_case_table_start = case_table_start;
+    if (special_prong != .none) {
+        normal_case_table_start += 1;
+
+        const start_index = payloads.items[case_table_start];
         var body_len_index = start_index;
         var end_index = start_index;
-        const table_index = case_table_start + i;
-        if (table_index < scalar_case_table) {
+        if (special_prong == .absorbing_under) {
+            body_len_index += 2;
+            const items_len = payloads.items[start_index];
+            const ranges_len = payloads.items[start_index + 1];
+            end_index += 3 + items_len + 2 * ranges_len;
+        } else {
             end_index += 1;
-        } else if (table_index < multi_case_table) {
+        }
+        const prong_info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(payloads.items[body_len_index]);
+        end_index += prong_info.body_len;
+        astgen.extra.appendSliceAssumeCapacity(payloads.items[start_index..end_index]);
+    }
+    for (payloads.items[normal_case_table_start..case_table_end], 0..) |start_index, i| {
+        var body_len_index = start_index;
+        var end_index = start_index;
+        const table_index = normal_case_table_start + i;
+        if (table_index < multi_case_table) {
             body_len_index += 1;
             end_index += 2;
         } else {
lib/std/zig/Zir.zig
@@ -3226,8 +3226,15 @@ pub const Inst = struct {
 
     /// 0. multi_cases_len: u32 // If has_multi_cases is set.
     /// 1. tag_capture_inst: u32 // If any_has_tag_capture is set. Index of instruction prongs use to refer to the inline tag capture.
-    /// 2. else_body { // If has_else or has_under is set.
+    /// 2. else_body { // If special_prong != .none
+    ///        items_len: u32, // If special_prong == .absorbing_under
+    ///        ranges_len: u32, // If special_prong == .absorbing_under
     ///        info: ProngInfo,
+    ///        item: Ref, // for every items_len
+    ///        ranges: { // for every ranges_len
+    ///            item_first: Ref,
+    ///            item_last: Ref,
+    ///        }
     ///        body member Index for every info.body_len
     ///     }
     /// 3. scalar_cases: { // for every scalar_cases_len
@@ -3239,7 +3246,7 @@ pub const Inst = struct {
     ///        items_len: u32,
     ///        ranges_len: u32,
     ///        info: ProngInfo,
-    ///        item: Ref // for every items_len
+    ///        item: Ref, // for every items_len
     ///        ranges: { // for every ranges_len
     ///            item_first: Ref,
     ///            item_last: Ref,
@@ -3275,10 +3282,8 @@ pub const Inst = struct {
         pub const Bits = packed struct(u32) {
             /// If true, one or more prongs have multiple items.
             has_multi_cases: bool,
-            /// If true, there is an else prong. This is mutually exclusive with `has_under`.
-            has_else: bool,
-            /// If true, there is an underscore prong. This is mutually exclusive with `has_else`.
-            has_under: bool,
+            /// Information about the special prong.
+            special_prong: SpecialProng,
             /// If true, at least one prong has an inline tag capture.
             any_has_tag_capture: bool,
             /// If true, at least one prong has a capture which may not
@@ -3288,17 +3293,6 @@ pub const Inst = struct {
             scalar_cases_len: ScalarCasesLen,
 
             pub const ScalarCasesLen = u26;
-
-            pub fn specialProng(bits: Bits) SpecialProng {
-                const has_else: u2 = @intFromBool(bits.has_else);
-                const has_under: u2 = @intFromBool(bits.has_under);
-                return switch ((has_else << 1) | has_under) {
-                    0b00 => .none,
-                    0b01 => .under,
-                    0b10 => .@"else",
-                    0b11 => unreachable,
-                };
-            }
         };
 
         pub const MultiProng = struct {
@@ -3874,7 +3868,18 @@ pub const Inst = struct {
     };
 };
 
-pub const SpecialProng = enum { none, @"else", under };
+pub const SpecialProng = enum(u2) {
+    none,
+    /// Simple else prong.
+    /// `else => {}`
+    @"else",
+    /// Simple '_' prong.
+    /// `_ => {}`
+    under,
+    /// '_' prong with additional items.
+    /// `a, _, b => {}`
+    absorbing_under,
+};
 
 pub const DeclIterator = struct {
     extra_index: u32,
@@ -4718,7 +4723,7 @@ fn findTrackableSwitch(
     }
 
     const has_special = switch (kind) {
-        .normal => extra.data.bits.specialProng() != .none,
+        .normal => extra.data.bits.special_prong != .none,
         .err_union => has_special: {
             // Handle `non_err_body` first.
             const prong_info: Inst.SwitchBlock.ProngInfo = @bitCast(zir.extra[extra_index]);
@@ -4733,6 +4738,23 @@ fn findTrackableSwitch(
     };
 
     if (has_special) {
+        if (kind == .normal) {
+            if (extra.data.bits.special_prong == .absorbing_under) {
+                const items_len = zir.extra[extra_index];
+                extra_index += 1;
+                const ranges_len = zir.extra[extra_index];
+                extra_index += 1;
+                const prong_info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(zir.extra[extra_index]);
+                extra_index += 1;
+
+                extra_index += items_len + ranges_len * 2;
+
+                const body = zir.bodySlice(extra_index, prong_info.body_len);
+                extra_index += body.len;
+
+                try zir.findTrackableBody(gpa, contents, defers, body);
+            }
+        }
         const prong_info: Inst.SwitchBlock.ProngInfo = @bitCast(zir.extra[extra_index]);
         extra_index += 1;
         const body = zir.bodySlice(extra_index, prong_info.body_len);
src/print_zir.zig
@@ -2088,27 +2088,57 @@ const Writer = struct {
         self.indent += 2;
 
         else_prong: {
-            const special_prong = extra.data.bits.specialProng();
-            const prong_name = switch (special_prong) {
-                .@"else" => "else",
-                .under => "_",
-                else => break :else_prong,
-            };
+            const special_prong = extra.data.bits.special_prong;
+            if (special_prong == .none) break :else_prong;
 
+            var items_len: u32 = 0;
+            var ranges_len: u32 = 0;
+            if (special_prong == .absorbing_under) {
+                items_len = self.code.extra[extra_index];
+                extra_index += 1;
+                ranges_len = self.code.extra[extra_index];
+                extra_index += 1;
+            }
             const info = @as(Zir.Inst.SwitchBlock.ProngInfo, @bitCast(self.code.extra[extra_index]));
-            const capture_text = switch (info.capture) {
-                .none => "",
-                .by_val => "by_val ",
-                .by_ref => "by_ref ",
-            };
-            const inline_text = if (info.is_inline) "inline " else "";
             extra_index += 1;
-            const body = self.code.bodySlice(extra_index, info.body_len);
-            extra_index += body.len;
+            const items = self.code.refSlice(extra_index, items_len);
+            extra_index += items_len;
 
             try stream.writeAll(",\n");
             try stream.splatByteAll(' ', self.indent);
-            try stream.print("{s}{s}{s} => ", .{ capture_text, inline_text, prong_name });
+            switch (info.capture) {
+                .none => {},
+                .by_val => try stream.writeAll("by_val "),
+                .by_ref => try stream.writeAll("by_ref "),
+            }
+            if (info.is_inline) try stream.writeAll("inline ");
+            switch (special_prong) {
+                .@"else" => try stream.writeAll("else"),
+                .under, .absorbing_under => try stream.writeAll("_"),
+                .none => unreachable,
+            }
+
+            for (items) |item_ref| {
+                try stream.writeAll(", ");
+                try self.writeInstRef(stream, item_ref);
+            }
+
+            var range_i: usize = 0;
+            while (range_i < ranges_len) : (range_i += 1) {
+                const item_first = @as(Zir.Inst.Ref, @enumFromInt(self.code.extra[extra_index]));
+                extra_index += 1;
+                const item_last = @as(Zir.Inst.Ref, @enumFromInt(self.code.extra[extra_index]));
+                extra_index += 1;
+
+                try stream.writeAll(", ");
+                try self.writeInstRef(stream, item_first);
+                try stream.writeAll("...");
+                try self.writeInstRef(stream, item_last);
+            }
+
+            const body = self.code.bodySlice(extra_index, info.body_len);
+            extra_index += info.body_len;
+            try stream.writeAll(" => ");
             try self.writeBracedBody(stream, body);
         }
 
src/Sema.zig
@@ -11335,7 +11335,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
     var case_vals = try std.ArrayListUnmanaged(Air.Inst.Ref).initCapacity(gpa, scalar_cases_len + 2 * multi_cases_len);
     defer case_vals.deinit(gpa);
 
-    const special_prong = extra.data.bits.specialProng();
+    var absorbed_items: []const Zir.Inst.Ref = &.{};
+    var absorbed_ranges: []const Zir.Inst.Ref = &.{};
+
+    const special_prong = extra.data.bits.special_prong;
     const special: SpecialProng = switch (special_prong) {
         .none => .{
             .body = &.{},
@@ -11355,6 +11358,26 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
                 .has_tag_capture = info.has_tag_capture,
             };
         },
+        .absorbing_under => blk: {
+            var extra_index = header_extra_index;
+            const items_len = sema.code.extra[extra_index];
+            extra_index += 1;
+            const ranges_len = sema.code.extra[extra_index];
+            extra_index += 1;
+            const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]);
+            extra_index += 1;
+            absorbed_items = sema.code.refSlice(extra_index, items_len);
+            extra_index += items_len;
+            absorbed_ranges = sema.code.refSlice(extra_index, ranges_len * 2);
+            extra_index += ranges_len * 2;
+            break :blk .{
+                .body = sema.code.bodySlice(extra_index, info.body_len),
+                .end = extra_index + info.body_len,
+                .capture = info.capture,
+                .is_inline = info.is_inline,
+                .has_tag_capture = info.has_tag_capture,
+            };
+        },
     };
 
     // Duplicate checking variables later also used for `inline else`.
@@ -11375,7 +11398,9 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
     var else_error_ty: ?Type = null;
 
     // Validate usage of '_' prongs.
-    if (special_prong == .under and !raw_operand_ty.isNonexhaustiveEnum(zcu)) {
+    if ((special_prong == .under or special_prong == .absorbing_under) and
+        !raw_operand_ty.isNonexhaustiveEnum(zcu))
+    {
         const msg = msg: {
             const msg = try sema.errMsg(
                 src,
@@ -11409,6 +11434,22 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
             @memset(seen_enum_fields, null);
             // `range_set` is used for non-exhaustive enum values that do not correspond to any tags.
 
+            for (absorbed_items, 0..) |item_ref, item_i| {
+                _ = try sema.validateSwitchItemEnum(
+                    block,
+                    seen_enum_fields,
+                    &range_set,
+                    item_ref,
+                    cond_ty,
+                    block.src(.{ .switch_case_item = .{
+                        .switch_node_offset = src_node_offset,
+                        .case_idx = .special,
+                        .item_idx = .{ .kind = .single, .index = @intCast(item_i) },
+                    } }),
+                );
+            }
+            try sema.validateSwitchNoRange(block, @intCast(absorbed_ranges.len), cond_ty, src_node_offset);
+
             var extra_index: usize = special.end;
             {
                 var scalar_i: u32 = 0;
@@ -11692,7 +11733,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
                         );
                     }
                 },
-                .under, .none => {
+                .under, .absorbing_under, .none => {
                     if (true_count + false_count < 2) {
                         return sema.fail(
                             block,
@@ -11892,7 +11933,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
                     special.capture,
                     block.src(.{ .switch_capture = .{
                         .switch_node_offset = src_node_offset,
-                        .case_idx = LazySrcLoc.Offset.SwitchCaseIndex.special,
+                        .case_idx = .special,
                     } }),
                     undefined, // case_vals may be undefined for special prongs
                     .none,
@@ -12344,7 +12385,7 @@ fn analyzeSwitchRuntimeBlock(
                             special.capture,
                             child_block.src(.{ .switch_capture = .{
                                 .switch_node_offset = switch_node_offset,
-                                .case_idx = LazySrcLoc.Offset.SwitchCaseIndex.special,
+                                .case_idx = .special,
                             } }),
                             &.{item_ref},
                             item_ref,
@@ -12399,7 +12440,7 @@ fn analyzeSwitchRuntimeBlock(
                         special.capture,
                         child_block.src(.{ .switch_capture = .{
                             .switch_node_offset = switch_node_offset,
-                            .case_idx = LazySrcLoc.Offset.SwitchCaseIndex.special,
+                            .case_idx = .special,
                         } }),
                         &.{item_ref},
                         item_ref,
@@ -12439,7 +12480,7 @@ fn analyzeSwitchRuntimeBlock(
                         special.capture,
                         child_block.src(.{ .switch_capture = .{
                             .switch_node_offset = switch_node_offset,
-                            .case_idx = LazySrcLoc.Offset.SwitchCaseIndex.special,
+                            .case_idx = .special,
                         } }),
                         &.{item_ref},
                         item_ref,
@@ -12476,7 +12517,7 @@ fn analyzeSwitchRuntimeBlock(
                         special.capture,
                         child_block.src(.{ .switch_capture = .{
                             .switch_node_offset = switch_node_offset,
-                            .case_idx = LazySrcLoc.Offset.SwitchCaseIndex.special,
+                            .case_idx = .special,
                         } }),
                         &.{.bool_true},
                         .bool_true,
@@ -12511,7 +12552,7 @@ fn analyzeSwitchRuntimeBlock(
                         special.capture,
                         child_block.src(.{ .switch_capture = .{
                             .switch_node_offset = switch_node_offset,
-                            .case_idx = LazySrcLoc.Offset.SwitchCaseIndex.special,
+                            .case_idx = .special,
                         } }),
                         &.{.bool_false},
                         .bool_false,
@@ -12571,7 +12612,7 @@ fn analyzeSwitchRuntimeBlock(
                 special.capture,
                 child_block.src(.{ .switch_capture = .{
                     .switch_node_offset = switch_node_offset,
-                    .case_idx = LazySrcLoc.Offset.SwitchCaseIndex.special,
+                    .case_idx = .special,
                 } }),
                 undefined, // case_vals may be undefined for special prongs
                 .none,
@@ -12836,7 +12877,7 @@ fn resolveSwitchComptime(
         special.capture,
         child_block.src(.{ .switch_capture = .{
             .switch_node_offset = switch_node_offset,
-            .case_idx = LazySrcLoc.Offset.SwitchCaseIndex.special,
+            .case_idx = .special,
         } }),
         undefined, // case_vals may be undefined for special prongs
         if (special.is_inline) cond_operand else .none,
src/Zcu.zig
@@ -1684,13 +1684,13 @@ pub const SrcLoc = struct {
                 const case_nodes = tree.extraDataSlice(tree.extraData(extra_index, Ast.Node.SubRange), Ast.Node.Index);
                 for (case_nodes) |case_node| {
                     const case = tree.fullSwitchCase(case_node).?;
-                    const is_special = (case.ast.values.len == 0) or
-                        (case.ast.values.len == 1 and
-                            tree.nodeTag(case.ast.values[0]) == .identifier and
-                            mem.eql(u8, tree.tokenSlice(tree.nodeMainToken(case.ast.values[0])), "_"));
-                    if (!is_special) continue;
-
-                    return tree.nodeToSpan(case_node);
+                    if (case.isSpecial(tree)) |special_node| {
+                        return tree.tokensToSpan(
+                            tree.firstToken(case_node),
+                            tree.lastToken(case_node),
+                            tree.nodeMainToken(special_node.unwrap() orelse case_node),
+                        );
+                    }
                 } else unreachable;
             },
 
@@ -1701,11 +1701,9 @@ pub const SrcLoc = struct {
                 const case_nodes = tree.extraDataSlice(tree.extraData(extra_index, Ast.Node.SubRange), Ast.Node.Index);
                 for (case_nodes) |case_node| {
                     const case = tree.fullSwitchCase(case_node).?;
-                    const is_special = (case.ast.values.len == 0) or
-                        (case.ast.values.len == 1 and
-                            tree.nodeTag(case.ast.values[0]) == .identifier and
-                            mem.eql(u8, tree.tokenSlice(tree.nodeMainToken(case.ast.values[0])), "_"));
-                    if (is_special) continue;
+                    if (case.isSpecial(tree)) |maybe_else| {
+                        if (maybe_else == .none) continue;
+                    }
 
                     for (case.ast.values) |item_node| {
                         if (tree.nodeTag(item_node) == .switch_range) {
@@ -2111,17 +2109,21 @@ pub const SrcLoc = struct {
 
                 var multi_i: u32 = 0;
                 var scalar_i: u32 = 0;
+                var found_special = false;
+                var underscore_node: Ast.Node.OptionalIndex = .none;
                 const case = for (case_nodes) |case_node| {
                     const case = tree.fullSwitchCase(case_node).?;
                     const is_special = special: {
-                        if (case.ast.values.len == 0) break :special true;
-                        if (case.ast.values.len == 1 and tree.nodeTag(case.ast.values[0]) == .identifier) {
-                            break :special mem.eql(u8, tree.tokenSlice(tree.nodeMainToken(case.ast.values[0])), "_");
+                        if (found_special) break :special false;
+                        if (case.isSpecial(tree)) |special_node| {
+                            underscore_node = special_node;
+                            found_special = true;
+                            break :special true;
                         }
                         break :special false;
                     };
                     if (is_special) {
-                        if (want_case_idx.isSpecial()) {
+                        if (want_case_idx == LazySrcLoc.Offset.SwitchCaseIndex.special) {
                             break case;
                         }
                         continue;
@@ -2171,7 +2173,11 @@ pub const SrcLoc = struct {
                     .single => {
                         var item_i: u32 = 0;
                         for (case.ast.values) |item_node| {
-                            if (tree.nodeTag(item_node) == .switch_range) continue;
+                            if (item_node.toOptional() == underscore_node or
+                                tree.nodeTag(item_node) == .switch_range)
+                            {
+                                continue;
+                            }
                             if (item_i != want_item.index) {
                                 item_i += 1;
                                 continue;
@@ -2182,7 +2188,9 @@ pub const SrcLoc = struct {
                     .range => {
                         var range_i: u32 = 0;
                         for (case.ast.values) |item_node| {
-                            if (tree.nodeTag(item_node) != .switch_range) continue;
+                            if (tree.nodeTag(item_node) != .switch_range) {
+                                continue;
+                            }
                             if (range_i != want_item.index) {
                                 range_i += 1;
                                 continue;
@@ -2561,9 +2569,6 @@ pub const LazySrcLoc = struct {
             index: u31,
 
             pub const special: SwitchCaseIndex = @bitCast(@as(u32, std.math.maxInt(u32)));
-            pub fn isSpecial(idx: SwitchCaseIndex) bool {
-                return @as(u32, @bitCast(idx)) == @as(u32, @bitCast(special));
-            }
         };
 
         pub const SwitchItemIndex = packed struct(u32) {
test/behavior/switch.zig
@@ -1073,3 +1073,28 @@ test "switch on 8-bit mod result" {
         else => unreachable,
     }
 }
+
+test "switch on non-exhaustive enum" {
+    const E = enum(u32) {
+        a,
+        b,
+        c,
+        _,
+    };
+
+    var e: E = .a;
+    _ = &e;
+    switch (e) {
+        .a, .b => {},
+        else => return error.TestFailed,
+    }
+    switch (e) {
+        .a, .b => {},
+        .c => return error.TestFailed,
+        _ => return error.TestFailed,
+    }
+    switch (e) {
+        .a, .b => {},
+        .c, _ => return error.TestFailed,
+    }
+}
test/cases/compile_errors/switch_expression-non_exhaustive_absorbing.zig
@@ -0,0 +1,33 @@
+const E = enum(u8) {
+    a,
+    b,
+    _,
+};
+const U = union(E) {
+    a: i32,
+    b: u32,
+};
+pub export fn entry1() void {
+    const e: E = .b;
+    switch (e) { // error: switch not handling the tag `b`
+        .a, _ => {},
+    }
+}
+pub export fn entry2() void {
+    const u = U{ .a = 2 };
+    switch (u) { // error: `_` prong not allowed when switching on tagged union
+        .a => {},
+        .b, _ => {},
+    }
+}
+
+// error
+// backend=stage2
+// target=native
+//
+// :12:5: error: switch must handle all possibilities
+// :3:5: note: unhandled enumeration value: 'b'
+// :1:11: note: enum 'tmp.E' declared here
+// :18:5: error: '_' prong only allowed when switching on non-exhaustive enums
+// :20:13: note: '_' prong here
+// :18:5: note: consider using 'else'
test/cases/compile_errors/switching_with_exhaustive_enum_has___prong_.zig
@@ -16,5 +16,5 @@ pub export fn entry() void {
 // target=native
 //
 // :7:5: error: '_' prong only allowed when switching on non-exhaustive enums
-// :10:11: note: '_' prong here
+// :10:9: note: '_' prong here
 // :7:5: note: consider using 'else'
test/cases/compile_errors/switching_with_non-exhaustive_enums.zig
@@ -39,5 +39,5 @@ pub export fn entry3() void {
 // :1:11: note: enum 'tmp.E' declared here
 // :19:5: error: switch on non-exhaustive enum must include 'else' or '_' prong
 // :26:5: error: '_' prong only allowed when switching on non-exhaustive enums
-// :29:11: note: '_' prong here
+// :29:9: note: '_' prong here
 // :26:5: note: consider using 'else'