Commit a377bf87ce

mlugg <mlugg@mlugg.co.uk>
2023-05-05 22:40:04
Zir: remove unnecessary switch_capture_multi instructions
By indexing from the very first switch case rather than into scalar and multi cases separately, the instructions for capturing in multi cases become unnecessary, freeing up 2 ZIR tags.
1 parent 387f956
src/AstGen.zig
@@ -2614,8 +2614,6 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As
             .switch_cond_ref,
             .switch_capture,
             .switch_capture_ref,
-            .switch_capture_multi,
-            .switch_capture_multi_ref,
             .switch_capture_tag,
             .struct_init_empty,
             .struct_init,
@@ -6916,15 +6914,8 @@ fn switchExpr(
                         },
                     });
                 } else {
-                    const is_multi_case_bits: u2 = @boolToInt(is_multi_case);
-                    const is_ptr_bits: u2 = @boolToInt(is_ptr);
-                    const capture_tag: Zir.Inst.Tag = switch ((is_multi_case_bits << 1) | is_ptr_bits) {
-                        0b00 => .switch_capture,
-                        0b01 => .switch_capture_ref,
-                        0b10 => .switch_capture_multi,
-                        0b11 => .switch_capture_multi_ref,
-                    };
-                    const capture_index = if (is_multi_case) multi_case_index else scalar_case_index;
+                    const capture_tag: Zir.Inst.Tag = if (is_ptr) .switch_capture_ref else .switch_capture;
+                    const capture_index = if (is_multi_case) scalar_cases_len + multi_case_index else scalar_case_index;
                     capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len);
                     try astgen.instructions.append(gpa, .{
                         .tag = capture_tag,
src/print_zir.zig
@@ -438,8 +438,6 @@ const Writer = struct {
 
             .switch_capture,
             .switch_capture_ref,
-            .switch_capture_multi,
-            .switch_capture_multi_ref,
             => try self.writeSwitchCapture(stream, inst),
 
             .dbg_stmt => try self.writeDbgStmt(stream, inst),
src/Sema.zig
@@ -1017,10 +1017,8 @@ fn analyzeBodyInner(
             .switch_block                 => try sema.zirSwitchBlock(block, inst),
             .switch_cond                  => try sema.zirSwitchCond(block, inst, false),
             .switch_cond_ref              => try sema.zirSwitchCond(block, inst, true),
-            .switch_capture               => try sema.zirSwitchCapture(block, inst, false, false),
-            .switch_capture_ref           => try sema.zirSwitchCapture(block, inst, false, true),
-            .switch_capture_multi         => try sema.zirSwitchCapture(block, inst, true, false),
-            .switch_capture_multi_ref     => try sema.zirSwitchCapture(block, inst, true, true),
+            .switch_capture               => try sema.zirSwitchCapture(block, inst, false),
+            .switch_capture_ref           => try sema.zirSwitchCapture(block, inst, true),
             .switch_capture_tag           => try sema.zirSwitchCaptureTag(block, inst),
             .type_info                    => try sema.zirTypeInfo(block, inst),
             .size_of                      => try sema.zirSizeOf(block, inst),
@@ -10089,7 +10087,6 @@ fn zirSwitchCapture(
     sema: *Sema,
     block: *Block,
     inst: Zir.Inst.Index,
-    is_multi: bool,
     is_ref: bool,
 ) CompileError!Air.Inst.Ref {
     const tracy = trace(@src());
@@ -10178,12 +10175,7 @@ fn zirSwitchCapture(
         }
     }
 
-    const items = if (is_multi)
-        switch_extra.data.getMultiProng(sema.code, switch_extra.end, capture_info.prong_index).items
-    else
-        &[_]Zir.Inst.Ref{
-            switch_extra.data.getScalarProng(sema.code, switch_extra.end, capture_info.prong_index).item,
-        };
+    const items = switch_extra.data.getProng(sema.code, switch_extra.end, capture_info.prong_index).items;
 
     switch (operand_ty.zigTypeTag(mod)) {
         .Union => {
@@ -10252,7 +10244,7 @@ fn zirSwitchCapture(
             return block.addStructFieldVal(operand, first_field_index, first_field.ty);
         },
         .ErrorSet => {
-            if (is_multi) {
+            if (items.len > 1) {
                 var names: Module.Fn.InferredErrorSet.NameMap = .{};
                 try names.ensureUnusedCapacity(sema.arena, items.len);
                 for (items) |item| {
src/Zir.zig
@@ -687,15 +687,6 @@ pub const Inst = struct {
         /// If the `prong_index` field is max int, it means this is the capture
         /// for the else/`_` prong.
         switch_capture_ref,
-        /// Produces the capture value for a switch prong.
-        /// The prong is one of the multi cases.
-        /// Uses the `switch_capture` field.
-        switch_capture_multi,
-        /// Produces the capture value for a switch prong.
-        /// The prong is one of the multi cases.
-        /// Result is a pointer to the value.
-        /// Uses the `switch_capture` field.
-        switch_capture_multi_ref,
         /// Produces the capture value for an inline switch prong tag capture.
         /// Uses the `un_tok` field.
         switch_capture_tag,
@@ -1146,8 +1137,6 @@ pub const Inst = struct {
                 .set_eval_branch_quota,
                 .switch_capture,
                 .switch_capture_ref,
-                .switch_capture_multi,
-                .switch_capture_multi_ref,
                 .switch_capture_tag,
                 .switch_block,
                 .switch_cond,
@@ -1440,8 +1429,6 @@ pub const Inst = struct {
                 .typeof_log2_int_type,
                 .switch_capture,
                 .switch_capture_ref,
-                .switch_capture_multi,
-                .switch_capture_multi_ref,
                 .switch_capture_tag,
                 .switch_block,
                 .switch_cond,
@@ -1700,8 +1687,6 @@ pub const Inst = struct {
                 .switch_cond_ref = .un_node,
                 .switch_capture = .switch_capture,
                 .switch_capture_ref = .switch_capture,
-                .switch_capture_multi = .switch_capture,
-                .switch_capture_multi_ref = .switch_capture,
                 .switch_capture_tag = .un_tok,
                 .array_base_ptr = .un_node,
                 .field_base_ptr = .un_node,
@@ -2735,8 +2720,8 @@ pub const Inst = struct {
             }
         };
 
-        pub const ScalarProng = struct {
-            item: Ref,
+        pub const MultiProng = struct {
+            items: []const Ref,
             body: []const Index,
         };
 
@@ -2744,56 +2729,13 @@ pub const Inst = struct {
         /// change the definition of switch_capture instruction to store extra_index
         /// instead of prong_index. This way, Sema won't be doing O(N^2) iterations
         /// over the switch prongs.
-        pub fn getScalarProng(
-            self: SwitchBlock,
-            zir: Zir,
-            extra_end: usize,
-            prong_index: usize,
-        ) ScalarProng {
-            var extra_index: usize = extra_end;
-
-            if (self.bits.has_multi_cases) {
-                extra_index += 1;
-            }
-
-            if (self.bits.specialProng() != .none) {
-                const body_len = @truncate(u31, zir.extra[extra_index]);
-                extra_index += 1;
-                const body = zir.extra[extra_index..][0..body_len];
-                extra_index += body.len;
-            }
-
-            var scalar_i: usize = 0;
-            while (true) : (scalar_i += 1) {
-                const item = @intToEnum(Ref, zir.extra[extra_index]);
-                extra_index += 1;
-                const body_len = @truncate(u31, zir.extra[extra_index]);
-                extra_index += 1;
-                const body = zir.extra[extra_index..][0..body_len];
-                extra_index += body.len;
-
-                if (scalar_i < prong_index) continue;
-
-                return .{
-                    .item = item,
-                    .body = body,
-                };
-            }
-        }
-
-        pub const MultiProng = struct {
-            items: []const Ref,
-            body: []const Index,
-        };
-
-        pub fn getMultiProng(
+        pub fn getProng(
             self: SwitchBlock,
             zir: Zir,
             extra_end: usize,
             prong_index: usize,
         ) MultiProng {
-            // +1 for self.bits.has_multi_cases == true
-            var extra_index: usize = extra_end + 1;
+            var extra_index: usize = extra_end + @boolToInt(self.bits.has_multi_cases);
 
             if (self.bits.specialProng() != .none) {
                 const body_len = @truncate(u31, zir.extra[extra_index]);
@@ -2802,15 +2744,22 @@ pub const Inst = struct {
                 extra_index += body.len;
             }
 
-            var scalar_i: usize = 0;
-            while (scalar_i < self.bits.scalar_cases_len) : (scalar_i += 1) {
+            var cur_idx: usize = 0;
+            while (cur_idx < self.bits.scalar_cases_len) : (cur_idx += 1) {
+                const items = zir.refSlice(extra_index, 1);
                 extra_index += 1;
                 const body_len = @truncate(u31, zir.extra[extra_index]);
                 extra_index += 1;
+                const body = zir.extra[extra_index..][0..body_len];
                 extra_index += body_len;
+                if (cur_idx == prong_index) {
+                    return .{
+                        .items = items,
+                        .body = body,
+                    };
+                }
             }
-            var multi_i: u32 = 0;
-            while (true) : (multi_i += 1) {
+            while (true) : (cur_idx += 1) {
                 const items_len = zir.extra[extra_index];
                 extra_index += 1;
                 const ranges_len = zir.extra[extra_index];
@@ -2825,11 +2774,12 @@ pub const Inst = struct {
                 const body = zir.extra[extra_index..][0..body_len];
                 extra_index += body_len;
 
-                if (multi_i < prong_index) continue;
-                return .{
-                    .items = items,
-                    .body = body,
-                };
+                if (cur_idx == prong_index) {
+                    return .{
+                        .items = items,
+                        .body = body,
+                    };
+                }
             }
         }
     };