Commit 283f40e4e9

Jakub Konka <kubkon@jakubkonka.com>
2022-05-19 15:34:13
x64: use StaticBitSet instead of an integer internally in RegisterManager
1 parent 080d138
Changed files (3)
src/arch/x86_64/abi.zig
@@ -3,6 +3,7 @@ const Type = @import("../../type.zig").Type;
 const Target = std.Target;
 const assert = std.debug.assert;
 const Register = @import("bits.zig").Register;
+const RegisterManagerFn = @import("../../register_manager.zig").RegisterManager;
 
 pub const Class = enum { integer, sse, sseup, x87, x87up, complex_x87, memory, none };
 
@@ -378,18 +379,34 @@ pub const callee_preserved_regs = [_]Register{ .rbx, .r12, .r13, .r14, .r15 };
 /// the caller relinquishes control to a subroutine via call instruction (or similar).
 /// In other words, these registers are free to use by the callee.
 pub const caller_preserved_regs = [_]Register{ .rax, .rcx, .rdx, .rsi, .rdi, .r8, .r9, .r10, .r11 };
-pub const avx_regs = [_]Register{
-    .ymm0, .ymm1, .ymm2,  .ymm3,  .ymm4,  .ymm5,  .ymm6,  .ymm7,
-    .ymm8, .ymm9, .ymm10, .ymm11, .ymm12, .ymm13, .ymm14, .ymm15,
-};
-pub const allocatable_registers = callee_preserved_regs ++ caller_preserved_regs ++ avx_regs;
 
 pub const c_abi_int_param_regs = [_]Register{ .rdi, .rsi, .rdx, .rcx, .r8, .r9 };
 pub const c_abi_int_return_regs = [_]Register{ .rax, .rdx };
 
-// Masks for register manager
-const FreeRegInt = std.meta.Int(.unsigned, allocatable_registers.len);
+const avx_regs = [_]Register{
+    .ymm0, .ymm1, .ymm2,  .ymm3,  .ymm4,  .ymm5,  .ymm6,  .ymm7,
+    .ymm8, .ymm9, .ymm10, .ymm11, .ymm12, .ymm13, .ymm14, .ymm15,
+};
+const allocatable_registers = callee_preserved_regs ++ caller_preserved_regs ++ avx_regs;
+pub const RegisterManager = RegisterManagerFn(@import("CodeGen.zig"), Register, &allocatable_registers);
+
+// Register classes
+const RegisterBitSet = RegisterManager.RegisterBitSet;
 pub const RegisterClass = struct {
-    pub const gp: FreeRegInt = 0x3fff;
-    pub const avx: FreeRegInt = 0x3fff_c000;
+    pub const gp: RegisterBitSet = blk: {
+        var set = RegisterBitSet.initEmpty();
+        set.setRangeValue(.{
+            .start = 0,
+            .end = caller_preserved_regs.len + callee_preserved_regs.len,
+        }, true);
+        break :blk set;
+    };
+    pub const avx: RegisterBitSet = blk: {
+        var set = RegisterBitSet.initEmpty();
+        set.setRangeValue(.{
+            .start = caller_preserved_regs.len + callee_preserved_regs.len,
+            .end = allocatable_registers.len,
+        }, true);
+        break :blk set;
+    };
 };
src/arch/x86_64/CodeGen.zig
@@ -21,7 +21,6 @@ const Emit = @import("Emit.zig");
 const Liveness = @import("../../Liveness.zig");
 const Mir = @import("Mir.zig");
 const Module = @import("../../Module.zig");
-const RegisterManagerFn = @import("../../register_manager.zig").RegisterManager;
 const Target = std.Target;
 const Type = @import("../../type.zig").Type;
 const TypedValue = @import("../../TypedValue.zig");
@@ -32,15 +31,15 @@ const abi = @import("abi.zig");
 
 const callee_preserved_regs = abi.callee_preserved_regs;
 const caller_preserved_regs = abi.caller_preserved_regs;
-const allocatable_registers = abi.allocatable_registers;
 const c_abi_int_param_regs = abi.c_abi_int_param_regs;
 const c_abi_int_return_regs = abi.c_abi_int_return_regs;
-const RegisterManager = RegisterManagerFn(Self, Register, &allocatable_registers);
+
+const RegisterManager = abi.RegisterManager;
 const RegisterLock = RegisterManager.RegisterLock;
 const Register = bits.Register;
-const RegisterClass = abi.RegisterClass;
-const gp = RegisterClass.gp;
-const avx = RegisterClass.avx;
+
+const gp = abi.RegisterClass.gp;
+const avx = abi.RegisterClass.avx;
 
 const InnerError = error{
     OutOfMemory,
src/register_manager.zig
@@ -4,6 +4,7 @@ const mem = std.mem;
 const assert = std.debug.assert;
 const Allocator = std.mem.Allocator;
 const Air = @import("Air.zig");
+const StaticBitSet = std.bit_set.StaticBitSet;
 const Type = @import("type.zig").Type;
 const Module = @import("Module.zig");
 const expect = std.testing.expect;
@@ -41,66 +42,54 @@ pub fn RegisterManager(
         registers: [tracked_registers.len]Air.Inst.Index = undefined,
         /// Tracks which registers are free (in which case the
         /// corresponding bit is set to 1)
-        free_registers: FreeRegInt = math.maxInt(FreeRegInt),
+        free_registers: RegisterBitSet = RegisterBitSet.initFull(),
         /// Tracks all registers allocated in the course of this
         /// function
-        allocated_registers: FreeRegInt = 0,
+        allocated_registers: RegisterBitSet = RegisterBitSet.initEmpty(),
         /// Tracks registers which are locked from being allocated
-        locked_registers: FreeRegInt = 0,
+        locked_registers: RegisterBitSet = RegisterBitSet.initEmpty(),
 
         const Self = @This();
 
-        /// An integer whose bits represent all the registers and
-        /// whether they are free.
-        const FreeRegInt = std.meta.Int(.unsigned, tracked_registers.len);
-        const ShiftInt = math.Log2Int(FreeRegInt);
+        pub const RegisterBitSet = StaticBitSet(tracked_registers.len);
 
         fn getFunction(self: *Self) *Function {
             return @fieldParentPtr(Function, "register_manager", self);
         }
 
-        fn getRegisterMask(reg: Register) ?FreeRegInt {
-            const index = indexOfRegIntoTracked(reg) orelse return null;
-            const shift = @intCast(ShiftInt, index);
-            const mask = @as(FreeRegInt, 1) << shift;
-            return mask;
-        }
-
-        fn excludeRegister(reg: Register, mask: FreeRegInt) bool {
-            const reg_mask = getRegisterMask(reg) orelse return true;
-            return reg_mask & mask == 0;
-        }
-
         fn markRegAllocated(self: *Self, reg: Register) void {
-            const mask = getRegisterMask(reg) orelse return;
-            self.allocated_registers |= mask;
+            const index = indexOfRegIntoTracked(reg) orelse return;
+            self.allocated_registers.set(index);
         }
 
         fn markRegUsed(self: *Self, reg: Register) void {
-            const mask = getRegisterMask(reg) orelse return;
-            self.free_registers &= ~mask;
+            const index = indexOfRegIntoTracked(reg) orelse return;
+            self.free_registers.unset(index);
         }
 
         fn markRegFree(self: *Self, reg: Register) void {
-            const mask = getRegisterMask(reg) orelse return;
-            self.free_registers |= mask;
+            const index = indexOfRegIntoTracked(reg) orelse return;
+            self.free_registers.set(index);
         }
 
-        pub fn indexOfReg(comptime registers: []const Register, reg: Register) ?std.math.IntFittingRange(0, registers.len - 1) {
+        pub fn indexOfReg(
+            comptime registers: []const Register,
+            reg: Register,
+        ) ?std.math.IntFittingRange(0, registers.len - 1) {
             inline for (tracked_registers) |cpreg, i| {
                 if (reg.id() == cpreg.id()) return i;
             }
             return null;
         }
 
-        pub fn indexOfRegIntoTracked(reg: Register) ?ShiftInt {
+        pub fn indexOfRegIntoTracked(reg: Register) ?RegisterBitSet.ShiftInt {
             return indexOfReg(tracked_registers, reg);
         }
 
         /// Returns true when this register is not tracked
         pub fn isRegFree(self: Self, reg: Register) bool {
-            const mask = getRegisterMask(reg) orelse return true;
-            return self.free_registers & mask != 0;
+            const index = indexOfRegIntoTracked(reg) orelse return true;
+            return self.free_registers.isSet(index);
         }
 
         /// Returns whether this register was allocated in the course
@@ -108,16 +97,16 @@ pub fn RegisterManager(
         ///
         /// Returns false when this register is not tracked
         pub fn isRegAllocated(self: Self, reg: Register) bool {
-            const mask = getRegisterMask(reg) orelse return false;
-            return self.allocated_registers & mask != 0;
+            const index = indexOfRegIntoTracked(reg) orelse return false;
+            return self.allocated_registers.isSet(index);
         }
 
         /// Returns whether this register is locked
         ///
         /// Returns false when this register is not tracked
         pub fn isRegLocked(self: Self, reg: Register) bool {
-            const mask = getRegisterMask(reg) orelse return false;
-            return self.locked_registers & mask != 0;
+            const index = indexOfRegIntoTracked(reg) orelse return false;
+            return self.locked_registers.isSet(index);
         }
 
         pub const RegisterLock = struct {
@@ -136,8 +125,8 @@ pub fn RegisterManager(
                 log.debug("  register already locked", .{});
                 return null;
             }
-            const mask = getRegisterMask(reg) orelse return null;
-            self.locked_registers |= mask;
+            const index = indexOfRegIntoTracked(reg) orelse return null;
+            self.locked_registers.set(index);
             return RegisterLock{ .register = reg };
         }
 
@@ -146,8 +135,8 @@ pub fn RegisterManager(
         pub fn lockRegAssumeUnused(self: *Self, reg: Register) RegisterLock {
             log.debug("locking asserting free {}", .{reg});
             assert(!self.isRegLocked(reg));
-            const mask = getRegisterMask(reg) orelse unreachable;
-            self.locked_registers |= mask;
+            const index = indexOfRegIntoTracked(reg) orelse unreachable;
+            self.locked_registers.set(index);
             return RegisterLock{ .register = reg };
         }
 
@@ -169,17 +158,17 @@ pub fn RegisterManager(
         /// Call `lockReg` to obtain the lock first.
         pub fn unlockReg(self: *Self, lock: RegisterLock) void {
             log.debug("unlocking {}", .{lock.register});
-            const mask = getRegisterMask(lock.register) orelse return;
-            self.locked_registers &= ~mask;
+            const index = indexOfRegIntoTracked(lock.register) orelse return;
+            self.locked_registers.unset(index);
         }
 
         /// Returns true when at least one register is locked
         pub fn lockedRegsExist(self: Self) bool {
-            return self.locked_registers != 0;
+            return self.locked_registers.count() > 0;
         }
 
         const AllocOpts = struct {
-            selector_mask: ?FreeRegInt = null,
+            selector_mask: ?RegisterBitSet = null,
         };
 
         /// Allocates a specified number of registers, optionally
@@ -193,17 +182,22 @@ pub fn RegisterManager(
         ) ?[count]Register {
             comptime assert(count > 0 and count <= tracked_registers.len);
 
-            const selector_mask = if (opts.selector_mask) |mask| mask else ~@as(FreeRegInt, 0);
-            const free_registers = self.free_registers & selector_mask;
-            const free_and_not_locked_registers = free_registers & ~self.locked_registers;
-            const free_and_not_locked_registers_count = @popCount(FreeRegInt, free_and_not_locked_registers);
-            if (free_and_not_locked_registers_count < count) return null;
+            const available_registers = opts.selector_mask orelse RegisterBitSet.initFull();
+
+            var free_and_not_locked_registers = self.free_registers;
+            free_and_not_locked_registers.setIntersection(available_registers);
+
+            var unlocked_registers = self.locked_registers;
+            unlocked_registers.toggleAll();
+
+            free_and_not_locked_registers.setIntersection(unlocked_registers);
+
+            if (free_and_not_locked_registers.count() < count) return null;
 
             var regs: [count]Register = undefined;
             var i: usize = 0;
             for (tracked_registers) |reg| {
                 if (i >= count) break;
-                if (excludeRegister(reg, selector_mask)) continue;
                 if (self.isRegLocked(reg)) continue;
                 if (!self.isRegFree(reg)) continue;
 
@@ -244,11 +238,12 @@ pub fn RegisterManager(
         ) AllocateRegistersError![count]Register {
             comptime assert(count > 0 and count <= tracked_registers.len);
 
-            const selector_mask = if (opts.selector_mask) |mask| mask else ~@as(FreeRegInt, 0);
-            const available_registers_count = @popCount(FreeRegInt, selector_mask);
-            const locked_registers = self.locked_registers & selector_mask;
-            const locked_registers_count = @popCount(FreeRegInt, locked_registers);
-            if (count > available_registers_count - locked_registers_count) return error.OutOfRegisters;
+            const available_registers = opts.selector_mask orelse RegisterBitSet.initFull();
+
+            var locked_registers = self.locked_registers;
+            locked_registers.setIntersection(available_registers);
+
+            if (count > available_registers.count() - locked_registers.count()) return error.OutOfRegisters;
 
             const result = self.tryAllocRegs(count, insts, opts) orelse blk: {
                 // We'll take over the first count registers. Spill
@@ -258,7 +253,6 @@ pub fn RegisterManager(
                 var i: usize = 0;
                 for (tracked_registers) |reg| {
                     if (i >= count) break;
-                    if (excludeRegister(reg, selector_mask)) continue;
                     if (self.isRegLocked(reg)) continue;
 
                     regs[i] = reg;