Commit 46d3e5bb04

Jacob Young <jacobly0@users.noreply.github.com>
2023-11-03 12:55:09
x86_64: reduce `RegisterManager` performance regression
This reduces the regression from 0.11.0 by 95%. Closes #17678
1 parent 509be7c
Changed files (3)
src/arch/x86_64/bits.zig
@@ -222,7 +222,7 @@ pub const Register = enum(u7) {
             @intFromEnum(Register.eax)  ... @intFromEnum(Register.r15d)  => @intFromEnum(Register.eax),
             @intFromEnum(Register.ax)   ... @intFromEnum(Register.r15w)  => @intFromEnum(Register.ax),
             @intFromEnum(Register.al)   ... @intFromEnum(Register.r15b)  => @intFromEnum(Register.al),
-            @intFromEnum(Register.ah)   ... @intFromEnum(Register.bh)    => @intFromEnum(Register.ah) - 4,
+            @intFromEnum(Register.ah)   ... @intFromEnum(Register.bh)    => @intFromEnum(Register.ah),
 
             @intFromEnum(Register.ymm0) ... @intFromEnum(Register.ymm15) => @intFromEnum(Register.ymm0) - 16,
             @intFromEnum(Register.xmm0) ... @intFromEnum(Register.xmm15) => @intFromEnum(Register.xmm0) - 16,
src/arch/x86_64/CodeGen.zig
@@ -2608,8 +2608,17 @@ pub fn spillEflagsIfOccupied(self: *Self) !void {
     }
 }
 
-pub fn spillRegisters(self: *Self, registers: []const Register) !void {
-    for (registers) |reg| try self.register_manager.getReg(reg, null);
+pub fn spillCallerPreservedRegs(self: *Self, cc: std.builtin.CallingConvention) !void {
+    switch (cc) {
+        inline .SysV, .Win64 => |known_cc| try self.spillRegisters(
+            comptime abi.getCallerPreservedRegs(known_cc),
+        ),
+        else => unreachable,
+    }
+}
+
+pub fn spillRegisters(self: *Self, comptime registers: []const Register) !void {
+    inline for (registers) |reg| try self.register_manager.getKnownReg(reg, null);
 }
 
 /// Copies a value to a register without tracking the register. The register is not considered
@@ -10639,30 +10648,30 @@ fn genCall(self: *Self, info: union(enum) {
     }
 
     try self.spillEflagsIfOccupied();
-    try self.spillRegisters(abi.getCallerPreservedRegs(resolved_cc));
+    try self.spillCallerPreservedRegs(resolved_cc);
 
     // set stack arguments first because this can clobber registers
     // also clobber spill arguments as we go
     switch (call_info.return_value.long) {
         .none, .unreach => {},
-        .indirect => |reg_off| try self.spillRegisters(&.{reg_off.reg}),
+        .indirect => |reg_off| try self.register_manager.getReg(reg_off.reg, null),
         else => unreachable,
     }
     for (call_info.args, arg_types, args, frame_indices) |dst_arg, arg_ty, src_arg, *frame_index|
         switch (dst_arg) {
             .none => {},
             .register => |reg| {
-                try self.spillRegisters(&.{reg});
+                try self.register_manager.getReg(reg, null);
                 try reg_locks.append(self.register_manager.lockReg(reg));
             },
             .register_pair => |regs| {
-                try self.spillRegisters(&regs);
+                for (regs) |reg| try self.register_manager.getReg(reg, null);
                 try reg_locks.appendSlice(&self.register_manager.lockRegs(2, regs));
             },
             .indirect => |reg_off| {
                 frame_index.* = try self.allocFrameIndex(FrameAlloc.initType(arg_ty, mod));
                 try self.genSetMem(.{ .frame = frame_index.* }, 0, arg_ty, src_arg);
-                try self.spillRegisters(&.{reg_off.reg});
+                try self.register_manager.getReg(reg_off.reg, null);
                 try reg_locks.append(self.register_manager.lockReg(reg_off.reg));
             },
             .load_frame => {
@@ -11990,8 +11999,9 @@ fn airAsm(self: *Self, inst: Air.Inst.Index) !void {
     var args = std.ArrayList(MCValue).init(self.gpa);
     try args.ensureTotalCapacity(outputs.len + inputs.len);
     defer {
-        for (args.items) |arg| if (arg.getReg()) |reg|
-            self.register_manager.unlockReg(.{ .register = reg });
+        for (args.items) |arg| if (arg.getReg()) |reg| self.register_manager.unlockReg(.{
+            .tracked_index = RegisterManager.indexOfRegIntoTracked(reg) orelse continue,
+        });
         args.deinit();
     }
     var arg_map = std.StringHashMap(u8).init(self.gpa);
@@ -14557,7 +14567,7 @@ fn airTagName(self: *Self, inst: Air.Inst.Index) !void {
     }
 
     try self.spillEflagsIfOccupied();
-    try self.spillRegisters(abi.getCallerPreservedRegs(resolved_cc));
+    try self.spillCallerPreservedRegs(resolved_cc);
 
     const param_regs = abi.getCAbiIntParamRegs(resolved_cc);
 
src/register_manager.zig
@@ -55,6 +55,7 @@ pub fn RegisterManager(
         const Self = @This();
 
         pub const TrackedRegisters = [tracked_registers.len]Air.Inst.Index;
+        pub const TrackedIndex = std.math.IntFittingRange(0, tracked_registers.len - 1);
         pub const RegisterBitSet = StaticBitSet(tracked_registers.len);
 
         fn getFunction(self: *Self) *Function {
@@ -66,45 +67,64 @@ pub fn RegisterManager(
             return !register_class.isSet(index);
         }
 
+        fn markRegIndexAllocated(self: *Self, tracked_index: TrackedIndex) void {
+            self.allocated_registers.set(tracked_index);
+        }
         fn markRegAllocated(self: *Self, reg: Register) void {
-            const index = indexOfRegIntoTracked(reg) orelse return;
-            self.allocated_registers.set(index);
+            self.markRegIndexAllocated(indexOfRegIntoTracked(reg) orelse return);
         }
 
+        fn markRegIndexUsed(self: *Self, tracked_index: TrackedIndex) void {
+            self.free_registers.unset(tracked_index);
+        }
         fn markRegUsed(self: *Self, reg: Register) void {
-            const index = indexOfRegIntoTracked(reg) orelse return;
-            self.free_registers.unset(index);
+            self.markRegIndexUsed(indexOfRegIntoTracked(reg) orelse return);
         }
 
+        fn markRegIndexFree(self: *Self, tracked_index: TrackedIndex) void {
+            self.free_registers.set(tracked_index);
+        }
         fn markRegFree(self: *Self, reg: Register) void {
-            const index = indexOfRegIntoTracked(reg) orelse return;
-            self.free_registers.set(index);
+            self.markRegIndexFree(indexOfRegIntoTracked(reg) orelse return);
         }
 
         pub fn indexOfReg(
-            comptime registers: []const Register,
+            comptime set: []const Register,
             reg: Register,
-        ) ?std.math.IntFittingRange(0, registers.len - 1) {
-            inline for (tracked_registers, 0..) |cpreg, i| {
-                if (reg.id() == cpreg.id()) return i;
+        ) ?std.math.IntFittingRange(0, set.len - 1) {
+            const Id = @TypeOf(reg.id());
+            comptime var min_id: Id = std.math.maxInt(Id);
+            comptime var max_id: Id = std.math.minInt(Id);
+            inline for (set) |elem| {
+                const elem_id = comptime elem.id();
+                min_id = @min(elem_id, min_id);
+                max_id = @max(elem_id, max_id);
             }
-            return null;
+
+            const OptionalIndex = std.math.IntFittingRange(0, set.len);
+            comptime var map = [1]OptionalIndex{set.len} ** (max_id + 1 - min_id);
+            inline for (set, 0..) |elem, elem_index| map[comptime elem.id() - min_id] = elem_index;
+
+            const id_index = reg.id() -% min_id;
+            if (id_index >= map.len) return null;
+            const set_index = map[id_index];
+            return if (set_index < set.len) @intCast(set_index) else null;
         }
 
-        pub fn indexOfRegIntoTracked(
-            reg: Register,
-        ) ?std.math.IntFittingRange(0, tracked_registers.len) {
+        pub fn indexOfRegIntoTracked(reg: Register) ?TrackedIndex {
             return indexOfReg(tracked_registers, reg);
         }
 
-        pub fn regAtTrackedIndex(index: std.math.IntFittingRange(0, tracked_registers.len)) Register {
-            return tracked_registers[index];
+        pub fn regAtTrackedIndex(tracked_index: TrackedIndex) Register {
+            return tracked_registers[tracked_index];
         }
 
         /// Returns true when this register is not tracked
+        pub fn isRegIndexFree(self: Self, tracked_index: TrackedIndex) bool {
+            return self.free_registers.isSet(tracked_index);
+        }
         pub fn isRegFree(self: Self, reg: Register) bool {
-            const index = indexOfRegIntoTracked(reg) orelse return true;
-            return self.free_registers.isSet(index);
+            return self.isRegIndexFree(indexOfRegIntoTracked(reg) orelse return true);
         }
 
         /// Returns whether this register was allocated in the course
@@ -119,14 +139,14 @@ pub fn RegisterManager(
         /// Returns whether this register is locked
         ///
         /// Returns false when this register is not tracked
+        fn isRegIndexLocked(self: Self, tracked_index: TrackedIndex) bool {
+            return self.locked_registers.isSet(tracked_index);
+        }
         pub fn isRegLocked(self: Self, reg: Register) bool {
-            const index = indexOfRegIntoTracked(reg) orelse return false;
-            return self.locked_registers.isSet(index);
+            return self.isRegIndexLocked(indexOfRegIntoTracked(reg) orelse return false);
         }
 
-        pub const RegisterLock = struct {
-            register: Register,
-        };
+        pub const RegisterLock = struct { tracked_index: TrackedIndex };
 
         /// Prevents the register from being allocated until they are
         /// unlocked again.
@@ -134,25 +154,29 @@ pub fn RegisterManager(
         /// locked, or `null` otherwise.
         /// Only the owner of the `RegisterLock` can unlock the
         /// register later.
-        pub fn lockReg(self: *Self, reg: Register) ?RegisterLock {
-            log.debug("locking {}", .{reg});
-            if (self.isRegLocked(reg)) {
+        pub fn lockRegIndex(self: *Self, tracked_index: TrackedIndex) ?RegisterLock {
+            log.debug("locking {}", .{regAtTrackedIndex(tracked_index)});
+            if (self.isRegIndexLocked(tracked_index)) {
                 log.debug("  register already locked", .{});
                 return null;
             }
-            const index = indexOfRegIntoTracked(reg) orelse return null;
-            self.locked_registers.set(index);
-            return RegisterLock{ .register = reg };
+            self.locked_registers.set(tracked_index);
+            return RegisterLock{ .tracked_index = tracked_index };
+        }
+        pub fn lockReg(self: *Self, reg: Register) ?RegisterLock {
+            return self.lockRegIndex(indexOfRegIntoTracked(reg) orelse return null);
         }
 
         /// Like `lockReg` but asserts the register was unused always
         /// returning a valid lock.
+        pub fn lockRegIndexAssumeUnused(self: *Self, tracked_index: TrackedIndex) RegisterLock {
+            log.debug("locking asserting free {}", .{regAtTrackedIndex(tracked_index)});
+            assert(!self.isRegIndexLocked(tracked_index));
+            self.locked_registers.set(tracked_index);
+            return RegisterLock{ .tracked_index = tracked_index };
+        }
         pub fn lockRegAssumeUnused(self: *Self, reg: Register) RegisterLock {
-            log.debug("locking asserting free {}", .{reg});
-            assert(!self.isRegLocked(reg));
-            const index = indexOfRegIntoTracked(reg) orelse unreachable;
-            self.locked_registers.set(index);
-            return RegisterLock{ .register = reg };
+            return self.lockRegIndexAssumeUnused(indexOfRegIntoTracked(reg) orelse unreachable);
         }
 
         /// Like `lockReg` but locks multiple registers.
@@ -181,9 +205,8 @@ pub fn RegisterManager(
         /// Requires `RegisterLock` to unlock a register.
         /// Call `lockReg` to obtain the lock first.
         pub fn unlockReg(self: *Self, lock: RegisterLock) void {
-            log.debug("unlocking {}", .{lock.register});
-            const index = indexOfRegIntoTracked(lock.register) orelse return;
-            self.locked_registers.unset(index);
+            log.debug("unlocking {}", .{regAtTrackedIndex(lock.tracked_index)});
+            self.locked_registers.unset(lock.tracked_index);
         }
 
         /// Returns true when at least one register is locked
@@ -319,44 +342,63 @@ pub fn RegisterManager(
         /// Spills the register if it is currently allocated. If a
         /// corresponding instruction is passed, will also track this
         /// register.
-        pub fn getReg(self: *Self, reg: Register, inst: ?Air.Inst.Index) AllocateRegistersError!void {
-            const index = indexOfRegIntoTracked(reg) orelse return;
-            log.debug("getReg {} for inst {?}", .{ reg, inst });
-
-            if (!self.isRegFree(reg)) {
-                self.markRegAllocated(reg);
+        fn getRegIndex(
+            self: *Self,
+            tracked_index: TrackedIndex,
+            inst: ?Air.Inst.Index,
+        ) AllocateRegistersError!void {
+            log.debug("getReg {} for inst {?}", .{ regAtTrackedIndex(tracked_index), inst });
+            if (!self.isRegIndexFree(tracked_index)) {
+                self.markRegIndexAllocated(tracked_index);
 
                 // Move the instruction that was previously there to a
                 // stack allocation.
-                const spilled_inst = self.registers[index];
-                if (inst) |tracked_inst| self.registers[index] = tracked_inst;
-                try self.getFunction().spillInstruction(reg, spilled_inst);
-                if (inst == null) self.freeReg(reg);
-            } else self.getRegAssumeFree(reg, inst);
+                const spilled_inst = self.registers[tracked_index];
+                if (inst) |tracked_inst| self.registers[tracked_index] = tracked_inst;
+                try self.getFunction().spillInstruction(regAtTrackedIndex(tracked_index), spilled_inst);
+                if (inst == null) self.freeRegIndex(tracked_index);
+            } else self.getRegIndexAssumeFree(tracked_index, inst);
+        }
+        pub fn getReg(self: *Self, reg: Register, inst: ?Air.Inst.Index) AllocateRegistersError!void {
+            return self.getRegIndex(indexOfRegIntoTracked(reg) orelse return, inst);
+        }
+        pub fn getKnownReg(
+            self: *Self,
+            comptime reg: Register,
+            inst: ?Air.Inst.Index,
+        ) AllocateRegistersError!void {
+            return self.getRegIndex((comptime indexOfRegIntoTracked(reg)) orelse return, inst);
         }
 
         /// Allocates the specified register with the specified
         /// instruction. Asserts that the register is free and no
         /// spilling is necessary.
-        pub fn getRegAssumeFree(self: *Self, reg: Register, inst: ?Air.Inst.Index) void {
-            const index = indexOfRegIntoTracked(reg) orelse return;
-            log.debug("getRegAssumeFree {} for inst {?}", .{ reg, inst });
-            self.markRegAllocated(reg);
+        fn getRegIndexAssumeFree(
+            self: *Self,
+            tracked_index: TrackedIndex,
+            inst: ?Air.Inst.Index,
+        ) void {
+            log.debug("getRegAssumeFree {} for inst {?}", .{ regAtTrackedIndex(tracked_index), inst });
+            self.markRegIndexAllocated(tracked_index);
 
-            assert(self.isRegFree(reg));
+            assert(self.isRegIndexFree(tracked_index));
             if (inst) |tracked_inst| {
-                self.registers[index] = tracked_inst;
-                self.markRegUsed(reg);
+                self.registers[tracked_index] = tracked_inst;
+                self.markRegIndexUsed(tracked_index);
             }
         }
+        pub fn getRegAssumeFree(self: *Self, reg: Register, inst: ?Air.Inst.Index) void {
+            self.getRegIndexAssumeFree(indexOfRegIntoTracked(reg) orelse return, inst);
+        }
 
         /// Marks the specified register as free
+        fn freeRegIndex(self: *Self, tracked_index: TrackedIndex) void {
+            log.debug("freeing register {}", .{regAtTrackedIndex(tracked_index)});
+            self.registers[tracked_index] = undefined;
+            self.markRegIndexFree(tracked_index);
+        }
         pub fn freeReg(self: *Self, reg: Register) void {
-            const index = indexOfRegIntoTracked(reg) orelse return;
-            log.debug("freeing register {}", .{reg});
-
-            self.registers[index] = undefined;
-            self.markRegFree(reg);
+            self.freeRegIndex(indexOfRegIntoTracked(reg) orelse return);
         }
     };
 }