Commit cff90e3ae0

Jacob Young <jacobly0@users.noreply.github.com>
2025-03-10 15:30:44
x86_64: implement select of register mask
1 parent 0ef3250
Changed files (2)
src
arch
test
behavior
src/arch/x86_64/CodeGen.zig
@@ -97980,16 +97980,150 @@ fn airSelect(self: *CodeGen, inst: Air.Inst.Index) !void {
             switch (pred_mcv) {
                 .register => |pred_reg| switch (pred_reg.class()) {
                     .general_purpose => {},
-                    .sse => if (need_xmm0 and pred_reg.id() != comptime Register.xmm0.id()) {
-                        try self.register_manager.getKnownReg(.xmm0, null);
-                        try self.genSetReg(.xmm0, pred_ty, pred_mcv, .{});
-                        break :mask .xmm0;
-                    } else break :mask if (has_blend)
-                        pred_reg
+                    .sse => if (elem_ty.toIntern() == .bool_type)
+                        if (need_xmm0 and pred_reg.id() != comptime Register.xmm0.id()) {
+                            try self.register_manager.getKnownReg(.xmm0, null);
+                            try self.genSetReg(.xmm0, pred_ty, pred_mcv, .{});
+                            break :mask .xmm0;
+                        } else break :mask if (has_blend)
+                            pred_reg
+                        else
+                            try self.copyToTmpRegister(pred_ty, pred_mcv)
                     else
-                        try self.copyToTmpRegister(pred_ty, pred_mcv),
+                        return self.fail("TODO implement airSelect for {}", .{ty.fmt(pt)}),
                     else => unreachable,
                 },
+                .register_mask => |pred_reg_mask| {
+                    if (pred_reg_mask.info.scalar.bitSize(self.target) != 8 * elem_abi_size)
+                        return self.fail("TODO implement airSelect for {}", .{ty.fmt(pt)});
+
+                    const mask_reg: Register = if (need_xmm0 and pred_reg_mask.reg.id() != comptime Register.xmm0.id()) mask_reg: {
+                        try self.register_manager.getKnownReg(.xmm0, null);
+                        try self.genSetReg(.xmm0, ty, .{ .register = pred_reg_mask.reg }, .{});
+                        break :mask_reg .xmm0;
+                    } else pred_reg_mask.reg;
+                    const mask_alias = registerAlias(mask_reg, abi_size);
+                    const mask_lock = self.register_manager.lockRegAssumeUnused(mask_reg);
+                    defer self.register_manager.unlockReg(mask_lock);
+
+                    const lhs_mcv = try self.resolveInst(extra.lhs);
+                    const lhs_lock = switch (lhs_mcv) {
+                        .register => |lhs_reg| self.register_manager.lockRegAssumeUnused(lhs_reg),
+                        else => null,
+                    };
+                    defer if (lhs_lock) |lock| self.register_manager.unlockReg(lock);
+
+                    const rhs_mcv = try self.resolveInst(extra.rhs);
+                    const rhs_lock = switch (rhs_mcv) {
+                        .register => |rhs_reg| self.register_manager.lockReg(rhs_reg),
+                        else => null,
+                    };
+                    defer if (rhs_lock) |lock| self.register_manager.unlockReg(lock);
+
+                    const order = has_blend != pred_reg_mask.info.inverted;
+                    const reuse_mcv, const other_mcv = if (order)
+                        .{ rhs_mcv, lhs_mcv }
+                    else
+                        .{ lhs_mcv, rhs_mcv };
+                    const dst_mcv: MCValue = if (reuse_mcv.isRegister() and self.reuseOperand(
+                        inst,
+                        if (order) extra.rhs else extra.lhs,
+                        @intFromBool(order),
+                        reuse_mcv,
+                    )) reuse_mcv else if (has_avx)
+                        .{ .register = try self.register_manager.allocReg(inst, abi.RegisterClass.sse) }
+                    else
+                        try self.copyToRegisterWithInstTracking(inst, ty, reuse_mcv);
+                    const dst_reg = dst_mcv.getReg().?;
+                    const dst_alias = registerAlias(dst_reg, abi_size);
+                    const dst_lock = self.register_manager.lockReg(dst_reg);
+                    defer if (dst_lock) |lock| self.register_manager.unlockReg(lock);
+
+                    const mir_tag = @as(?Mir.Inst.FixedTag, if ((pred_reg_mask.info.kind == .all and
+                        elem_ty.toIntern() != .f32_type and elem_ty.toIntern() != .f64_type) or pred_reg_mask.info.scalar == .byte)
+                        if (has_avx)
+                            .{ .vp_b, .blendv }
+                        else if (has_blend)
+                            .{ .p_b, .blendv }
+                        else if (pred_reg_mask.info.kind == .all)
+                            .{ .p_, undefined }
+                        else
+                            null
+                    else if ((pred_reg_mask.info.kind == .all and (elem_ty.toIntern() != .f64_type or !self.hasFeature(.sse2))) or
+                        pred_reg_mask.info.scalar == .dword)
+                        if (has_avx)
+                            .{ .v_ps, .blendv }
+                        else if (has_blend)
+                            .{ ._ps, .blendv }
+                        else if (pred_reg_mask.info.kind == .all)
+                            .{ ._ps, undefined }
+                        else
+                            null
+                    else if (pred_reg_mask.info.kind == .all or pred_reg_mask.info.scalar == .qword)
+                        if (has_avx)
+                            .{ .v_pd, .blendv }
+                        else if (has_blend)
+                            .{ ._pd, .blendv }
+                        else if (pred_reg_mask.info.kind == .all)
+                            .{ ._pd, undefined }
+                        else
+                            null
+                    else
+                        null) orelse return self.fail("TODO implement airSelect for {}", .{ty.fmt(pt)});
+                    if (has_avx) {
+                        const rhs_alias = if (reuse_mcv.isRegister())
+                            registerAlias(reuse_mcv.getReg().?, abi_size)
+                        else rhs: {
+                            try self.genSetReg(dst_reg, ty, reuse_mcv, .{});
+                            break :rhs dst_alias;
+                        };
+                        if (other_mcv.isBase()) try self.asmRegisterRegisterMemoryRegister(
+                            mir_tag,
+                            dst_alias,
+                            rhs_alias,
+                            try other_mcv.mem(self, .{ .size = self.memSize(ty) }),
+                            mask_alias,
+                        ) else try self.asmRegisterRegisterRegisterRegister(
+                            mir_tag,
+                            dst_alias,
+                            rhs_alias,
+                            registerAlias(if (other_mcv.isRegister())
+                                other_mcv.getReg().?
+                            else
+                                try self.copyToTmpRegister(ty, other_mcv), abi_size),
+                            mask_alias,
+                        );
+                    } else if (has_blend) if (other_mcv.isBase()) try self.asmRegisterMemoryRegister(
+                        mir_tag,
+                        dst_alias,
+                        try other_mcv.mem(self, .{ .size = self.memSize(ty) }),
+                        mask_alias,
+                    ) else try self.asmRegisterRegisterRegister(
+                        mir_tag,
+                        dst_alias,
+                        registerAlias(if (other_mcv.isRegister())
+                            other_mcv.getReg().?
+                        else
+                            try self.copyToTmpRegister(ty, other_mcv), abi_size),
+                        mask_alias,
+                    ) else {
+                        try self.asmRegisterRegister(.{ mir_tag[0], .@"and" }, dst_alias, mask_alias);
+                        if (other_mcv.isBase()) try self.asmRegisterMemory(
+                            .{ mir_tag[0], .andn },
+                            mask_alias,
+                            try other_mcv.mem(self, .{ .size = .fromSize(abi_size) }),
+                        ) else try self.asmRegisterRegister(
+                            .{ mir_tag[0], .andn },
+                            mask_alias,
+                            if (other_mcv.isRegister())
+                                other_mcv.getReg().?
+                            else
+                                try self.copyToTmpRegister(ty, other_mcv),
+                        );
+                        try self.asmRegisterRegister(.{ mir_tag[0], .@"or" }, dst_alias, mask_alias);
+                    }
+                    break :result dst_mcv;
+                },
                 else => {},
             }
             const mask_reg: Register = if (need_xmm0) mask_reg: {
@@ -98192,7 +98326,7 @@ fn airSelect(self: *CodeGen, inst: Air.Inst.Index) !void {
         const dst_lock = self.register_manager.lockReg(dst_reg);
         defer if (dst_lock) |lock| self.register_manager.unlockReg(lock);
 
-        const mir_tag = @as(?Mir.Inst.FixedTag, switch (ty.childType(zcu).zigTypeTag(zcu)) {
+        const mir_tag = @as(?Mir.Inst.FixedTag, switch (elem_ty.zigTypeTag(zcu)) {
             else => null,
             .int => switch (abi_size) {
                 0 => unreachable,
@@ -98208,7 +98342,7 @@ fn airSelect(self: *CodeGen, inst: Air.Inst.Index) !void {
                     null,
                 else => null,
             },
-            .float => switch (ty.childType(zcu).floatBits(self.target.*)) {
+            .float => switch (elem_ty.floatBits(self.target.*)) {
                 else => unreachable,
                 16, 80, 128 => null,
                 32 => switch (vec_len) {
@@ -98262,30 +98396,20 @@ fn airSelect(self: *CodeGen, inst: Air.Inst.Index) !void {
                 try self.copyToTmpRegister(ty, lhs_mcv), abi_size),
             mask_alias,
         ) else {
-            const mir_fixes = @as(?Mir.Inst.Fixes, switch (elem_ty.zigTypeTag(zcu)) {
-                else => null,
-                .int => .p_,
-                .float => switch (elem_ty.floatBits(self.target.*)) {
-                    32 => ._ps,
-                    64 => ._pd,
-                    16, 80, 128 => null,
-                    else => unreachable,
-                },
-            }) orelse return self.fail("TODO implement airSelect for {}", .{ty.fmt(pt)});
-            try self.asmRegisterRegister(.{ mir_fixes, .@"and" }, dst_alias, mask_alias);
+            try self.asmRegisterRegister(.{ mir_tag[0], .@"and" }, dst_alias, mask_alias);
             if (rhs_mcv.isBase()) try self.asmRegisterMemory(
-                .{ mir_fixes, .andn },
+                .{ mir_tag[0], .andn },
                 mask_alias,
                 try rhs_mcv.mem(self, .{ .size = .fromSize(abi_size) }),
             ) else try self.asmRegisterRegister(
-                .{ mir_fixes, .andn },
+                .{ mir_tag[0], .andn },
                 mask_alias,
                 if (rhs_mcv.isRegister())
                     rhs_mcv.getReg().?
                 else
                     try self.copyToTmpRegister(ty, rhs_mcv),
             );
-            try self.asmRegisterRegister(.{ mir_fixes, .@"or" }, dst_alias, mask_alias);
+            try self.asmRegisterRegister(.{ mir_tag[0], .@"or" }, dst_alias, mask_alias);
         }
         break :result dst_mcv;
     };
test/behavior/select.zig
@@ -66,3 +66,23 @@ fn selectArrays() !void {
     const xyz = @select(f32, x, y, z);
     try expect(mem.eql(f32, &@as([4]f32, xyz), &[4]f32{ 0.0, 312.1, -145.9, -3381.233 }));
 }
+
+test "@select compare result" {
+    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest;
+
+    const S = struct {
+        fn min(comptime V: type, lhs: V, rhs: V) V {
+            return @select(@typeInfo(V).vector.child, lhs < rhs, lhs, rhs);
+        }
+
+        fn doTheTest() !void {
+            try expect(@reduce(.And, min(@Vector(4, f32), .{ -1, 2, -3, 4 }, .{ 1, -2, 3, -4 }) == @Vector(4, f32){ -1, -2, -3, -4 }));
+            try expect(@reduce(.And, min(@Vector(2, f64), .{ -1, 2 }, .{ 1, -2 }) == @Vector(2, f64){ -1, -2 }));
+        }
+    };
+
+    try S.doTheTest();
+    try comptime S.doTheTest();
+}