Commit 117f9f69e7

Jakub Konka <kubkon@jakubkonka.com>
2022-06-07 11:16:08
x64: simplify saving registers to stack in prologue
1 parent e9fc58e
Changed files (4)
src/arch/x86_64/CodeGen.zig
@@ -409,13 +409,11 @@ fn gen(self: *Self) InnerError!void {
             // The address where to store the return value for the caller is in `.rdi`
             // register which the callee is free to clobber. Therefore, we purposely
             // spill it to stack immediately.
-            const ptr_ty = Type.usize;
-            const abi_size = @intCast(u32, ptr_ty.abiSize(self.target.*));
-            const abi_align = ptr_ty.abiAlignment(self.target.*);
-            const stack_offset = mem.alignForwardGeneric(u32, self.next_stack_offset + abi_size, abi_align);
+            const stack_offset = mem.alignForwardGeneric(u32, self.next_stack_offset + 8, 8);
             self.next_stack_offset = stack_offset;
             self.max_end_stack = @maximum(self.max_end_stack, self.next_stack_offset);
-            try self.genSetStack(ptr_ty, @intCast(i32, stack_offset), MCValue{ .register = .rdi }, .{});
+
+            try self.genSetStack(Type.usize, @intCast(i32, stack_offset), MCValue{ .register = .rdi }, .{});
             self.ret_mcv = MCValue{ .stack_offset = @intCast(i32, stack_offset) };
             log.debug("gen: spilling .rdi to stack at offset {}", .{stack_offset});
         }
@@ -426,11 +424,11 @@ fn gen(self: *Self) InnerError!void {
             .data = undefined,
         });
 
-        // push the callee_preserved_regs that were used
-        const backpatch_push_callee_preserved_regs_i = try self.addInst(.{
-            .tag = .push_regs_from_callee_preserved_regs,
-            .ops = Mir.Inst.Ops.encode(.{ .reg1 = .rbp }),
-            .data = .{ .payload = undefined }, // to be backpatched
+        // Push callee-preserved regs that were used actually in use.
+        const backpatch_push_callee_preserved_regs = try self.addInst(.{
+            .tag = .nop,
+            .ops = undefined,
+            .data = undefined,
         });
 
         try self.genBody(self.air.getMainBody());
@@ -446,31 +444,21 @@ fn gen(self: *Self) InnerError!void {
             self.mir_instructions.items(.data)[jmp_reloc].inst = @intCast(u32, self.mir_instructions.len);
         }
 
-        // calculate the data for callee_preserved_regs to be pushed and popped
-        const callee_preserved_regs_payload = blk: {
-            var data = Mir.RegsToPushOrPop{
-                .regs = 0,
-                .disp = mem.alignForwardGeneric(u32, self.next_stack_offset, 8),
-            };
-            var disp = data.disp + 8;
-            inline for (callee_preserved_regs) |reg, i| {
-                if (self.register_manager.isRegAllocated(reg)) {
-                    data.regs |= 1 << @intCast(u5, i);
-                    self.max_end_stack += 8;
-                    disp += 8;
-                }
+        // Create list of registers to save in the prologue.
+        // TODO handle register classes
+        var reg_list: Mir.RegisterList(Register, &callee_preserved_regs) = .{};
+        inline for (callee_preserved_regs) |reg| {
+            if (self.register_manager.isRegAllocated(reg)) {
+                reg_list.push(reg);
             }
-            break :blk try self.addExtra(data);
-        };
+        }
+        const saved_regs_stack_space: u32 = reg_list.count() * 8;
 
-        const data = self.mir_instructions.items(.data);
-        // backpatch the push instruction
-        data[backpatch_push_callee_preserved_regs_i].payload = callee_preserved_regs_payload;
-        // pop the callee_preserved_regs
-        _ = try self.addInst(.{
-            .tag = .pop_regs_from_callee_preserved_regs,
-            .ops = Mir.Inst.Ops.encode(.{ .reg1 = .rbp }),
-            .data = .{ .payload = callee_preserved_regs_payload },
+        // Pop saved callee-preserved regs.
+        const backpatch_pop_callee_preserved_regs = try self.addInst(.{
+            .tag = .nop,
+            .ops = undefined,
+            .data = undefined,
         });
 
         _ = try self.addInst(.{
@@ -502,9 +490,11 @@ fn gen(self: *Self) InnerError!void {
         if (self.max_end_stack > math.maxInt(i32)) {
             return self.failSymbol("too much stack used in call parameters", .{});
         }
-        // TODO we should reuse this mechanism to align the stack when calling any function even if
-        // we do not pass any args on the stack BUT we still push regs to stack with `push` inst.
-        const aligned_stack_end = @intCast(u32, mem.alignForward(self.max_end_stack, self.stack_align));
+
+        const aligned_stack_end = @intCast(
+            u32,
+            mem.alignForward(self.max_end_stack + saved_regs_stack_space, self.stack_align),
+        );
         if (aligned_stack_end > 0) {
             self.mir_instructions.set(backpatch_stack_sub, .{
                 .tag = .sub,
@@ -516,6 +506,21 @@ fn gen(self: *Self) InnerError!void {
                 .ops = Mir.Inst.Ops.encode(.{ .reg1 = .rsp }),
                 .data = .{ .imm = aligned_stack_end },
             });
+
+            const save_reg_list = try self.addExtra(Mir.SaveRegisterList{
+                .register_list = reg_list.asInt(),
+                .stack_end = aligned_stack_end,
+            });
+            self.mir_instructions.set(backpatch_push_callee_preserved_regs, .{
+                .tag = .push_regs,
+                .ops = Mir.Inst.Ops.encode(.{ .reg1 = .rbp }),
+                .data = .{ .payload = save_reg_list },
+            });
+            self.mir_instructions.set(backpatch_pop_callee_preserved_regs, .{
+                .tag = .pop_regs,
+                .ops = Mir.Inst.Ops.encode(.{ .reg1 = .rbp }),
+                .data = .{ .payload = save_reg_list },
+            });
         }
     } else {
         _ = try self.addInst(.{
@@ -907,6 +912,39 @@ fn allocRegOrMem(self: *Self, inst: Air.Inst.Index, reg_ok: bool) !MCValue {
     return MCValue{ .stack_offset = @intCast(i32, stack_offset) };
 }
 
+const State = struct {
+    next_stack_offset: u32,
+    registers: abi.RegisterManager.TrackedRegisters,
+    free_registers: abi.RegisterManager.RegisterBitSet,
+    eflags_inst: ?Air.Inst.Index,
+    stack: std.AutoHashMapUnmanaged(u32, StackAllocation),
+
+    fn deinit(state: *State, gpa: Allocator) void {
+        state.stack.deinit(gpa);
+    }
+};
+
+fn captureState(self: *Self) !State {
+    return State{
+        .next_stack_offset = self.next_stack_offset,
+        .registers = self.register_manager.registers,
+        .free_registers = self.register_manager.free_registers,
+        .eflags_inst = self.eflags_inst,
+        .stack = try self.stack.clone(self.gpa),
+    };
+}
+
+fn revertState(self: *Self, state: State) void {
+    self.register_manager.registers = state.registers;
+    self.eflags_inst = state.eflags_inst;
+
+    self.stack.deinit(self.gpa);
+    self.stack = state.stack;
+
+    self.next_stack_offset = state.next_stack_offset;
+    self.register_manager.free_registers = state.free_registers;
+}
+
 pub fn spillInstruction(self: *Self, reg: Register, inst: Air.Inst.Index) !void {
     const stack_mcv = try self.allocRegOrMem(inst, false);
     log.debug("spilling {d} to stack mcv {any}", .{ inst, stack_mcv });
@@ -4503,12 +4541,7 @@ fn airCondBr(self: *Self, inst: Air.Inst.Index) !void {
     }
 
     // Capture the state of register and stack allocation state so that we can revert to it.
-    const parent_next_stack_offset = self.next_stack_offset;
-    const parent_free_registers = self.register_manager.free_registers;
-    const parent_eflags_inst = self.eflags_inst;
-    var parent_stack = try self.stack.clone(self.gpa);
-    defer parent_stack.deinit(self.gpa);
-    const parent_registers = self.register_manager.registers;
+    const saved_state = try self.captureState();
 
     try self.branch_stack.append(.{});
     errdefer {
@@ -4526,17 +4559,10 @@ fn airCondBr(self: *Self, inst: Air.Inst.Index) !void {
     var saved_then_branch = self.branch_stack.pop();
     defer saved_then_branch.deinit(self.gpa);
 
-    self.register_manager.registers = parent_registers;
-    self.eflags_inst = parent_eflags_inst;
-
-    self.stack.deinit(self.gpa);
-    self.stack = parent_stack;
-    parent_stack = .{};
-
-    self.next_stack_offset = parent_next_stack_offset;
-    self.register_manager.free_registers = parent_free_registers;
+    self.revertState(saved_state);
 
     try self.performReloc(reloc);
+
     const else_branch = self.branch_stack.addOneAssumeCapacity();
     else_branch.* = .{};
 
@@ -5021,12 +5047,7 @@ fn airSwitch(self: *Self, inst: Air.Inst.Index) !void {
         }
 
         // Capture the state of register and stack allocation state so that we can revert to it.
-        const parent_next_stack_offset = self.next_stack_offset;
-        const parent_free_registers = self.register_manager.free_registers;
-        const parent_eflags_inst = self.eflags_inst;
-        var parent_stack = try self.stack.clone(self.gpa);
-        defer parent_stack.deinit(self.gpa);
-        const parent_registers = self.register_manager.registers;
+        const saved_state = try self.captureState();
 
         try self.branch_stack.append(.{});
         errdefer {
@@ -5044,14 +5065,7 @@ fn airSwitch(self: *Self, inst: Air.Inst.Index) !void {
         var saved_case_branch = self.branch_stack.pop();
         defer saved_case_branch.deinit(self.gpa);
 
-        self.register_manager.registers = parent_registers;
-        self.eflags_inst = parent_eflags_inst;
-        self.stack.deinit(self.gpa);
-        self.stack = parent_stack;
-        parent_stack = .{};
-
-        self.next_stack_offset = parent_next_stack_offset;
-        self.register_manager.free_registers = parent_free_registers;
+        self.revertState(saved_state);
 
         for (relocs) |reloc| {
             try self.performReloc(reloc);
src/arch/x86_64/Emit.zig
@@ -169,7 +169,7 @@ pub fn lowerMir(emit: *Emit) InnerError!void {
             .@"test" => try emit.mirTest(inst),
 
             .interrupt => try emit.mirInterrupt(inst),
-            .nop => try emit.mirNop(),
+            .nop => {}, // just skip it
 
             // SSE instructions
             .mov_f64_sse => try emit.mirMovFloatSse(.movsd, inst),
@@ -198,8 +198,8 @@ pub fn lowerMir(emit: *Emit) InnerError!void {
             .dbg_prologue_end => try emit.mirDbgPrologueEnd(inst),
             .dbg_epilogue_begin => try emit.mirDbgEpilogueBegin(inst),
 
-            .push_regs_from_callee_preserved_regs => try emit.mirPushPopRegsFromCalleePreservedRegs(.push, inst),
-            .pop_regs_from_callee_preserved_regs => try emit.mirPushPopRegsFromCalleePreservedRegs(.pop, inst),
+            .push_regs => try emit.mirPushPopRegisterList(.push, inst),
+            .pop_regs => try emit.mirPushPopRegisterList(.pop, inst),
 
             else => {
                 return emit.fail("Implement MIR->Emit lowering for x86_64 for pseudo-inst: {s}", .{tag});
@@ -246,10 +246,6 @@ fn mirInterrupt(emit: *Emit, inst: Mir.Inst.Index) InnerError!void {
     }
 }
 
-fn mirNop(emit: *Emit) InnerError!void {
-    return lowerToZoEnc(.nop, emit.code);
-}
-
 fn mirSyscall(emit: *Emit) InnerError!void {
     return lowerToZoEnc(.syscall, emit.code);
 }
@@ -283,26 +279,27 @@ fn mirPushPop(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
     }
 }
 
-fn mirPushPopRegsFromCalleePreservedRegs(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
+fn mirPushPopRegisterList(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
     const ops = emit.mir.instructions.items(.ops)[inst].decode();
     const payload = emit.mir.instructions.items(.data)[inst].payload;
-    const data = emit.mir.extraData(Mir.RegsToPushOrPop, payload).data;
-    const regs = data.regs;
-    var disp: u32 = data.disp + 8;
-    for (abi.callee_preserved_regs) |reg, i| {
-        if ((regs >> @intCast(u5, i)) & 1 == 0) continue;
-        if (tag == .push) {
-            try lowerToMrEnc(.mov, RegisterOrMemory.mem(.qword_ptr, .{
-                .disp = @bitCast(u32, -@intCast(i32, disp)),
-                .base = ops.reg1,
-            }), reg.to64(), emit.code);
-        } else {
-            try lowerToRmEnc(.mov, reg.to64(), RegisterOrMemory.mem(.qword_ptr, .{
-                .disp = @bitCast(u32, -@intCast(i32, disp)),
-                .base = ops.reg1,
-            }), emit.code);
+    const save_reg_list = emit.mir.extraData(Mir.SaveRegisterList, payload).data;
+    const reg_list = Mir.RegisterList(Register, &abi.callee_preserved_regs).fromInt(save_reg_list.register_list);
+    var disp: i32 = -@intCast(i32, save_reg_list.stack_end);
+    inline for (abi.callee_preserved_regs) |reg| {
+        if (reg_list.isSet(reg)) {
+            switch (tag) {
+                .push => try lowerToMrEnc(.mov, RegisterOrMemory.mem(.qword_ptr, .{
+                    .disp = @bitCast(u32, disp),
+                    .base = ops.reg1,
+                }), reg, emit.code),
+                .pop => try lowerToRmEnc(.mov, reg, RegisterOrMemory.mem(.qword_ptr, .{
+                    .disp = @bitCast(u32, disp),
+                    .base = ops.reg1,
+                }), emit.code),
+                else => unreachable,
+            }
+            disp += 8;
         }
-        disp += 8;
     }
 }
 
src/arch/x86_64/Mir.zig
@@ -14,6 +14,7 @@ const assert = std.debug.assert;
 const bits = @import("bits.zig");
 const Air = @import("../../Air.zig");
 const CodeGen = @import("CodeGen.zig");
+const IntegerBitSet = std.bit_set.IntegerBitSet;
 const Register = bits.Register;
 
 instructions: std.MultiArrayList(Inst).Slice,
@@ -379,19 +380,13 @@ pub const Inst = struct {
         /// update debug line
         dbg_line,
 
-        /// push registers from the callee_preserved_regs
-        /// data is the bitfield of which regs to push
-        /// for example on x86_64, the callee_preserved_regs are [_]Register{ .rcx, .rsi, .rdi, .r8, .r9, .r10, .r11 };    };
-        /// so to push rcx and r8 one would make data 0b00000000_00000000_00000000_00001001 (the first and fourth bits are set)
-        /// ops is unused
-        push_regs_from_callee_preserved_regs,
-
-        /// pop registers from the callee_preserved_regs
-        /// data is the bitfield of which regs to pop
-        /// for example on x86_64, the callee_preserved_regs are [_]Register{ .rcx, .rsi, .rdi, .r8, .r9, .r10, .r11 };    };
-        /// so to pop rcx and r8 one would make data 0b00000000_00000000_00000000_00001001 (the first and fourth bits are set)
-        /// ops is unused
-        pop_regs_from_callee_preserved_regs,
+        /// push registers
+        /// Uses `payload` field with `SaveRegisterList` as payload.
+        push_regs,
+
+        /// pop registers
+        /// Uses `payload` field with `SaveRegisterList` as payload.
+        pop_regs,
     };
     /// The position of an MIR instruction within the `Mir` instructions array.
     pub const Index = u32;
@@ -471,9 +466,51 @@ pub const Inst = struct {
     }
 };
 
-pub const RegsToPushOrPop = struct {
-    regs: u32,
-    disp: u32,
+pub fn RegisterList(comptime Reg: type, comptime registers: []const Reg) type {
+    assert(registers.len <= @bitSizeOf(u32));
+    return struct {
+        bitset: RegBitSet = RegBitSet.initEmpty(),
+
+        const RegBitSet = IntegerBitSet(registers.len);
+        const Self = @This();
+
+        fn getIndexForReg(reg: Reg) RegBitSet.MaskInt {
+            inline for (registers) |cpreg, i| {
+                if (reg.id() == cpreg.id()) return i;
+            }
+            unreachable; // register not in input register list!
+        }
+
+        pub fn push(self: *Self, reg: Reg) void {
+            const index = getIndexForReg(reg);
+            self.bitset.set(index);
+        }
+
+        pub fn isSet(self: Self, reg: Reg) bool {
+            const index = getIndexForReg(reg);
+            return self.bitset.isSet(index);
+        }
+
+        pub fn asInt(self: Self) u32 {
+            return self.bitset.mask;
+        }
+
+        pub fn fromInt(mask: u32) Self {
+            return .{
+                .bitset = RegBitSet{ .mask = @intCast(RegBitSet.MaskInt, mask) },
+            };
+        }
+
+        pub fn count(self: Self) u32 {
+            return @intCast(u32, self.bitset.count());
+        }
+    };
+}
+
+pub const SaveRegisterList = struct {
+    /// Use `RegisterList` to populate.
+    register_list: u32,
+    stack_end: u32,
 };
 
 pub const ImmPair = struct {
src/register_manager.zig
@@ -39,7 +39,7 @@ pub fn RegisterManager(
         /// register is free), the value in that slot is undefined.
         ///
         /// The key must be canonical register.
-        registers: [tracked_registers.len]Air.Inst.Index = undefined,
+        registers: TrackedRegisters = undefined,
         /// Tracks which registers are free (in which case the
         /// corresponding bit is set to 1)
         free_registers: RegisterBitSet = RegisterBitSet.initFull(),
@@ -51,6 +51,7 @@ pub fn RegisterManager(
 
         const Self = @This();
 
+        pub const TrackedRegisters = [tracked_registers.len]Air.Inst.Index;
         pub const RegisterBitSet = StaticBitSet(tracked_registers.len);
 
         fn getFunction(self: *Self) *Function {