Commit b48d6ff619

Jacob Young <jacobly0@users.noreply.github.com>
2025-05-31 00:04:30
Legalize: implement scalarization of `@select`
1 parent 32a57bf
Changed files (5)
lib
src
test
behavior
lib/std/simd.zig
@@ -368,9 +368,6 @@ pub fn countElementsWithValue(vec: anytype, value: std.meta.Child(@TypeOf(vec)))
 }
 
 test "vector searching" {
-    if (builtin.zig_backend == .stage2_x86_64 and
-        !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .ssse3)) return error.SkipZigTest;
-
     const base = @Vector(8, u32){ 6, 4, 7, 4, 4, 2, 3, 7 };
 
     try std.testing.expectEqual(@as(?u3, 1), firstIndexOfValue(base, 4));
src/Air/Legalize.zig
@@ -74,6 +74,7 @@ pub const Feature = enum {
     scalarize_int_from_float,
     scalarize_int_from_float_optimized,
     scalarize_float_from_int,
+    scalarize_select,
     scalarize_mul_add,
 
     /// Legalize (shift lhs, (splat rhs)) -> (shift lhs, rhs)
@@ -167,6 +168,7 @@ pub const Feature = enum {
             .int_from_float => .scalarize_int_from_float,
             .int_from_float_optimized => .scalarize_int_from_float_optimized,
             .float_from_int => .scalarize_float_from_int,
+            .select => .scalarize_select,
             .mul_add => .scalarize_mul_add,
         };
     }
@@ -520,7 +522,9 @@ fn legalizeBody(l: *Legalize, body_start: usize, body_len: usize) Error!void {
             },
             .splat,
             .shuffle,
+            => {},
             .select,
+            => if (l.features.contains(.scalarize_select)) continue :inst try l.scalarize(inst, .select_pl_op_bin),
             .memset,
             .memset_safe,
             .memcpy,
@@ -568,7 +572,7 @@ fn legalizeBody(l: *Legalize, body_start: usize, body_len: usize) Error!void {
     }
 }
 
-const ScalarizeDataTag = enum { un_op, ty_op, bin_op, ty_pl_vector_cmp, pl_op_bin };
+const ScalarizeDataTag = enum { un_op, ty_op, bin_op, ty_pl_vector_cmp, pl_op_bin, select_pl_op_bin };
 inline fn scalarize(l: *Legalize, orig_inst: Air.Inst.Index, comptime data_tag: ScalarizeDataTag) Error!Air.Inst.Tag {
     return l.replaceInst(orig_inst, .block, try l.scalarizeBlockPayload(orig_inst, data_tag));
 }
@@ -584,6 +588,7 @@ fn scalarizeBlockPayload(l: *Legalize, orig_inst: Air.Inst.Index, comptime data_
             .un_op, .ty_op => 1,
             .bin_op, .ty_pl_vector_cmp => 2,
             .pl_op_bin => 3,
+            .select_pl_op_bin => 6,
         } + 9
     ]Air.Inst.Index = undefined;
     try l.air_instructions.ensureUnusedCapacity(zcu.gpa, inst_buf.len);
@@ -722,23 +727,67 @@ fn scalarizeBlockPayload(l: *Legalize, orig_inst: Air.Inst.Index, comptime data_
                                     } },
                                 });
                             },
+                            .select_pl_op_bin => {
+                                const extra = l.extraData(Air.Bin, orig.data.pl_op.payload).data;
+                                var res_elem: Result = .init(l, l.typeOf(extra.lhs).scalarType(zcu), &loop.block);
+                                res_elem.block = .init(loop.block.stealCapacity(6));
+                                {
+                                    var select_cond_br: CondBr = .init(l, res_elem.block.add(l, .{
+                                        .tag = .array_elem_val,
+                                        .data = .{ .bin_op = .{
+                                            .lhs = orig.data.pl_op.operand,
+                                            .rhs = cur_index_inst.toRef(),
+                                        } },
+                                    }).toRef(), &res_elem.block, .{});
+                                    select_cond_br.then_block = .init(res_elem.block.stealRemainingCapacity());
+                                    {
+                                        _ = select_cond_br.then_block.add(l, .{
+                                            .tag = .br,
+                                            .data = .{ .br = .{
+                                                .block_inst = res_elem.inst,
+                                                .operand = select_cond_br.then_block.add(l, .{
+                                                    .tag = .array_elem_val,
+                                                    .data = .{ .bin_op = .{
+                                                        .lhs = extra.lhs,
+                                                        .rhs = cur_index_inst.toRef(),
+                                                    } },
+                                                }).toRef(),
+                                            } },
+                                        });
+                                    }
+                                    select_cond_br.else_block = .init(select_cond_br.then_block.stealRemainingCapacity());
+                                    {
+                                        _ = select_cond_br.else_block.add(l, .{
+                                            .tag = .br,
+                                            .data = .{ .br = .{
+                                                .block_inst = res_elem.inst,
+                                                .operand = select_cond_br.else_block.add(l, .{
+                                                    .tag = .array_elem_val,
+                                                    .data = .{ .bin_op = .{
+                                                        .lhs = extra.rhs,
+                                                        .rhs = cur_index_inst.toRef(),
+                                                    } },
+                                                }).toRef(),
+                                            } },
+                                        });
+                                    }
+                                    try select_cond_br.finish(l);
+                                }
+                                try res_elem.finish(l);
+                                break :res_elem res_elem.inst;
+                            },
                         }.toRef(),
                     }),
                 } },
             });
 
-            var loop_cond_br: CondBr = .init(
+            var loop_cond_br: CondBr = .init(l, (try loop.block.addCmp(
                 l,
-                (try loop.block.addCmp(
-                    l,
-                    .lt,
-                    cur_index_inst.toRef(),
-                    try pt.intRef(.usize, res_ty.vectorLen(zcu) - 1),
-                    .{},
-                )).toRef(),
-                &loop.block,
+                .lt,
+                cur_index_inst.toRef(),
+                try pt.intRef(.usize, res_ty.vectorLen(zcu) - 1),
                 .{},
-            );
+            )).toRef(), &loop.block, .{});
             loop_cond_br.then_block = .init(loop.block.stealRemainingCapacity());
             {
                 _ = loop_cond_br.then_block.add(l, .{
@@ -1138,9 +1187,21 @@ const Block = struct {
     /// This is useful when you've provided a buffer big enough for all your instructions, but you are
     /// now starting a new block and some of them need to live there instead.
     fn stealRemainingCapacity(b: *Block) []Air.Inst.Index {
-        const remaining = b.instructions[b.len..];
-        b.instructions = b.instructions[0..b.len];
-        return remaining;
+        return b.stealFrom(b.len);
+    }
+
+    /// Returns `len` elements taken from the unused capacity of `b.instructions`, and shrinks
+    /// `b.instructions` down to not include them anymore.
+    /// This is useful when you've provided a buffer big enough for all your instructions, but you are
+    /// now starting a new block and some of them need to live there instead.
+    fn stealCapacity(b: *Block, len: usize) []Air.Inst.Index {
+        return b.stealFrom(b.instructions.len - len);
+    }
+
+    fn stealFrom(b: *Block, start: usize) []Air.Inst.Index {
+        assert(start >= b.len);
+        defer b.instructions.len = start;
+        return b.instructions[start..];
     }
 
     fn body(b: *const Block) []const Air.Inst.Index {
@@ -1149,6 +1210,31 @@ const Block = struct {
     }
 };
 
+const Result = struct {
+    inst: Air.Inst.Index,
+    block: Block,
+
+    /// The return value has `block` initialized to `undefined`; it is the caller's reponsibility
+    /// to initialize it.
+    fn init(l: *Legalize, ty: Type, parent_block: *Block) Result {
+        return .{
+            .inst = parent_block.add(l, .{
+                .tag = .block,
+                .data = .{ .ty_pl = .{
+                    .ty = Air.internedToRef(ty.toIntern()),
+                    .payload = undefined,
+                } },
+            }),
+            .block = undefined,
+        };
+    }
+
+    fn finish(res: Result, l: *Legalize) Error!void {
+        const data = &l.air_instructions.items(.data)[@intFromEnum(res.inst)];
+        data.ty_pl.payload = try l.addBlockBody(res.block.body());
+    }
+};
+
 const Loop = struct {
     inst: Air.Inst.Index,
     block: Block,
src/arch/x86_64/CodeGen.zig
@@ -84,6 +84,7 @@ pub fn legalizeFeatures(target: *const std.Target) *const Air.Legalize.Features
             .scalarize_int_from_float = use_old,
             .scalarize_int_from_float_optimized = use_old,
             .scalarize_float_from_int = use_old,
+            .scalarize_select = true,
             .scalarize_mul_add = use_old,
 
             .unsplat_shift_rhs = false,
src/Compilation.zig
@@ -2529,6 +2529,7 @@ pub fn destroy(comp: *Compilation) void {
 
 pub fn clearMiscFailures(comp: *Compilation) void {
     comp.alloc_failure_occurred = false;
+    comp.link_diags.flags = .{};
     for (comp.misc_failures.values()) |*value| {
         value.deinit(comp.gpa);
     }
@@ -2795,7 +2796,6 @@ pub fn update(comp: *Compilation, main_progress_node: std.Progress.Node) !void {
 
     if (anyErrors(comp)) {
         // Skip flushing and keep source files loaded for error reporting.
-        comp.link_diags.flags = .{};
         return;
     }
 
test/behavior/select.zig
@@ -41,8 +41,6 @@ test "@select arrays" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_x86_64 and
-        !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .avx2)) return error.SkipZigTest;
 
     try comptime selectArrays();
     try selectArrays();
@@ -70,7 +68,6 @@ fn selectArrays() !void {
 test "@select compare result" {
     if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_llvm and builtin.cpu.arch == .hexagon) return error.SkipZigTest;
 
     const S = struct {