Commit 4e32193de3

xtex <xtexchooser@duck.com>
2025-01-18 14:59:00
x86_64: implement integer saturating left shifting codegen
Simliarly to shl_with_overflow, we first SHL/SAL the integer, then SHR/SAR it back to compare if overflow happens. If overflow happened, set result to the upper limit to make it saturating. Bug: #17645 Co-authored-by: Jacob Young <jacobly0@users.noreply.github.com> Signed-off-by: Bingwu Zhang <xtex@aosc.io>
1 parent 6c3cbb0
Changed files (2)
src
arch
test
src/arch/x86_64/CodeGen.zig
@@ -85049,10 +85049,132 @@ fn airShlShrBinOp(self: *CodeGen, inst: Air.Inst.Index) !void {
 }
 
 fn airShlSat(self: *CodeGen, inst: Air.Inst.Index) !void {
+    const zcu = self.pt.zcu;
     const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
-    _ = bin_op;
-    return self.fail("TODO implement shl_sat for {}", .{self.target.cpu.arch});
-    //return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
+    const lhs_ty = self.typeOf(bin_op.lhs);
+    const rhs_ty = self.typeOf(bin_op.rhs);
+
+    const result: MCValue = result: {
+        switch (lhs_ty.zigTypeTag(zcu)) {
+            .int => {
+                const lhs_bits = lhs_ty.bitSize(zcu);
+                const rhs_bits = rhs_ty.bitSize(zcu);
+                if (!(lhs_bits <= 32 and rhs_bits <= 5) and !(lhs_bits > 32 and lhs_bits <= 64 and rhs_bits <= 6) and !(rhs_bits <= std.math.log2(lhs_bits))) {
+                    return self.fail("TODO implement shl_sat for {} with lhs bits {}, rhs bits {}", .{ self.target.cpu.arch, lhs_bits, rhs_bits });
+                }
+
+                // clobberred by genShiftBinOp
+                try self.spillRegisters(&.{.rcx});
+
+                const lhs_mcv = try self.resolveInst(bin_op.lhs);
+                var lhs_temp1 = try self.tempInit(lhs_ty, lhs_mcv);
+                const rhs_mcv = try self.resolveInst(bin_op.rhs);
+
+                const lhs_lock = switch (lhs_mcv) {
+                    .register => |reg| self.register_manager.lockRegAssumeUnused(reg),
+                    else => null,
+                };
+                defer if (lhs_lock) |lock| self.register_manager.unlockReg(lock);
+
+                // shift left
+                const dst_mcv = try self.genShiftBinOp(.shl, null, lhs_mcv, rhs_mcv, lhs_ty, rhs_ty);
+                switch (dst_mcv) {
+                    .register => |dst_reg| try self.truncateRegister(lhs_ty, dst_reg),
+                    .register_pair => |dst_regs| try self.truncateRegister(lhs_ty, dst_regs[1]),
+                    .load_frame => |frame_addr| {
+                        const tmp_reg =
+                            try self.register_manager.allocReg(null, abi.RegisterClass.gp);
+                        const tmp_lock = self.register_manager.lockRegAssumeUnused(tmp_reg);
+                        defer self.register_manager.unlockReg(tmp_lock);
+
+                        const lhs_bits_u31: u31 = @intCast(lhs_bits);
+                        const tmp_ty: Type = if (lhs_bits_u31 > 64) .usize else lhs_ty;
+                        const off = frame_addr.off + (lhs_bits_u31 - 1) / 64 * 8;
+                        try self.genSetReg(
+                            tmp_reg,
+                            tmp_ty,
+                            .{ .load_frame = .{ .index = frame_addr.index, .off = off } },
+                            .{},
+                        );
+                        try self.truncateRegister(lhs_ty, tmp_reg);
+                        try self.genSetMem(
+                            .{ .frame = frame_addr.index },
+                            off,
+                            tmp_ty,
+                            .{ .register = tmp_reg },
+                            .{},
+                        );
+                    },
+                    else => {},
+                }
+                const dst_lock = switch (dst_mcv) {
+                    .register => |reg| self.register_manager.lockRegAssumeUnused(reg),
+                    else => null,
+                };
+                defer if (dst_lock) |lock| self.register_manager.unlockReg(lock);
+
+                // shift right
+                const tmp_mcv = try self.genShiftBinOp(.shr, null, dst_mcv, rhs_mcv, lhs_ty, rhs_ty);
+                var tmp_temp = try self.tempInit(lhs_ty, tmp_mcv);
+
+                // check if overflow happens
+                const cc_temp = lhs_temp1.cmpInts(.neq, &tmp_temp, self) catch |err| switch (err) {
+                    error.SelectFailed => unreachable,
+                    else => |e| return e,
+                };
+                try lhs_temp1.die(self);
+                try tmp_temp.die(self);
+                const overflow_reloc = try self.genCondBrMir(lhs_ty, cc_temp.tracking(self).short);
+                try cc_temp.die(self);
+
+                // if overflow,
+                // for unsigned integers, the saturating result is just its max
+                // for signed integers,
+                //   if lhs is positive, the result is its max
+                //   if lhs is negative, it is min
+                switch (lhs_ty.intInfo(zcu).signedness) {
+                    .unsigned => {
+                        const bound_mcv = try self.genTypedValue(try lhs_ty.maxIntScalar(self.pt, lhs_ty));
+                        try self.genCopy(lhs_ty, dst_mcv, bound_mcv, .{});
+                    },
+                    .signed => {
+                        // check the sign of lhs
+                        // TODO: optimize this.
+                        // we only need the highest bit so shifting the highest part of lhs_mcv
+                        // is enough to check the signedness. other parts can be skipped here.
+                        var lhs_temp2 = try self.tempInit(lhs_ty, lhs_mcv);
+                        var zero_temp = try self.tempInit(lhs_ty, try self.genTypedValue(try self.pt.intValue(lhs_ty, 0)));
+                        const sign_cc_temp = lhs_temp2.cmpInts(.lt, &zero_temp, self) catch |err| switch (err) {
+                            error.SelectFailed => unreachable,
+                            else => |e| return e,
+                        };
+                        try lhs_temp2.die(self);
+                        try zero_temp.die(self);
+                        const sign_reloc_condbr = try self.genCondBrMir(lhs_ty, sign_cc_temp.tracking(self).short);
+                        try sign_cc_temp.die(self);
+
+                        // if it is negative
+                        const min_mcv = try self.genTypedValue(try lhs_ty.minIntScalar(self.pt, lhs_ty));
+                        try self.genCopy(lhs_ty, dst_mcv, min_mcv, .{});
+                        const sign_reloc_br = try self.asmJmpReloc(undefined);
+                        self.performReloc(sign_reloc_condbr);
+
+                        // if it is positive
+                        const max_mcv = try self.genTypedValue(try lhs_ty.maxIntScalar(self.pt, lhs_ty));
+                        try self.genCopy(lhs_ty, dst_mcv, max_mcv, .{});
+                        self.performReloc(sign_reloc_br);
+                    },
+                }
+
+                self.performReloc(overflow_reloc);
+                break :result dst_mcv;
+            },
+            else => {
+                return self.fail("TODO implement shl_sat for {} op type {}", .{ self.target.cpu.arch, lhs_ty.zigTypeTag(zcu) });
+            },
+        }
+    };
+    return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
 }
 
 fn airOptionalPayload(self: *CodeGen, inst: Air.Inst.Index) !void {
@@ -88437,7 +88559,7 @@ fn genShiftBinOpMir(
 ) !void {
     const pt = self.pt;
     const zcu = pt.zcu;
-    const abi_size: u32 = @intCast(lhs_ty.abiSize(zcu));
+    const abi_size: u31 = @intCast(lhs_ty.abiSize(zcu));
     const shift_abi_size: u32 = @intCast(rhs_ty.abiSize(zcu));
     try self.spillEflagsIfOccupied();
 
@@ -88621,7 +88743,17 @@ fn genShiftBinOpMir(
                 .immediate => {},
                 else => self.performReloc(skip),
             }
-        }
+        } else try self.asmRegisterMemory(.{ ._, .mov }, temp_regs[2].to64(), .{
+            .base = .{ .frame = lhs_mcv.load_frame.index },
+            .mod = .{ .rm = .{
+                .size = .qword,
+                .disp = switch (tag[0]) {
+                    ._l => lhs_mcv.load_frame.off,
+                    ._r => lhs_mcv.load_frame.off + abi_size - 8,
+                    else => unreachable,
+                },
+            } },
+        });
         switch (rhs_mcv) {
             .immediate => |shift_imm| try self.asmRegisterImmediate(
                 tag,
test/behavior/bit_shifting.zig
@@ -111,7 +111,6 @@ test "comptime shift safety check" {
 }
 
 test "Saturating Shift Left where lhs is of a computed type" {
-    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO