Commit 896472c20e

Andrew Kelley <andrew@ziglang.org>
2020-07-18 00:51:15
stage2: implement register copying
1 parent ef9aeb6
Changed files (4)
src-self-hosted
test
src-self-hosted/codegen/x86.zig
@@ -25,6 +25,20 @@ pub const Register = enum(u8) {
     pub fn id(self: @This()) u3 {
         return @truncate(u3, @enumToInt(self));
     }
+
+    /// Returns the index into `callee_preserved_regs`.
+    pub fn allocIndex(self: Register) ?u4 {
+        return switch (self) {
+            .eax, .ax, .al => 0,
+            .ecx, .cx, .cl => 1,
+            .edx, .dx, .dl => 2,
+            .esi, .si  => 3,
+            .edi, .di => 4,
+            else => null,
+        };
+    }
 };
 
 // zig fmt: on
+
+pub const callee_preserved_regs = [_]Register{ .eax, .ecx, .edx, .esi, .edi };
src-self-hosted/codegen/x86_64.zig
@@ -38,7 +38,7 @@ pub const Register = enum(u8) {
     r8b, r9b, r10b, r11b, r12b, r13b, r14b, r15b,
 
     /// Returns the bit-width of the register.
-    pub fn size(self: @This()) u7 {
+    pub fn size(self: Register) u7 {
         return switch (@enumToInt(self)) {
             0...15 => 64,
             16...31 => 32,
@@ -53,7 +53,7 @@ pub const Register = enum(u8) {
     /// other variant of access to those registers, such as r8b, r15d, and so
     /// on. This is needed because access to these registers requires special
     /// handling via the REX prefix, via the B or R bits, depending on context.
-    pub fn isExtended(self: @This()) bool {
+    pub fn isExtended(self: Register) bool {
         return @enumToInt(self) & 0x08 != 0;
     }
 
@@ -62,12 +62,29 @@ pub const Register = enum(u8) {
     /// an instruction (@see isExtended), and requires special handling. The
     /// lower three bits are often embedded directly in instructions (such as
     /// the B8 variant of moves), or used in R/M bytes.
-    pub fn id(self: @This()) u4 {
+    pub fn id(self: Register) u4 {
         return @truncate(u4, @enumToInt(self));
     }
+
+    /// Returns the index into `callee_preserved_regs`.
+    pub fn allocIndex(self: Register) ?u4 {
+        return switch (self) {
+            .rax, .eax, .ax, .al => 0,
+            .rcx, .ecx, .cx, .cl => 1,
+            .rdx, .edx, .dx, .dl => 2,
+            .rsi, .esi, .si  => 3,
+            .rdi, .edi, .di => 4,
+            .r8, .r8d, .r8w, .r8b => 5,
+            .r9, .r9d, .r9w, .r9b => 6,
+            .r10, .r10d, .r10w, .r10b => 7,
+            .r11, .r11d, .r11w, .r11b => 8,
+            else => null,
+        };
+    }
 };
 
 // zig fmt: on
 
 /// These registers belong to the called function.
-pub const callee_preserved = [_]Register{ rax, rcx, rdx, rsi, rdi, r8, r9, r10, r11 };
+pub const callee_preserved_regs = [_]Register{ .rax, .rcx, .rdx, .rsi, .rdi, .r8, .r9, .r10, .r11 };
+pub const c_abi_int_param_regs = [_]Register{ .rdi, .rsi, .rdx, .rcx, .r8, .r9 };
src-self-hosted/codegen.zig
@@ -11,8 +11,6 @@ const ErrorMsg = Module.ErrorMsg;
 const Target = std.Target;
 const Allocator = mem.Allocator;
 const trace = @import("tracy.zig").trace;
-const x86_64 = @import("codegen/x86_64.zig");
-const x86 = @import("codegen/x86.zig");
 
 /// The codegen-related data that is stored in `ir.Inst.Block` instructions.
 pub const BlockData = struct {
@@ -232,7 +230,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
             /// The constant was emitted into the code, at this offset.
             embedded_in_code: usize,
             /// The value is in a target-specific register.
-            register: Reg,
+            register: Register,
             /// The value is in memory at a hard-coded address.
             memory: u64,
             /// The value is one of the stack variables.
@@ -280,9 +278,8 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
 
         const Branch = struct {
             inst_table: std.AutoHashMapUnmanaged(*ir.Inst, MCValue) = .{},
-
-            /// The key is an enum value of an arch-specific register.
-            registers: std.AutoHashMapUnmanaged(usize, RegisterAllocation) = .{},
+            registers: std.AutoHashMapUnmanaged(Register, RegisterAllocation) = .{},
+            free_registers: FreeRegInt = std.math.maxInt(FreeRegInt),
 
             /// Maps offset to what is stored there.
             stack: std.AutoHashMapUnmanaged(usize, StackAllocation) = .{},
@@ -292,6 +289,20 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
             /// to place a new stack allocation, it goes here, and then bumps `max_end_stack`.
             next_stack_offset: u32 = 0,
 
+            fn markRegUsed(self: *Branch, reg: Register) void {
+                const index = reg.allocIndex() orelse return;
+                const ShiftInt = std.math.Log2Int(FreeRegInt);
+                const shift = @intCast(ShiftInt, index);
+                self.free_registers &= ~(@as(FreeRegInt, 1) << shift);
+            }
+
+            fn markRegFree(self: *Branch, reg: Register) void {
+                const index = reg.allocIndex() orelse return;
+                const ShiftInt = std.math.Log2Int(FreeRegInt);
+                const shift = @intCast(ShiftInt, index);
+                self.free_registers |= @as(FreeRegInt, 1) << shift;
+            }
+
             fn deinit(self: *Branch, gpa: *Allocator) void {
                 self.inst_table.deinit(gpa);
                 self.registers.deinit(gpa);
@@ -516,7 +527,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
                 // Both operands cannot be memory.
                 src_inst = op_rhs;
                 if (lhs.isMemory() and rhs.isMemory()) {
-                    dst_mcv = try self.moveToNewRegister(op_lhs);
+                    dst_mcv = try self.copyToNewRegister(op_lhs);
                     src_mcv = rhs;
                 } else {
                     dst_mcv = lhs;
@@ -527,7 +538,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
                 // Both operands cannot be memory.
                 src_inst = op_lhs;
                 if (lhs.isMemory() and rhs.isMemory()) {
-                    dst_mcv = try self.moveToNewRegister(op_rhs);
+                    dst_mcv = try self.copyToNewRegister(op_rhs);
                     src_mcv = lhs;
                 } else {
                     dst_mcv = rhs;
@@ -535,11 +546,11 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
                 }
             } else {
                 if (lhs.isMemory()) {
-                    dst_mcv = try self.moveToNewRegister(op_lhs);
+                    dst_mcv = try self.copyToNewRegister(op_lhs);
                     src_mcv = rhs;
                     src_inst = op_rhs;
                 } else {
-                    dst_mcv = try self.moveToNewRegister(op_rhs);
+                    dst_mcv = try self.copyToNewRegister(op_rhs);
                     src_mcv = lhs;
                     src_inst = op_lhs;
                 }
@@ -552,7 +563,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
             switch (src_mcv) {
                 .immediate => |imm| {
                     if (imm > std.math.maxInt(u31)) {
-                        src_mcv = try self.moveToNewRegister(src_inst);
+                        src_mcv = try self.copyToNewRegister(src_inst);
                     }
                 },
                 else => {},
@@ -614,9 +625,26 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
         }
 
         fn genArg(self: *Self, inst: *ir.Inst.Arg) !MCValue {
-            const i = self.arg_index;
+            if (FreeRegInt == u0) {
+                return self.fail(inst.base.src, "TODO implement Register enum for {}", .{self.target.cpu.arch});
+            }
+            if (inst.base.isUnused())
+                return MCValue.dead;
+
+            const branch = &self.branch_stack.items[self.branch_stack.items.len - 1];
+            try branch.registers.ensureCapacity(self.gpa, branch.registers.items().len + 1);
+
+            const result = self.args[self.arg_index];
             self.arg_index += 1;
-            return self.args[i];
+
+            switch (result) {
+                .register => |reg| {
+                    branch.registers.putAssumeCapacityNoClobber(reg, .{ .inst = &inst.base });
+                    branch.markRegUsed(reg);
+                },
+                else => {},
+            }
+            return result;
         }
 
         fn genBreakpoint(self: *Self, src: usize) !MCValue {
@@ -737,7 +765,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
                     // Either one, but not both, can be a memory operand.
                     // Source operand can be an immediate, 8 bits or 32 bits.
                     const dst_mcv = if (lhs.isImmediate() or (lhs.isMemory() and rhs.isMemory()))
-                        try self.moveToNewRegister(inst.args.lhs)
+                        try self.copyToNewRegister(inst.args.lhs)
                     else
                         lhs;
                     // This instruction supports only signed 32-bit immediates at most.
@@ -949,7 +977,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
             }
         }
 
-        fn genSetReg(self: *Self, src: usize, reg: Reg, mcv: MCValue) error{ CodegenFail, OutOfMemory }!void {
+        fn genSetReg(self: *Self, src: usize, reg: Register, mcv: MCValue) error{ CodegenFail, OutOfMemory }!void {
             switch (arch) {
                 .x86_64 => switch (mcv) {
                     .dead => unreachable,
@@ -1171,9 +1199,22 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
             }
         }
 
-        fn moveToNewRegister(self: *Self, inst: *ir.Inst) !MCValue {
+        /// Does not "move" the instruction.
+        fn copyToNewRegister(self: *Self, inst: *ir.Inst) !MCValue {
             const branch = &self.branch_stack.items[self.branch_stack.items.len - 1];
-            return self.fail(inst.src, "TODO implement moveToNewRegister", .{});
+            try branch.registers.ensureCapacity(self.gpa, branch.registers.items().len + 1);
+            try branch.inst_table.ensureCapacity(self.gpa, branch.inst_table.items().len + 1);
+
+            const free_index = @ctz(FreeRegInt, branch.free_registers);
+            if (free_index >= callee_preserved_regs.len)
+                return self.fail(inst.src, "TODO implement spilling register to stack", .{});
+            branch.free_registers &= ~(@as(FreeRegInt, 1) << free_index);
+            const reg = callee_preserved_regs[free_index];
+            branch.registers.putAssumeCapacityNoClobber(reg, .{ .inst = inst });
+            const old_mcv = branch.inst_table.get(inst).?;
+            const new_mcv: MCValue = .{ .register = reg };
+            try self.genSetReg(inst.src, reg, old_mcv);
+            return new_mcv;
         }
 
         /// If the MCValue is an immediate, and it does not fit within this type,
@@ -1194,7 +1235,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
                         },
                     });
                     if (imm >= std.math.maxInt(U)) {
-                        return self.moveToNewRegister(inst);
+                        return self.copyToNewRegister(inst);
                     }
                 },
                 else => {},
@@ -1249,15 +1290,14 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
                             var next_int_reg: usize = 0;
                             var next_stack_offset: u32 = 0;
 
-                            const integer_registers = [_]Reg{ .rdi, .rsi, .rdx, .rcx, .r8, .r9 };
                             for (param_types) |ty, i| {
                                 switch (ty.zigTypeTag()) {
                                     .Bool, .Int => {
-                                        if (next_int_reg >= integer_registers.len) {
+                                        if (next_int_reg >= c_abi_int_param_regs.len) {
                                             results[i] = .{ .stack_offset = next_stack_offset };
                                             next_stack_offset += @intCast(u32, ty.abiSize(self.target.*));
                                         } else {
-                                            results[i] = .{ .register = integer_registers[next_int_reg] };
+                                            results[i] = .{ .register = c_abi_int_param_regs[next_int_reg] };
                                             next_int_reg += 1;
                                         }
                                     },
@@ -1280,14 +1320,26 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
             return error.CodegenFail;
         }
 
-        const Reg = switch (arch) {
-            .i386 => x86.Register,
-            .x86_64 => x86_64.Register,
-            else => enum { dummy },
+        usingnamespace switch (arch) {
+            .i386 => @import("codegen/x86.zig"),
+            .x86_64 => @import("codegen/x86_64.zig"),
+            else => struct {
+                pub const Register = enum {
+                    dummy,
+
+                    pub fn allocIndex(self: Register) ?u4 {
+                        return null;
+                    }
+                };
+                pub const callee_preserved_regs = [_]Register{};
+            },
         };
 
-        fn parseRegName(name: []const u8) ?Reg {
-            return std.meta.stringToEnum(Reg, name);
+        /// An integer whose bits represent all the registers and whether they are free.
+        const FreeRegInt = @Type(.{ .Int = .{ .is_signed = false, .bits = callee_preserved_regs.len } });
+
+        fn parseRegName(name: []const u8) ?Register {
+            return std.meta.stringToEnum(Register, name);
         }
     };
 }
test/stage2/compare_output.zig
@@ -169,9 +169,8 @@ pub fn addCases(ctx: *TestContext) !void {
         ,
             "",
         );
-    }
-    {
-        var case = ctx.exe("assert function", linux_x64);
+
+        // Tests the assert() function.
         case.addCompareOutput(
             \\export fn _start() noreturn {
             \\    add(3, 4);
@@ -199,15 +198,21 @@ pub fn addCases(ctx: *TestContext) !void {
         ,
             "",
         );
+
+        // Tests copying a register. For the `c = a + b`, it has to
+        // preserve both a and b, because they are both used later.
         case.addCompareOutput(
             \\export fn _start() noreturn {
-            \\    add(100, 200);
+            \\    add(3, 4);
             \\
             \\    exit();
             \\}
             \\
             \\fn add(a: u32, b: u32) void {
-            \\    assert(a + b == 300);
+            \\    const c = a + b; // 7
+            \\    const d = a + c; // 10
+            \\    const e = d + b; // 14
+            \\    assert(e == 14);
             \\}
             \\
             \\pub fn assert(ok: bool) void {