Commit 00609e7edb

mlugg <mlugg@mlugg.co.uk>
2023-05-26 05:10:51
Eliminate switch_capture and switch_capture_ref ZIR tags
These tags are unnecessary, as this information can be more efficiently encoded within the switch_block instruction itself. We also use a neat little trick to avoid needing a dummy instruction (like is used for errdefer captures): since the switch_block itself cannot otherwise be referenced within a prong, we can repurpose its index within prongs to refer to the captured value.
1 parent cebd800
src/AstGen.zig
@@ -2612,8 +2612,6 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As
             .switch_block,
             .switch_cond,
             .switch_cond_ref,
-            .switch_capture,
-            .switch_capture_ref,
             .switch_capture_tag,
             .struct_init_empty,
             .struct_init,
@@ -6876,17 +6874,22 @@ fn switchExpr(
         var dbg_var_inst: Zir.Inst.Ref = undefined;
         var dbg_var_tag_name: ?u32 = null;
         var dbg_var_tag_inst: Zir.Inst.Ref = undefined;
-        var capture_inst: Zir.Inst.Index = 0;
         var tag_inst: Zir.Inst.Index = 0;
         var capture_val_scope: Scope.LocalVal = undefined;
         var tag_scope: Scope.LocalVal = undefined;
+
+        var capture: Zir.Inst.SwitchBlock.ProngInfo.Capture = .none;
+
         const sub_scope = blk: {
             const payload_token = case.payload_token orelse break :blk &case_scope.base;
             const ident = if (token_tags[payload_token] == .asterisk)
                 payload_token + 1
             else
                 payload_token;
+
             const is_ptr = ident != payload_token;
+            capture = if (is_ptr) .by_ref else .by_val;
+
             const ident_slice = tree.tokenSlice(ident);
             var payload_sub_scope: *Scope = undefined;
             if (mem.eql(u8, ident_slice, "_")) {
@@ -6895,46 +6898,18 @@ fn switchExpr(
                 }
                 payload_sub_scope = &case_scope.base;
             } else {
-                if (case_node == special_node) {
-                    const capture_tag: Zir.Inst.Tag = if (is_ptr)
-                        .switch_capture_ref
-                    else
-                        .switch_capture;
-                    capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len);
-                    try astgen.instructions.append(gpa, .{
-                        .tag = capture_tag,
-                        .data = .{
-                            .switch_capture = .{
-                                .switch_inst = switch_block,
-                                // Max int communicates that this is the else/underscore prong.
-                                .prong_index = std.math.maxInt(u32),
-                            },
-                        },
-                    });
-                } else {
-                    const capture_tag: Zir.Inst.Tag = if (is_ptr) .switch_capture_ref else .switch_capture;
-                    const capture_index = if (is_multi_case) scalar_cases_len + multi_case_index else scalar_case_index;
-                    capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len);
-                    try astgen.instructions.append(gpa, .{
-                        .tag = capture_tag,
-                        .data = .{ .switch_capture = .{
-                            .switch_inst = switch_block,
-                            .prong_index = capture_index,
-                        } },
-                    });
-                }
                 const capture_name = try astgen.identAsString(ident);
                 try astgen.detectLocalShadowing(&case_scope.base, capture_name, ident, ident_slice, .capture);
                 capture_val_scope = .{
                     .parent = &case_scope.base,
                     .gen_zir = &case_scope,
                     .name = capture_name,
-                    .inst = indexToRef(capture_inst),
+                    .inst = indexToRef(switch_block),
                     .token_src = payload_token,
                     .id_cat = .capture,
                 };
                 dbg_var_name = capture_name;
-                dbg_var_inst = indexToRef(capture_inst);
+                dbg_var_inst = indexToRef(switch_block);
                 payload_sub_scope = &capture_val_scope.base;
             }
 
@@ -7023,7 +6998,6 @@ fn switchExpr(
             case_scope.instructions_top = parent_gz.instructions.items.len;
             defer case_scope.unstack();
 
-            if (capture_inst != 0) try case_scope.instructions.append(gpa, capture_inst);
             if (tag_inst != 0) try case_scope.instructions.append(gpa, tag_inst);
             try case_scope.addDbgBlockBegin();
             if (dbg_var_name) |some| {
@@ -7042,10 +7016,28 @@ fn switchExpr(
             }
 
             const case_slice = case_scope.instructionsSlice();
-            const body_len = astgen.countBodyLenAfterFixups(case_slice);
+            // Since we use the switch_block instruction itself to refer to the
+            // capture, which will not be added to the child block, we need to
+            // handle ref_table manually.
+            const refs_len = refs: {
+                var n: usize = 0;
+                var check_inst = switch_block;
+                while (astgen.ref_table.get(check_inst)) |ref_inst| {
+                    n += 1;
+                    check_inst = ref_inst;
+                }
+                break :refs n;
+            };
+            const body_len = refs_len + astgen.countBodyLenAfterFixups(case_slice);
             try payloads.ensureUnusedCapacity(gpa, body_len);
-            const inline_bit = @as(u32, @boolToInt(case.inline_token != null)) << 31;
-            payloads.items[body_len_index] = body_len | inline_bit;
+            payloads.items[body_len_index] = @bitCast(u32, Zir.Inst.SwitchBlock.ProngInfo{
+                .body_len = @intCast(u29, body_len),
+                .capture = capture,
+                .is_inline = case.inline_token != null,
+            });
+            if (astgen.ref_table.fetchRemove(switch_block)) |kv| {
+                appendPossiblyRefdBodyInst(astgen, payloads, kv.value);
+            }
             appendBodyWithFixupsArrayList(astgen, payloads, case_slice);
         }
     }
@@ -7092,7 +7084,7 @@ fn switchExpr(
             end_index += 3 + items_len + 2 * ranges_len;
         }
 
-        const body_len = @truncate(u31, payloads.items[body_len_index]);
+        const body_len = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, payloads.items[body_len_index]).body_len;
         end_index += body_len;
 
         switch (strat.tag) {
src/Module.zig
@@ -5871,6 +5871,7 @@ pub const SwitchProngSrc = union(enum) {
     multi: Multi,
     range: Multi,
     multi_capture: u32,
+    special,
 
     pub const Multi = struct {
         prong: u32,
@@ -5908,14 +5909,22 @@ pub const SwitchProngSrc = union(enum) {
         var scalar_i: u32 = 0;
         for (case_nodes) |case_node| {
             const case = tree.fullSwitchCase(case_node).?;
-            if (case.ast.values.len == 0)
-                continue;
-            if (case.ast.values.len == 1 and
-                node_tags[case.ast.values[0]] == .identifier and
-                mem.eql(u8, tree.tokenSlice(main_tokens[case.ast.values[0]]), "_"))
-            {
-                continue;
+
+            const is_special = special: {
+                if (case.ast.values.len == 0) break :special true;
+                if (case.ast.values.len == 1 and node_tags[case.ast.values[0]] == .identifier) {
+                    break :special mem.eql(u8, tree.tokenSlice(main_tokens[case.ast.values[0]]), "_");
+                }
+                break :special false;
+            };
+
+            if (is_special) {
+                if (prong_src != .special) continue;
+                return LazySrcLoc.nodeOffset(
+                    decl.nodeIndexToRelative(case.ast.values[0]),
+                );
             }
+
             const is_multi = case.ast.values.len != 1 or
                 node_tags[case.ast.values[0]] == .switch_range;
 
@@ -5956,6 +5965,7 @@ pub const SwitchProngSrc = union(enum) {
                         range_i += 1;
                     } else unreachable;
                 },
+                .special => {},
             }
             if (is_multi) {
                 multi_i += 1;
src/print_zir.zig
@@ -436,10 +436,6 @@ const Writer = struct {
 
             .@"unreachable" => try self.writeUnreachable(stream, inst),
 
-            .switch_capture,
-            .switch_capture_ref,
-            => try self.writeSwitchCapture(stream, inst),
-
             .dbg_stmt => try self.writeDbgStmt(stream, inst),
 
             .dbg_block_begin,
@@ -1913,15 +1909,20 @@ const Writer = struct {
                 else => break :else_prong,
             };
 
-            const body_len = @truncate(u31, self.code.extra[extra_index]);
-            const inline_text = if (self.code.extra[extra_index] >> 31 != 0) "inline " else "";
+            const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, self.code.extra[extra_index]);
+            const capture_text = switch (info.capture) {
+                .none => "",
+                .by_val => "by_val ",
+                .by_ref => "by_ref ",
+            };
+            const inline_text = if (info.is_inline) "inline " else "";
             extra_index += 1;
-            const body = self.code.extra[extra_index..][0..body_len];
+            const body = self.code.extra[extra_index..][0..info.body_len];
             extra_index += body.len;
 
             try stream.writeAll(",\n");
             try stream.writeByteNTimes(' ', self.indent);
-            try stream.print("{s}{s} => ", .{ inline_text, prong_name });
+            try stream.print("{s}{s}{s} => ", .{ capture_text, inline_text, prong_name });
             try self.writeBracedBody(stream, body);
         }
 
@@ -1931,15 +1932,19 @@ const Writer = struct {
             while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
                 const item_ref = @intToEnum(Zir.Inst.Ref, self.code.extra[extra_index]);
                 extra_index += 1;
-                const body_len = @truncate(u31, self.code.extra[extra_index]);
-                const is_inline = self.code.extra[extra_index] >> 31 != 0;
+                const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, self.code.extra[extra_index]);
                 extra_index += 1;
-                const body = self.code.extra[extra_index..][0..body_len];
-                extra_index += body_len;
+                const body = self.code.extra[extra_index..][0..info.body_len];
+                extra_index += info.body_len;
 
                 try stream.writeAll(",\n");
                 try stream.writeByteNTimes(' ', self.indent);
-                if (is_inline) try stream.writeAll("inline ");
+                switch (info.capture) {
+                    .none => {},
+                    .by_val => try stream.writeAll("by_val "),
+                    .by_ref => try stream.writeAll("by_ref "),
+                }
+                if (info.is_inline) try stream.writeAll("inline ");
                 try self.writeInstRef(stream, item_ref);
                 try stream.writeAll(" => ");
                 try self.writeBracedBody(stream, body);
@@ -1952,15 +1957,19 @@ const Writer = struct {
                 extra_index += 1;
                 const ranges_len = self.code.extra[extra_index];
                 extra_index += 1;
-                const body_len = @truncate(u31, self.code.extra[extra_index]);
-                const is_inline = self.code.extra[extra_index] >> 31 != 0;
+                const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, self.code.extra[extra_index]);
                 extra_index += 1;
                 const items = self.code.refSlice(extra_index, items_len);
                 extra_index += items_len;
 
                 try stream.writeAll(",\n");
                 try stream.writeByteNTimes(' ', self.indent);
-                if (is_inline) try stream.writeAll("inline ");
+                switch (info.capture) {
+                    .none => {},
+                    .by_val => try stream.writeAll("by_val "),
+                    .by_ref => try stream.writeAll("by_ref "),
+                }
+                if (info.is_inline) try stream.writeAll("inline ");
 
                 for (items, 0..) |item_ref, item_i| {
                     if (item_i != 0) try stream.writeAll(", ");
@@ -1982,8 +1991,8 @@ const Writer = struct {
                     try self.writeInstRef(stream, item_last);
                 }
 
-                const body = self.code.extra[extra_index..][0..body_len];
-                extra_index += body_len;
+                const body = self.code.extra[extra_index..][0..info.body_len];
+                extra_index += info.body_len;
                 try stream.writeAll(" => ");
                 try self.writeBracedBody(stream, body);
             }
@@ -2435,12 +2444,6 @@ const Writer = struct {
         try self.writeSrc(stream, src);
     }
 
-    fn writeSwitchCapture(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
-        const inst_data = self.code.instructions.items(.data)[inst].switch_capture;
-        try self.writeInstIndex(stream, inst_data.switch_inst);
-        try stream.print(", {d})", .{inst_data.prong_index});
-    }
-
     fn writeDbgStmt(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
         const inst_data = self.code.instructions.items(.data)[inst].dbg_stmt;
         try stream.print("{d}, {d})", .{ inst_data.line + 1, inst_data.column + 1 });
src/Sema.zig
@@ -277,9 +277,6 @@ pub const Block = struct {
 
     c_import_buf: ?*std.ArrayList(u8) = null,
 
-    /// type of `err` in `else => |err|`
-    switch_else_err_ty: ?Type = null,
-
     /// Value for switch_capture in an inline case
     inline_case_capture: Air.Inst.Ref = .none,
 
@@ -397,7 +394,6 @@ pub const Block = struct {
             .want_safety = parent.want_safety,
             .float_mode = parent.float_mode,
             .c_import_buf = parent.c_import_buf,
-            .switch_else_err_ty = parent.switch_else_err_ty,
             .error_return_trace_index = parent.error_return_trace_index,
         };
     }
@@ -1017,8 +1013,6 @@ fn analyzeBodyInner(
             .switch_block                 => try sema.zirSwitchBlock(block, inst),
             .switch_cond                  => try sema.zirSwitchCond(block, inst, false),
             .switch_cond_ref              => try sema.zirSwitchCond(block, inst, true),
-            .switch_capture               => try sema.zirSwitchCapture(block, inst, false),
-            .switch_capture_ref           => try sema.zirSwitchCapture(block, inst, true),
             .switch_capture_tag           => try sema.zirSwitchCaptureTag(block, inst),
             .type_info                    => try sema.zirTypeInfo(block, inst),
             .size_of                      => try sema.zirSizeOf(block, inst),
@@ -10083,61 +10077,160 @@ fn zirSliceLength(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
     return sema.analyzeSlice(block, src, array_ptr, start, len, sentinel, sentinel_src, ptr_src, start_src, end_src, true);
 }
 
-fn zirSwitchCapture(
+/// Resolve a switch prong which is determined at comptime to have no peers. Uses
+/// `resolveBlockBody`. Sets up captures as needed.
+fn resolveSwitchProngComptime(
     sema: *Sema,
-    block: *Block,
-    inst: Zir.Inst.Index,
-    is_ref: bool,
+    parent_block: *Block,
+    child_block: *Block,
+    src: LazySrcLoc,
+    operand: Air.Inst.Ref,
+    operand_ptr: Air.Inst.Ref,
+    prong_type: enum { normal, special },
+    prong_body: []const Zir.Inst.Index,
+    capture: Zir.Inst.SwitchBlock.ProngInfo.Capture,
+    raw_capture_src: Module.SwitchProngSrc,
+    else_error_ty: ?Type,
+    case_vals: []const Air.Inst.Ref,
+    switch_block_inst: Zir.Inst.Index,
+    merges: *Block.Merges,
 ) CompileError!Air.Inst.Ref {
-    const tracy = trace(@src());
-    defer tracy.end();
+    switch (capture) {
+        .none => {
+            return sema.resolveBlockBody(parent_block, src, child_block, prong_body, switch_block_inst, merges);
+        },
+
+        .by_val, .by_ref => {
+            const zir_datas = sema.code.instructions.items(.data);
+            const switch_info = zir_datas[switch_block_inst].pl_node;
+
+            const capture_ref = try sema.analyzeSwitchCapture(
+                child_block,
+                capture == .by_ref,
+                operand,
+                operand_ptr,
+                switch_info.src_node,
+                prong_type == .special,
+                raw_capture_src,
+                else_error_ty,
+                case_vals,
+            );
+
+            if (sema.typeOf(capture_ref).isNoReturn(sema.mod)) {
+                // This prong should be unreachable!
+                return Air.Inst.Ref.unreachable_value;
+            }
+
+            sema.inst_map.putAssumeCapacity(switch_block_inst, capture_ref);
+            defer assert(sema.inst_map.remove(switch_block_inst));
+
+            return sema.resolveBlockBody(parent_block, src, child_block, prong_body, switch_block_inst, merges);
+        },
+    }
+}
+
+/// Analyze a switch prong which may have peers at runtime. Uses
+/// `analyzeBodyRuntimeBreak`. Sets up captures as needed.
+fn analyzeSwitchProngRuntime(
+    sema: *Sema,
+    case_block: *Block,
+    operand: Air.Inst.Ref,
+    operand_ptr: Air.Inst.Ref,
+    prong_type: enum { normal, special },
+    prong_body: []const Zir.Inst.Index,
+    capture: Zir.Inst.SwitchBlock.ProngInfo.Capture,
+    raw_capture_src: Module.SwitchProngSrc,
+    else_error_ty: ?Type,
+    case_vals: []const Air.Inst.Ref,
+    switch_block_inst: Zir.Inst.Index,
+) CompileError!void {
+    switch (capture) {
+        .none => {
+            return sema.analyzeBodyRuntimeBreak(case_block, prong_body);
+        },
+
+        .by_val, .by_ref => {
+            const zir_datas = sema.code.instructions.items(.data);
+            const switch_info = zir_datas[switch_block_inst].pl_node;
+
+            const capture_ref = try sema.analyzeSwitchCapture(
+                case_block,
+                capture == .by_ref,
+                operand,
+                operand_ptr,
+                switch_info.src_node,
+                prong_type == .special,
+                raw_capture_src,
+                else_error_ty,
+                case_vals,
+            );
 
+            if (sema.typeOf(capture_ref).isNoReturn(sema.mod)) {
+                // No need to analyze any further, the prong is unreachable
+                return;
+            }
+
+            sema.inst_map.putAssumeCapacity(switch_block_inst, capture_ref);
+            defer assert(sema.inst_map.remove(switch_block_inst));
+
+            return sema.analyzeBodyRuntimeBreak(case_block, prong_body);
+        },
+    }
+}
+
+fn analyzeSwitchCapture(
+    sema: *Sema,
+    /// Must be the child block so that `inline_case_capture` is set for inline prongs.
+    block: *Block,
+    capture_byref: bool,
+    /// The raw switch operand value.
+    operand: Air.Inst.Ref,
+    /// Pointer to the raw switch operand. May be undefined if `capture_byref` is false.
+    operand_ptr: Air.Inst.Ref,
+    switch_node_offset: i32,
+    /// `true` if this is the `else` or `_` prong of a switch.
+    is_special_prong: bool,
+    /// Must use the `scalar`, `special`, or `multi_capture` union field.
+    raw_capture_src: Module.SwitchProngSrc,
+    /// If this is the `else` prong of a switch on an error set, this is the
+    /// type that should be assigned to the capture. If `null`, the prong should
+    /// be unreachable.
+    else_error_ty: ?Type,
+    /// The set of all values which can reach this prong. May be undefined if
+    /// the prong has `is_special_prong` or contains ranges.
+    case_vals: []const Air.Inst.Ref,
+) CompileError!Air.Inst.Ref {
     const mod = sema.mod;
     const gpa = sema.gpa;
-    const zir_datas = sema.code.instructions.items(.data);
-    const capture_info = zir_datas[inst].switch_capture;
-    const switch_info = zir_datas[capture_info.switch_inst].pl_node;
-    const switch_extra = sema.code.extraData(Zir.Inst.SwitchBlock, switch_info.payload_index);
-    const operand_src: LazySrcLoc = .{ .node_offset_switch_operand = switch_info.src_node };
-    const cond = try sema.resolveInst(switch_extra.data.operand);
-    const cond_ty = sema.typeOf(cond);
-    const cond_inst = Zir.refToIndex(switch_extra.data.operand).?;
-    const cond_info = zir_datas[cond_inst].un_node;
-    const cond_tag = sema.code.instructions.items(.tag)[cond_inst];
-    const operand_is_ref = cond_tag == .switch_cond_ref;
-    const operand_ptr = try sema.resolveInst(cond_info.operand);
-    const operand_ptr_ty = sema.typeOf(operand_ptr);
-    const operand_ty = if (operand_is_ref) operand_ptr_ty.childType(mod) else operand_ptr_ty;
+    const operand_ty = sema.typeOf(operand);
+    const operand_ptr_ty = if (capture_byref) sema.typeOf(operand_ptr) else undefined;
+    const operand_src: LazySrcLoc = .{ .node_offset_switch_operand = switch_node_offset };
 
     if (block.inline_case_capture != .none) {
-        const item_val = sema.resolveConstValue(block, .unneeded, block.inline_case_capture, undefined) catch unreachable;
-        const resolved_item_val = try sema.resolveLazyValue(item_val);
+        const item_val = sema.resolveConstValue(block, .unneeded, block.inline_case_capture, "") catch unreachable;
         if (operand_ty.zigTypeTag(mod) == .Union) {
-            const field_index = @intCast(u32, operand_ty.unionTagFieldIndex(resolved_item_val, mod).?);
+            const field_index = @intCast(u32, operand_ty.unionTagFieldIndex(item_val, mod).?);
             const union_obj = mod.typeToUnion(operand_ty).?;
             const field_ty = union_obj.fields.values()[field_index].ty;
-            if (try sema.resolveDefinedValue(block, sema.src, operand_ptr)) |union_val| {
-                if (is_ref) {
+            if (capture_byref) {
+                if (try sema.resolveDefinedValue(block, sema.src, operand_ptr)) |union_ptr| {
                     const ptr_field_ty = try Type.ptr(sema.arena, mod, .{
                         .pointee_type = field_ty,
                         .mutable = operand_ptr_ty.ptrIsMutable(mod),
                         .@"volatile" = operand_ptr_ty.isVolatilePtr(mod),
                         .@"addrspace" = operand_ptr_ty.ptrAddressSpace(mod),
                     });
-                    return sema.addConstant(ptr_field_ty, (try mod.intern(.{ .ptr = .{
-                        .ty = ptr_field_ty.toIntern(),
-                        .addr = .{ .field = .{
-                            .base = union_val.toIntern(),
-                            .index = field_index,
-                        } },
-                    } })).toValue());
+                    return sema.addConstant(
+                        ptr_field_ty,
+                        (try mod.intern(.{ .ptr = .{
+                            .ty = ptr_field_ty.toIntern(),
+                            .addr = .{ .field = .{
+                                .base = union_ptr.toIntern(),
+                                .index = field_index,
+                            } },
+                        } })).toValue(),
+                    );
                 }
-                return sema.addConstant(
-                    field_ty,
-                    mod.intern_pool.indexToKey(union_val.toIntern()).un.val.toValue(),
-                );
-            }
-            if (is_ref) {
                 const ptr_field_ty = try Type.ptr(sema.arena, mod, .{
                     .pointee_type = field_ty,
                     .mutable = operand_ptr_ty.ptrIsMutable(mod),
@@ -10146,29 +10239,27 @@ fn zirSwitchCapture(
                 });
                 return block.addStructFieldPtr(operand_ptr, field_index, ptr_field_ty);
             } else {
-                return block.addStructFieldVal(operand_ptr, field_index, field_ty);
+                if (try sema.resolveDefinedValue(block, sema.src, operand)) |union_val| {
+                    const tag_and_val = mod.intern_pool.indexToKey(union_val.toIntern()).un;
+                    return sema.addConstant(field_ty, tag_and_val.val.toValue());
+                }
+                return block.addStructFieldVal(operand, field_index, field_ty);
             }
-        } else if (is_ref) {
-            return sema.addConstantMaybeRef(block, operand_ty, resolved_item_val, true);
+        } else if (capture_byref) {
+            return sema.addConstantMaybeRef(block, operand_ty, item_val, true);
         } else {
             return block.inline_case_capture;
         }
     }
 
-    const operand = if (operand_is_ref)
-        try sema.analyzeLoad(block, operand_src, operand_ptr, operand_src)
-    else
-        operand_ptr;
-
-    if (capture_info.prong_index == std.math.maxInt(@TypeOf(capture_info.prong_index))) {
-        // It is the else/`_` prong.
-        if (is_ref) {
+    if (is_special_prong) {
+        if (capture_byref) {
             return operand_ptr;
         }
 
         switch (operand_ty.zigTypeTag(mod)) {
-            .ErrorSet => if (block.switch_else_err_ty) |some| {
-                return sema.bitCast(block, some, operand, operand_src, null);
+            .ErrorSet => if (else_error_ty) |ty| {
+                return sema.bitCast(block, ty, operand, operand_src, null);
             } else {
                 try block.addUnreachable(false);
                 return Air.Inst.Ref.unreachable_value;
@@ -10177,41 +10268,33 @@ fn zirSwitchCapture(
         }
     }
 
-    // Note that these are the *uncasted* prong items.
-    // Also note that items from ranges are not included so this only works for non-ranged types.
-    const items = switch_extra.data.getProng(sema.code, switch_extra.end, capture_info.prong_index).items;
-
     switch (operand_ty.zigTypeTag(mod)) {
         .Union => {
             const union_obj = mod.typeToUnion(operand_ty).?;
-            const first_item = try sema.resolveInst(items[0]);
-            // Previous switch validation ensured this will succeed
-            const first_item_coerced = try sema.coerce(block, cond_ty, first_item, .unneeded);
-            const first_item_val = sema.resolveConstValue(block, .unneeded, first_item_coerced, "") catch unreachable;
+            const first_item_val = sema.resolveConstValue(block, .unneeded, case_vals[0], "") catch unreachable;
 
             const first_field_index = @intCast(u32, operand_ty.unionTagFieldIndex(first_item_val, mod).?);
             const first_field = union_obj.fields.values()[first_field_index];
 
-            for (items[1..], 0..) |item, i| {
-                const item_ref = try sema.resolveInst(item);
-                // Previous switch validation ensured this will succeed
-                const item_coerced = try sema.coerce(block, cond_ty, item_ref, .unneeded);
-                const item_val = sema.resolveConstValue(block, .unneeded, item_coerced, "") catch unreachable;
+            for (case_vals[1..], 0..) |item, i| {
+                const item_val = sema.resolveConstValue(block, .unneeded, item, "") catch unreachable;
 
                 const field_index = operand_ty.unionTagFieldIndex(item_val, mod).?;
                 const field = union_obj.fields.values()[field_index];
                 if (!field.ty.eql(first_field.ty, mod)) {
                     const msg = msg: {
-                        const raw_capture_src = Module.SwitchProngSrc{ .multi_capture = capture_info.prong_index };
-                        const capture_src = raw_capture_src.resolve(mod, mod.declPtr(block.src_decl), switch_info.src_node, .first);
+                        const capture_src = raw_capture_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .none);
 
                         const msg = try sema.errMsg(block, capture_src, "capture group with incompatible types", .{});
                         errdefer msg.destroy(gpa);
 
-                        const raw_first_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = capture_info.prong_index, .item = 0 } };
-                        const first_item_src = raw_first_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_info.src_node, .first);
-                        const raw_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = capture_info.prong_index, .item = 1 + @intCast(u32, i) } };
-                        const item_src = raw_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_info.src_node, .first);
+                        // This must be a multi-prong so this must be a `multi_capture` src
+                        const multi_idx = raw_capture_src.multi_capture;
+
+                        const raw_first_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 0 } };
+                        const first_item_src = raw_first_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .first);
+                        const raw_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 1 + @intCast(u32, i) } };
+                        const item_src = raw_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .first);
                         try sema.errNote(block, first_item_src, msg, "type '{}' here", .{first_field.ty.fmt(mod)});
                         try sema.errNote(block, item_src, msg, "type '{}' here", .{field.ty.fmt(mod)});
                         break :msg msg;
@@ -10220,7 +10303,7 @@ fn zirSwitchCapture(
                 }
             }
 
-            if (is_ref) {
+            if (capture_byref) {
                 const field_ty_ptr = try Type.ptr(sema.arena, mod, .{
                     .pointee_type = first_field.ty,
                     .@"addrspace" = .generic,
@@ -10250,31 +10333,35 @@ fn zirSwitchCapture(
             return block.addStructFieldVal(operand, first_field_index, first_field.ty);
         },
         .ErrorSet => {
-            if (items.len > 1) {
-                var names: Module.Fn.InferredErrorSet.NameMap = .{};
-                try names.ensureUnusedCapacity(sema.arena, items.len);
-                for (items) |item| {
-                    const item_ref = try sema.resolveInst(item);
-                    // Previous switch validation ensured this will succeed
-                    const item_val = sema.resolveConstLazyValue(block, .unneeded, item_ref, "") catch unreachable;
-                    names.putAssumeCapacityNoClobber(item_val.getErrorName(mod).unwrap().?, {});
-                }
-                const else_error_ty = try mod.errorSetFromUnsortedNames(names.keys());
-
-                return sema.bitCast(block, else_error_ty, operand, operand_src, null);
-            } else {
-                const item_ref = try sema.resolveInst(items[0]);
-                // Previous switch validation ensured this will succeed
-                const item_val = sema.resolveConstLazyValue(block, .unneeded, item_ref, "") catch unreachable;
+            if (capture_byref) {
+                const capture_src = raw_capture_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .none);
+                return sema.fail(
+                    block,
+                    capture_src,
+                    "error set cannot be captured by reference",
+                    .{},
+                );
+            }
 
+            if (case_vals.len == 1) {
+                const item_val = sema.resolveConstValue(block, .unneeded, case_vals[0], "") catch unreachable;
                 const item_ty = try mod.singleErrorSetType(item_val.getErrorName(mod).unwrap().?);
                 return sema.bitCast(block, item_ty, operand, operand_src, null);
             }
+
+            var names: Module.Fn.InferredErrorSet.NameMap = .{};
+            try names.ensureUnusedCapacity(sema.arena, case_vals.len);
+            for (case_vals) |err| {
+                const err_val = sema.resolveConstValue(block, .unneeded, err, "") catch unreachable;
+                names.putAssumeCapacityNoClobber(err_val.getErrorName(mod).unwrap().?, {});
+            }
+            const error_ty = try mod.errorSetFromUnsortedNames(names.keys());
+            return sema.bitCast(block, error_ty, operand, operand_src, null);
         },
         else => {
-            // In this case the capture value is just the passed-through value of the
-            // switch condition.
-            if (is_ref) {
+            // In this case the capture value is just the passed-through value
+            // of the switch condition.
+            if (capture_byref) {
                 return operand_ptr;
             } else {
                 return operand;
@@ -10415,28 +10502,42 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
     var case_vals = try std.ArrayListUnmanaged(Air.Inst.Ref).initCapacity(gpa, scalar_cases_len + 2 * multi_cases_len);
     defer case_vals.deinit(gpa);
 
+    const Special = struct {
+        body: []const Zir.Inst.Index,
+        end: usize,
+        capture: Zir.Inst.SwitchBlock.ProngInfo.Capture,
+        is_inline: bool,
+    };
+
     const special_prong = extra.data.bits.specialProng();
-    const special: struct { body: []const Zir.Inst.Index, end: usize, is_inline: bool } = switch (special_prong) {
-        .none => .{ .body = &.{}, .end = header_extra_index, .is_inline = false },
+    const special: Special = switch (special_prong) {
+        .none => .{ .body = &.{}, .end = header_extra_index, .capture = .none, .is_inline = false },
         .under, .@"else" => blk: {
-            const body_len = @truncate(u31, sema.code.extra[header_extra_index]);
+            const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[header_extra_index]);
             const extra_body_start = header_extra_index + 1;
             break :blk .{
-                .body = sema.code.extra[extra_body_start..][0..body_len],
-                .end = extra_body_start + body_len,
-                .is_inline = sema.code.extra[header_extra_index] >> 31 != 0,
+                .body = sema.code.extra[extra_body_start..][0..info.body_len],
+                .end = extra_body_start + info.body_len,
+                .capture = info.capture,
+                .is_inline = info.is_inline,
             };
         },
     };
 
-    const maybe_union_ty = blk: {
+    const raw_operand: struct { val: Air.Inst.Ref, ptr: Air.Inst.Ref } = blk: {
         const zir_tags = sema.code.instructions.items(.tag);
         const zir_data = sema.code.instructions.items(.data);
         const cond_index = Zir.refToIndex(extra.data.operand).?;
-        const raw_operand = sema.resolveInst(zir_data[cond_index].un_node.operand) catch unreachable;
-        const target_ty = sema.typeOf(raw_operand);
-        break :blk if (zir_tags[cond_index] == .switch_cond_ref) target_ty.childType(mod) else target_ty;
+        const raw = sema.resolveInst(zir_data[cond_index].un_node.operand) catch unreachable;
+        if (zir_tags[cond_index] == .switch_cond_ref) {
+            const val = try sema.analyzeLoad(block, src, raw, operand_src);
+            break :blk .{ .val = val, .ptr = raw };
+        } else {
+            break :blk .{ .val = raw, .ptr = undefined };
+        }
     };
+
+    const maybe_union_ty = sema.typeOf(raw_operand.val);
     const union_originally = maybe_union_ty.zigTypeTag(mod) == .Union;
 
     // Duplicate checking variables later also used for `inline else`.
@@ -10496,9 +10597,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                 while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
                     const item_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]);
                     extra_index += 1;
-                    const body_len = @truncate(u31, sema.code.extra[extra_index]);
-                    extra_index += 1;
-                    extra_index += body_len;
+                    const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]);
+                    extra_index += 1 + info.body_len;
 
                     case_vals.appendAssumeCapacity(try sema.validateSwitchItemEnum(
                         block,
@@ -10518,10 +10618,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                     extra_index += 1;
                     const ranges_len = sema.code.extra[extra_index];
                     extra_index += 1;
-                    const body_len = @truncate(u31, sema.code.extra[extra_index]);
+                    const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]);
                     extra_index += 1;
                     const items = sema.code.refSlice(extra_index, items_len);
-                    extra_index += items_len + body_len;
+                    extra_index += items_len + info.body_len;
 
                     try case_vals.ensureUnusedCapacity(gpa, items.len);
                     for (items, 0..) |item_ref, item_i| {
@@ -10596,9 +10696,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                 while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
                     const item_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]);
                     extra_index += 1;
-                    const body_len = @truncate(u31, sema.code.extra[extra_index]);
-                    extra_index += 1;
-                    extra_index += body_len;
+                    const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]);
+                    extra_index += 1 + info.body_len;
 
                     case_vals.appendAssumeCapacity(try sema.validateSwitchItemError(
                         block,
@@ -10617,10 +10716,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                     extra_index += 1;
                     const ranges_len = sema.code.extra[extra_index];
                     extra_index += 1;
-                    const body_len = @truncate(u31, sema.code.extra[extra_index]);
+                    const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]);
                     extra_index += 1;
                     const items = sema.code.refSlice(extra_index, items_len);
-                    extra_index += items_len + body_len;
+                    extra_index += items_len + info.body_len;
 
                     try case_vals.ensureUnusedCapacity(gpa, items.len);
                     for (items, 0..) |item_ref, item_i| {
@@ -10694,7 +10793,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                         .dbg_block_end,
                         .dbg_stmt,
                         .dbg_var_val,
-                        .switch_capture,
                         .ret_type,
                         .as_node,
                         .ret_node,
@@ -10739,9 +10837,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                 while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
                     const item_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]);
                     extra_index += 1;
-                    const body_len = @truncate(u31, sema.code.extra[extra_index]);
-                    extra_index += 1;
-                    extra_index += body_len;
+                    const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]);
+                    extra_index += 1 + info.body_len;
 
                     case_vals.appendAssumeCapacity(try sema.validateSwitchItemInt(
                         block,
@@ -10760,7 +10857,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                     extra_index += 1;
                     const ranges_len = sema.code.extra[extra_index];
                     extra_index += 1;
-                    const body_len = @truncate(u31, sema.code.extra[extra_index]);
+                    const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]);
                     extra_index += 1;
                     const items = sema.code.refSlice(extra_index, items_len);
                     extra_index += items_len;
@@ -10798,7 +10895,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                         case_vals.appendAssumeCapacity(vals[1]);
                     }
 
-                    extra_index += body_len;
+                    extra_index += info.body_len;
                 }
             }
 
@@ -10835,9 +10932,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                 while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
                     const item_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]);
                     extra_index += 1;
-                    const body_len = @truncate(u31, sema.code.extra[extra_index]);
-                    extra_index += 1;
-                    extra_index += body_len;
+                    const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]);
+                    extra_index += 1 + info.body_len;
 
                     case_vals.appendAssumeCapacity(try sema.validateSwitchItemBool(
                         block,
@@ -10856,10 +10952,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                     extra_index += 1;
                     const ranges_len = sema.code.extra[extra_index];
                     extra_index += 1;
-                    const body_len = @truncate(u31, sema.code.extra[extra_index]);
+                    const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]);
                     extra_index += 1;
                     const items = sema.code.refSlice(extra_index, items_len);
-                    extra_index += items_len + body_len;
+                    extra_index += items_len + info.body_len;
 
                     try case_vals.ensureUnusedCapacity(gpa, items.len);
                     for (items, 0..) |item_ref, item_i| {
@@ -10918,9 +11014,9 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                 while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
                     const item_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]);
                     extra_index += 1;
-                    const body_len = @truncate(u31, sema.code.extra[extra_index]);
+                    const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]);
                     extra_index += 1;
-                    extra_index += body_len;
+                    extra_index += info.body_len;
 
                     case_vals.appendAssumeCapacity(try sema.validateSwitchItemSparse(
                         block,
@@ -10939,10 +11035,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                     extra_index += 1;
                     const ranges_len = sema.code.extra[extra_index];
                     extra_index += 1;
-                    const body_len = @truncate(u31, sema.code.extra[extra_index]);
+                    const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]);
                     extra_index += 1;
                     const items = sema.code.refSlice(extra_index, items_len);
-                    extra_index += items_len + body_len;
+                    extra_index += items_len + info.body_len;
 
                     try case_vals.ensureUnusedCapacity(gpa, items.len);
                     for (items, 0..) |item_ref, item_i| {
@@ -11006,7 +11102,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         .is_comptime = block.is_comptime,
         .comptime_reason = block.comptime_reason,
         .is_typeof = block.is_typeof,
-        .switch_else_err_ty = else_error_ty,
         .c_import_buf = block.c_import_buf,
         .runtime_cond = block.runtime_cond,
         .runtime_loop = block.runtime_loop,
@@ -11024,19 +11119,31 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
             var scalar_i: usize = 0;
             while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
                 extra_index += 1;
-                const body_len = @truncate(u31, sema.code.extra[extra_index]);
-                const is_inline = sema.code.extra[extra_index] >> 31 != 0;
+                const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]);
                 extra_index += 1;
-                const body = sema.code.extra[extra_index..][0..body_len];
-                extra_index += body_len;
+                const body = sema.code.extra[extra_index..][0..info.body_len];
+                extra_index += info.body_len;
 
                 const item = case_vals.items[scalar_i];
-                const item_val = sema.resolveConstLazyValue(&child_block, .unneeded, item, "") catch unreachable;
-                if (resolved_operand_val.eql(item_val, operand_ty, mod)) {
-                    if (is_inline) child_block.inline_case_capture = operand;
-
+                const item_val = sema.resolveConstValue(&child_block, .unneeded, item, "") catch unreachable;
+                if (operand_val.eql(item_val, operand_ty, sema.mod)) {
+                    if (info.is_inline) child_block.inline_case_capture = operand;
                     if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand);
-                    return sema.resolveBlockBody(block, src, &child_block, body, inst, merges);
+                    return sema.resolveSwitchProngComptime(
+                        block,
+                        &child_block,
+                        src,
+                        raw_operand.val,
+                        raw_operand.ptr,
+                        .normal,
+                        body,
+                        info.capture,
+                        .{ .scalar = @intCast(u32, scalar_i) },
+                        else_error_ty,
+                        &.{item},
+                        inst,
+                        merges,
+                    );
                 }
             }
         }
@@ -11048,22 +11155,34 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                 extra_index += 1;
                 const ranges_len = sema.code.extra[extra_index];
                 extra_index += 1;
-                const body_len = @truncate(u31, sema.code.extra[extra_index]);
-                const is_inline = sema.code.extra[extra_index] >> 31 != 0;
+                const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]);
                 extra_index += 1 + items_len;
-                const body = sema.code.extra[extra_index + 2 * ranges_len ..][0..body_len];
+                const body = sema.code.extra[extra_index + 2 * ranges_len ..][0..info.body_len];
 
                 const items = case_vals.items[case_val_idx..][0..items_len];
                 case_val_idx += items_len;
 
                 for (items) |item| {
                     // Validation above ensured these will succeed.
-                    const item_val = sema.resolveConstLazyValue(&child_block, .unneeded, item, "") catch unreachable;
-                    if (resolved_operand_val.eql(item_val, operand_ty, mod)) {
-                        if (is_inline) child_block.inline_case_capture = operand;
-
+                    const item_val = sema.resolveConstValue(&child_block, .unneeded, item, "") catch unreachable;
+                    if (operand_val.eql(item_val, operand_ty, sema.mod)) {
+                        if (info.is_inline) child_block.inline_case_capture = operand;
                         if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand);
-                        return sema.resolveBlockBody(block, src, &child_block, body, inst, merges);
+                        return sema.resolveSwitchProngComptime(
+                            block,
+                            &child_block,
+                            src,
+                            raw_operand.val,
+                            raw_operand.ptr,
+                            .normal,
+                            body,
+                            info.capture,
+                            .{ .multi_capture = @intCast(u32, multi_i) },
+                            else_error_ty,
+                            items,
+                            inst,
+                            merges,
+                        );
                     }
                 }
 
@@ -11079,13 +11198,27 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                     if ((try sema.compareAll(resolved_operand_val, .gte, first_val, operand_ty)) and
                         (try sema.compareAll(resolved_operand_val, .lte, last_val, operand_ty)))
                     {
-                        if (is_inline) child_block.inline_case_capture = operand;
+                        if (info.is_inline) child_block.inline_case_capture = operand;
                         if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand);
-                        return sema.resolveBlockBody(block, src, &child_block, body, inst, merges);
+                        return sema.resolveSwitchProngComptime(
+                            block,
+                            &child_block,
+                            src,
+                            raw_operand.val,
+                            raw_operand.ptr,
+                            .normal,
+                            body,
+                            info.capture,
+                            .{ .multi_capture = @intCast(u32, multi_i) },
+                            else_error_ty,
+                            undefined,
+                            inst,
+                            merges,
+                        );
                     }
                 }
 
-                extra_index += body_len;
+                extra_index += info.body_len;
             }
         }
         if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, special.body, operand);
@@ -11093,7 +11226,22 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         if (empty_enum) {
             return Air.Inst.Ref.void_value;
         }
-        return sema.resolveBlockBody(block, src, &child_block, special.body, inst, merges);
+
+        return sema.resolveSwitchProngComptime(
+            block,
+            &child_block,
+            src,
+            raw_operand.val,
+            raw_operand.ptr,
+            .special,
+            special.body,
+            special.capture,
+            .special,
+            else_error_ty,
+            undefined,
+            inst,
+            merges,
+        );
     }
 
     if (scalar_cases_len + multi_cases_len == 0 and !special.is_inline) {
@@ -11113,7 +11261,22 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
             const ok = try block.addUnOp(.is_named_enum_value, operand);
             try sema.addSafetyCheck(block, ok, .corrupt_switch);
         }
-        return sema.resolveBlockBody(block, src, &child_block, special.body, inst, merges);
+
+        return sema.resolveSwitchProngComptime(
+            block,
+            &child_block,
+            src,
+            raw_operand.val,
+            raw_operand.ptr,
+            .special,
+            special.body,
+            special.capture,
+            .special,
+            else_error_ty,
+            undefined,
+            inst,
+            merges,
+        );
     }
 
     if (child_block.is_comptime) {
@@ -11140,11 +11303,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
     var scalar_i: usize = 0;
     while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
         extra_index += 1;
-        const body_len = @truncate(u31, sema.code.extra[extra_index]);
-        const is_inline = sema.code.extra[extra_index] >> 31 != 0;
+        const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]);
         extra_index += 1;
-        const body = sema.code.extra[extra_index..][0..body_len];
-        extra_index += body_len;
+        const body = sema.code.extra[extra_index..][0..info.body_len];
+        extra_index += info.body_len;
 
         var wip_captures = try WipCaptureScope.init(gpa, child_block.wip_capture_scope);
         defer wip_captures.deinit();
@@ -11154,7 +11316,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         case_block.inline_case_capture = .none;
 
         const item = case_vals.items[scalar_i];
-        if (is_inline) case_block.inline_case_capture = item;
+        if (info.is_inline) case_block.inline_case_capture = item;
         // `item` is already guaranteed to be constant known.
 
         const analyze_body = if (union_originally) blk: {
@@ -11166,7 +11328,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         if (err_set and try sema.maybeErrorUnwrap(&case_block, body, operand)) {
             // nothing to do here
         } else if (analyze_body) {
-            try sema.analyzeBodyRuntimeBreak(&case_block, body);
+            try sema.analyzeSwitchProngRuntime(
+                &case_block,
+                raw_operand.val,
+                raw_operand.ptr,
+                .normal,
+                body,
+                info.capture,
+                .{ .scalar = @intCast(u32, scalar_i) },
+                else_error_ty,
+                &.{item},
+                inst,
+            );
         } else {
             _ = try case_block.addNoOp(.unreach);
         }
@@ -11195,8 +11368,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         extra_index += 1;
         const ranges_len = sema.code.extra[extra_index];
         extra_index += 1;
-        const body_len = @truncate(u31, sema.code.extra[extra_index]);
-        const is_inline = sema.code.extra[extra_index] >> 31 != 0;
+        const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, sema.code.extra[extra_index]);
         extra_index += 1 + items_len;
 
         const items = case_vals.items[case_val_idx..][0..items_len];
@@ -11207,9 +11379,9 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         case_block.inline_case_capture = .none;
 
         // Generate all possible cases as scalar prongs.
-        if (is_inline) {
+        if (info.is_inline) {
             const body_start = extra_index + 2 * ranges_len;
-            const body = sema.code.extra[body_start..][0..body_len];
+            const body = sema.code.extra[body_start..][0..info.body_len];
             var emit_bb = false;
 
             var range_i: u32 = 0;
@@ -11250,7 +11422,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                     };
                     emit_bb = true;
 
-                    try sema.analyzeBodyRuntimeBreak(&case_block, body);
+                    try sema.analyzeSwitchProngRuntime(
+                        &case_block,
+                        raw_operand.val,
+                        raw_operand.ptr,
+                        .normal,
+                        body,
+                        info.capture,
+                        .{ .multi_capture = multi_i },
+                        else_error_ty,
+                        undefined,
+                        inst,
+                    );
 
                     try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len);
                     cases_extra.appendAssumeCapacity(1); // items_len
@@ -11286,7 +11469,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                 emit_bb = true;
 
                 if (analyze_body) {
-                    try sema.analyzeBodyRuntimeBreak(&case_block, body);
+                    try sema.analyzeSwitchProngRuntime(
+                        &case_block,
+                        raw_operand.val,
+                        raw_operand.ptr,
+                        .normal,
+                        body,
+                        info.capture,
+                        .{ .multi_capture = multi_i },
+                        else_error_ty,
+                        &.{item},
+                        inst,
+                    );
                 } else {
                     _ = try case_block.addNoOp(.unreach);
                 }
@@ -11298,7 +11492,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                 cases_extra.appendSliceAssumeCapacity(case_block.instructions.items);
             }
 
-            extra_index += body_len;
+            extra_index += info.body_len;
             continue;
         }
 
@@ -11319,12 +11513,23 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
             else
                 true;
 
-            const body = sema.code.extra[extra_index..][0..body_len];
-            extra_index += body_len;
+            const body = sema.code.extra[extra_index..][0..info.body_len];
+            extra_index += info.body_len;
             if (err_set and try sema.maybeErrorUnwrap(&case_block, body, operand)) {
                 // nothing to do here
             } else if (analyze_body) {
-                try sema.analyzeBodyRuntimeBreak(&case_block, body);
+                try sema.analyzeSwitchProngRuntime(
+                    &case_block,
+                    raw_operand.val,
+                    raw_operand.ptr,
+                    .normal,
+                    body,
+                    info.capture,
+                    .{ .multi_capture = multi_i },
+                    else_error_ty,
+                    items,
+                    inst,
+                );
             } else {
                 _ = try case_block.addNoOp(.unreach);
             }
@@ -11397,12 +11602,23 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
             case_block.instructions.shrinkRetainingCapacity(0);
             case_block.wip_capture_scope = wip_captures.scope;
 
-            const body = sema.code.extra[extra_index..][0..body_len];
-            extra_index += body_len;
+            const body = sema.code.extra[extra_index..][0..info.body_len];
+            extra_index += info.body_len;
             if (err_set and try sema.maybeErrorUnwrap(&case_block, body, operand)) {
                 // nothing to do here
             } else {
-                try sema.analyzeBodyRuntimeBreak(&case_block, body);
+                try sema.analyzeSwitchProngRuntime(
+                    &case_block,
+                    raw_operand.val,
+                    raw_operand.ptr,
+                    .normal,
+                    body,
+                    info.capture,
+                    .{ .multi_capture = multi_i },
+                    else_error_ty,
+                    items,
+                    inst,
+                );
             }
 
             try wip_captures.finalize();
@@ -11461,7 +11677,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                     emit_bb = true;
 
                     if (analyze_body) {
-                        try sema.analyzeBodyRuntimeBreak(&case_block, special.body);
+                        try sema.analyzeSwitchProngRuntime(
+                            &case_block,
+                            raw_operand.val,
+                            raw_operand.ptr,
+                            .special,
+                            special.body,
+                            special.capture,
+                            .special,
+                            else_error_ty,
+                            &.{item_ref},
+                            inst,
+                        );
                     } else {
                         _ = try case_block.addNoOp(.unreach);
                     }
@@ -11497,7 +11724,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                     if (emit_bb) try sema.emitBackwardBranch(block, special_prong_src);
                     emit_bb = true;
 
-                    try sema.analyzeBodyRuntimeBreak(&case_block, special.body);
+                    try sema.analyzeSwitchProngRuntime(
+                        &case_block,
+                        raw_operand.val,
+                        raw_operand.ptr,
+                        .special,
+                        special.body,
+                        special.capture,
+                        .special,
+                        else_error_ty,
+                        &.{item_ref},
+                        inst,
+                    );
 
                     try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len);
                     cases_extra.appendAssumeCapacity(1); // items_len
@@ -11520,7 +11758,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                     if (emit_bb) try sema.emitBackwardBranch(block, special_prong_src);
                     emit_bb = true;
 
-                    try sema.analyzeBodyRuntimeBreak(&case_block, special.body);
+                    try sema.analyzeSwitchProngRuntime(
+                        &case_block,
+                        raw_operand.val,
+                        raw_operand.ptr,
+                        .special,
+                        special.body,
+                        special.capture,
+                        .special,
+                        else_error_ty,
+                        &.{item_ref},
+                        inst,
+                    );
 
                     try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len);
                     cases_extra.appendAssumeCapacity(1); // items_len
@@ -11540,7 +11789,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                     if (emit_bb) try sema.emitBackwardBranch(block, special_prong_src);
                     emit_bb = true;
 
-                    try sema.analyzeBodyRuntimeBreak(&case_block, special.body);
+                    try sema.analyzeSwitchProngRuntime(
+                        &case_block,
+                        raw_operand.val,
+                        raw_operand.ptr,
+                        .special,
+                        special.body,
+                        special.capture,
+                        .special,
+                        else_error_ty,
+                        &.{Air.Inst.Ref.bool_true},
+                        inst,
+                    );
 
                     try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len);
                     cases_extra.appendAssumeCapacity(1); // items_len
@@ -11558,7 +11818,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                     if (emit_bb) try sema.emitBackwardBranch(block, special_prong_src);
                     emit_bb = true;
 
-                    try sema.analyzeBodyRuntimeBreak(&case_block, special.body);
+                    try sema.analyzeSwitchProngRuntime(
+                        &case_block,
+                        raw_operand.val,
+                        raw_operand.ptr,
+                        .special,
+                        special.body,
+                        special.capture,
+                        .special,
+                        else_error_ty,
+                        &.{Air.Inst.Ref.bool_false},
+                        inst,
+                    );
 
                     try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len);
                     cases_extra.appendAssumeCapacity(1); // items_len
@@ -11601,7 +11872,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         {
             // nothing to do here
         } else if (special.body.len != 0 and analyze_body and !special.is_inline) {
-            try sema.analyzeBodyRuntimeBreak(&case_block, special.body);
+            try sema.analyzeSwitchProngRuntime(
+                &case_block,
+                raw_operand.val,
+                raw_operand.ptr,
+                .special,
+                special.body,
+                special.capture,
+                .special,
+                else_error_ty,
+                undefined,
+                inst,
+            );
         } else {
             // We still need a terminator in this block, but we have proven
             // that it is unreachable.
@@ -11746,7 +12028,7 @@ fn resolveSwitchItemVal(
         else => |e| return e,
     };
 
-    const val = sema.resolveConstLazyValue(block, .unneeded, item, "") catch |err| switch (err) {
+    const maybe_lazy = sema.resolveConstValue(block, .unneeded, item, "") catch |err| switch (err) {
         error.NeededSourceLocation => {
             const src = switch_prong_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, range_expand);
             _ = try sema.resolveConstValue(block, src, item, "switch prong values must be comptime-known");
@@ -11755,7 +12037,12 @@ fn resolveSwitchItemVal(
         else => |e| return e,
     };
 
-    return .{ .ref = item, .val = val.toIntern() };
+    const val = try sema.resolveLazyValue(maybe_lazy);
+    const new_item = if (val.toIntern() != maybe_lazy.toIntern()) blk: {
+        break :blk try sema.addConstant(coerce_ty, val);
+    } else item;
+
+    return .{ .ref = new_item, .val = val.toIntern() };
 }
 
 fn validateSwitchRange(
src/Zir.zig
@@ -676,17 +676,6 @@ pub const Inst = struct {
         /// what will be switched on.
         /// Uses the `un_node` union field.
         switch_cond_ref,
-        /// Produces the capture value for a switch prong.
-        /// Uses the `switch_capture` field.
-        /// If the `prong_index` field is max int, it means this is the capture
-        /// for the else/`_` prong.
-        switch_capture,
-        /// Produces the capture value for a switch prong.
-        /// Result is a pointer to the value.
-        /// Uses the `switch_capture` field.
-        /// If the `prong_index` field is max int, it means this is the capture
-        /// for the else/`_` prong.
-        switch_capture_ref,
         /// Produces the capture value for an inline switch prong tag capture.
         /// Uses the `un_tok` field.
         switch_capture_tag,
@@ -1135,8 +1124,6 @@ pub const Inst = struct {
                 .typeof_log2_int_type,
                 .resolve_inferred_alloc,
                 .set_eval_branch_quota,
-                .switch_capture,
-                .switch_capture_ref,
                 .switch_capture_tag,
                 .switch_block,
                 .switch_cond,
@@ -1427,8 +1414,6 @@ pub const Inst = struct {
                 .slice_length,
                 .import,
                 .typeof_log2_int_type,
-                .switch_capture,
-                .switch_capture_ref,
                 .switch_capture_tag,
                 .switch_block,
                 .switch_cond,
@@ -1685,8 +1670,6 @@ pub const Inst = struct {
                 .switch_block = .pl_node,
                 .switch_cond = .un_node,
                 .switch_cond_ref = .un_node,
-                .switch_capture = .switch_capture,
-                .switch_capture_ref = .switch_capture,
                 .switch_capture_tag = .un_tok,
                 .array_base_ptr = .un_node,
                 .field_base_ptr = .un_node,
@@ -2254,10 +2237,6 @@ pub const Inst = struct {
             operand: Ref,
             payload_index: u32,
         },
-        switch_capture: struct {
-            switch_inst: Index,
-            prong_index: u32,
-        },
         dbg_stmt: LineColumn,
         /// Used for unary operators which reference an inst,
         /// with an AST node source location.
@@ -2327,7 +2306,6 @@ pub const Inst = struct {
             bool_br,
             @"unreachable",
             @"break",
-            switch_capture,
             dbg_stmt,
             inst_node,
             str_op,
@@ -2667,25 +2645,29 @@ pub const Inst = struct {
 
     /// 0. multi_cases_len: u32 // If has_multi_cases is set.
     /// 1. else_body { // If has_else or has_under is set.
-    ///        body_len: u32,
-    ///        body member Index for every body_len
+    ///        info: ProngInfo,
+    ///        body member Index for every info.body_len
     ///     }
     /// 2. scalar_cases: { // for every scalar_cases_len
     ///        item: Ref,
-    ///        body_len: u32,
-    ///        body member Index for every body_len
+    ///        info: ProngInfo,
+    ///        body member Index for every info.body_len
     ///     }
     /// 3. multi_cases: { // for every multi_cases_len
     ///        items_len: u32,
     ///        ranges_len: u32,
-    ///        body_len: u32,
+    ///        info: ProngInfo,
     ///        item: Ref // for every items_len
     ///        ranges: { // for every ranges_len
     ///            item_first: Ref,
     ///            item_last: Ref,
     ///        }
-    ///        body member Index for every body_len
+    ///        body member Index for every info.body_len
     ///    }
+    ///
+    /// When analyzing a case body, the switch instruction itself refers to the
+    /// captured payload. Whether this is captured by reference or by value
+    /// depends on whether the `byref` bit is set for the corresponding body.
     pub const SwitchBlock = struct {
         /// This is always a `switch_cond` or `switch_cond_ref` instruction.
         /// If it is a `switch_cond_ref` instruction, bits.is_ref is always true.
@@ -2697,6 +2679,19 @@ pub const Inst = struct {
         operand: Ref,
         bits: Bits,
 
+        /// These are stored in trailing data in `extra` for each prong.
+        pub const ProngInfo = packed struct(u32) {
+            body_len: u29,
+            capture: Capture,
+            is_inline: bool,
+
+            pub const Capture = enum(u2) {
+                none,
+                by_val,
+                by_ref,
+            };
+        };
+
         pub const Bits = packed struct {
             /// If true, one or more prongs have multiple items.
             has_multi_cases: bool,
@@ -2724,64 +2719,6 @@ pub const Inst = struct {
             items: []const Ref,
             body: []const Index,
         };
-
-        /// TODO performance optimization: instead of having this helper method
-        /// change the definition of switch_capture instruction to store extra_index
-        /// instead of prong_index. This way, Sema won't be doing O(N^2) iterations
-        /// over the switch prongs.
-        pub fn getProng(
-            self: SwitchBlock,
-            zir: Zir,
-            extra_end: usize,
-            prong_index: usize,
-        ) MultiProng {
-            var extra_index: usize = extra_end + @boolToInt(self.bits.has_multi_cases);
-
-            if (self.bits.specialProng() != .none) {
-                const body_len = @truncate(u31, zir.extra[extra_index]);
-                extra_index += 1;
-                const body = zir.extra[extra_index..][0..body_len];
-                extra_index += body.len;
-            }
-
-            var cur_idx: usize = 0;
-            while (cur_idx < self.bits.scalar_cases_len) : (cur_idx += 1) {
-                const items = zir.refSlice(extra_index, 1);
-                extra_index += 1;
-                const body_len = @truncate(u31, zir.extra[extra_index]);
-                extra_index += 1;
-                const body = zir.extra[extra_index..][0..body_len];
-                extra_index += body_len;
-                if (cur_idx == prong_index) {
-                    return .{
-                        .items = items,
-                        .body = body,
-                    };
-                }
-            }
-            while (true) : (cur_idx += 1) {
-                const items_len = zir.extra[extra_index];
-                extra_index += 1;
-                const ranges_len = zir.extra[extra_index];
-                extra_index += 1;
-                const body_len = @truncate(u31, zir.extra[extra_index]);
-                extra_index += 1;
-                const items = zir.refSlice(extra_index, items_len);
-                extra_index += items_len;
-                // Each range has a start and an end.
-                extra_index += 2 * ranges_len;
-
-                const body = zir.extra[extra_index..][0..body_len];
-                extra_index += body_len;
-
-                if (cur_idx == prong_index) {
-                    return .{
-                        .items = items,
-                        .body = body,
-                    };
-                }
-            }
-        }
     };
 
     pub const Field = struct {