Commit 875a16030c

Jakub Konka <kubkon@jakubkonka.com>
2022-05-12 17:40:54
x64: extend Emit to allow for AVX registers
1 parent 70d809e
Changed files (2)
src
arch
src/arch/x86_64/bits.zig
@@ -151,6 +151,11 @@ pub const AvxRegister = enum(u6) {
         };
     }
 
+    /// Returns whether the register is *extended*.
+    pub fn isExtended(self: Register) bool {
+        return @enumToInt(self) & 0x08 != 0;
+    }
+
     /// This returns the 4-bit register ID.
     pub fn id(self: AvxRegister) u4 {
         return @truncate(u4, @enumToInt(self));
src/arch/x86_64/Emit.zig
@@ -25,7 +25,8 @@ const MCValue = @import("CodeGen.zig").MCValue;
 const Mir = @import("Mir.zig");
 const Module = @import("../../Module.zig");
 const Instruction = bits.Instruction;
-const Register = bits.Register;
+const GpRegister = bits.Register;
+const AvxRegister = bits.Register;
 const Type = @import("../../type.zig").Type;
 
 mir: Mir,
@@ -248,7 +249,7 @@ fn mirPushPop(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
     switch (ops.flags) {
         0b00 => {
             // PUSH/POP reg
-            return lowerToOEnc(tag, ops.reg1, emit.code);
+            return lowerToOEnc(tag, .{ .register = ops.reg1 }, emit.code);
         },
         0b01 => {
             // PUSH/POP r/m64
@@ -271,6 +272,7 @@ fn mirPushPop(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
         0b11 => unreachable,
     }
 }
+
 fn mirPushPopRegsFromCalleePreservedRegs(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
     const ops = Mir.Ops.decode(emit.mir.instructions.items(.ops)[inst]);
     const payload = emit.mir.instructions.items(.data)[inst].payload;
@@ -283,9 +285,9 @@ fn mirPushPopRegsFromCalleePreservedRegs(emit: *Emit, tag: Tag, inst: Mir.Inst.I
             try lowerToMrEnc(.mov, RegisterOrMemory.mem(.qword_ptr, .{
                 .disp = @bitCast(u32, -@intCast(i32, disp)),
                 .base = ops.reg1,
-            }), reg.to64(), emit.code);
+            }), .{ .register = reg.to64() }, emit.code);
         } else {
-            try lowerToRmEnc(.mov, reg.to64(), RegisterOrMemory.mem(.qword_ptr, .{
+            try lowerToRmEnc(.mov, .{ .register = reg.to64() }, RegisterOrMemory.mem(.qword_ptr, .{
                 .disp = @bitCast(u32, -@intCast(i32, disp)),
                 .base = ops.reg1,
             }), emit.code);
@@ -319,7 +321,7 @@ fn mirJmpCall(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
                 return lowerToMEnc(tag, RegisterOrMemory.mem(ptr_size, .{ .disp = imm }), emit.code);
             }
             // JMP/CALL reg
-            return lowerToMEnc(tag, RegisterOrMemory.reg(ops.reg1), emit.code);
+            return lowerToMEnc(tag, RegisterOrMemory.reg(.{ .register = ops.reg1 }), emit.code);
         },
         0b10 => {
             // JMP/CALL r/m64
@@ -392,13 +394,13 @@ fn mirCondSetByte(emit: *Emit, mir_tag: Mir.Inst.Tag, inst: Mir.Inst.Index) Inne
         },
         else => unreachable,
     };
-    return lowerToMEnc(tag, RegisterOrMemory.reg(ops.reg1.to8()), emit.code);
+    return lowerToMEnc(tag, RegisterOrMemory.reg(.{ .register = ops.reg1.to8() }), emit.code);
 }
 
 fn mirCondMov(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
     const ops = Mir.Ops.decode(emit.mir.instructions.items(.ops)[inst]);
     if (ops.flags == 0b00) {
-        return lowerToRmEnc(tag, ops.reg1, RegisterOrMemory.reg(ops.reg2), emit.code);
+        return lowerToRmEnc(tag, .{ .register = ops.reg1 }, RegisterOrMemory.reg(.{ .register = ops.reg2 }), emit.code);
     }
     const imm = emit.mir.instructions.items(.data)[inst].imm;
     const ptr_size: Memory.PtrSize = switch (ops.flags) {
@@ -407,7 +409,7 @@ fn mirCondMov(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
         0b10 => .dword_ptr,
         0b11 => .qword_ptr,
     };
-    return lowerToRmEnc(tag, ops.reg1, RegisterOrMemory.mem(ptr_size, .{
+    return lowerToRmEnc(tag, .{ .register = ops.reg1 }, RegisterOrMemory.mem(ptr_size, .{
         .disp = imm,
         .base = ops.reg2,
     }), emit.code);
@@ -428,10 +430,15 @@ fn mirTest(emit: *Emit, inst: Mir.Inst.Index) InnerError!void {
                     // I
                     return lowerToIEnc(.@"test", imm, emit.code);
                 }
-                return lowerToMiEnc(.@"test", RegisterOrMemory.reg(ops.reg1), imm, emit.code);
+                return lowerToMiEnc(.@"test", RegisterOrMemory.reg(.{ .register = ops.reg1 }), imm, emit.code);
             }
             // TEST r/m64, r64
-            return lowerToMrEnc(.@"test", RegisterOrMemory.reg(ops.reg1), ops.reg2, emit.code);
+            return lowerToMrEnc(
+                .@"test",
+                RegisterOrMemory.reg(.{ .register = ops.reg1 }),
+                .{ .register = ops.reg2 },
+                emit.code,
+            );
         },
         else => return emit.fail("TODO more TEST alternatives", .{}),
     }
@@ -471,18 +478,18 @@ fn mirArith(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
                 // mov reg1, imm32
                 // MI
                 const imm = emit.mir.instructions.items(.data)[inst].imm;
-                return lowerToMiEnc(tag, RegisterOrMemory.reg(ops.reg1), imm, emit.code);
+                return lowerToMiEnc(tag, RegisterOrMemory.reg(.{ .register = ops.reg1 }), imm, emit.code);
             }
             // mov reg1, reg2
             // RM
-            return lowerToRmEnc(tag, ops.reg1, RegisterOrMemory.reg(ops.reg2), emit.code);
+            return lowerToRmEnc(tag, .{ .register = ops.reg1 }, RegisterOrMemory.reg(.{ .register = ops.reg2 }), emit.code);
         },
         0b01 => {
             // mov reg1, [reg2 + imm32]
             // RM
             const imm = emit.mir.instructions.items(.data)[inst].imm;
-            const src_reg: ?Register = if (ops.reg2 == .none) null else ops.reg2;
-            return lowerToRmEnc(tag, ops.reg1, RegisterOrMemory.mem(Memory.PtrSize.fromBits(ops.reg1.size()), .{
+            const src_reg: ?GpRegister = if (ops.reg2 == .none) null else ops.reg2;
+            return lowerToRmEnc(tag, .{ .register = ops.reg1 }, RegisterOrMemory.mem(Memory.PtrSize.fromBits(ops.reg1.size()), .{
                 .disp = imm,
                 .base = src_reg,
             }), emit.code);
@@ -497,7 +504,7 @@ fn mirArith(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
             return lowerToMrEnc(tag, RegisterOrMemory.mem(Memory.PtrSize.fromBits(ops.reg2.size()), .{
                 .disp = imm,
                 .base = ops.reg1,
-            }), ops.reg2, emit.code);
+            }), .{ .register = ops.reg2 }, emit.code);
         },
         0b11 => {
             return emit.fail("TODO unused variant: mov reg1, reg2, 0b11", .{});
@@ -523,11 +530,16 @@ fn mirArithMemImm(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
 }
 
 inline fn setRexWRegister(reg: Register) bool {
-    if (reg.size() == 64) return true;
-    return switch (reg) {
-        .ah, .bh, .ch, .dh => true,
-        else => false,
-    };
+    switch (reg) {
+        .avx_register => return false,
+        .register => |r| {
+            if (r.size() == 64) return true;
+            return switch (r) {
+                .ah, .bh, .ch, .dh => true,
+                else => false,
+            };
+        },
+    }
 }
 
 inline fn immOpSize(u_imm: u32) u8 {
@@ -550,7 +562,7 @@ fn mirArithScaleSrc(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void
         .scale = scale,
         .index = .rcx,
     };
-    return lowerToRmEnc(tag, ops.reg1, RegisterOrMemory.mem(Memory.PtrSize.fromBits(ops.reg1.size()), .{
+    return lowerToRmEnc(tag, .{ .register = ops.reg1 }, RegisterOrMemory.mem(Memory.PtrSize.fromBits(ops.reg1.size()), .{
         .disp = imm,
         .base = ops.reg2,
         .scale_index = scale_index,
@@ -578,7 +590,7 @@ fn mirArithScaleDst(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void
         .disp = imm,
         .base = ops.reg1,
         .scale_index = scale_index,
-    }), ops.reg2, emit.code);
+    }), .{ .register = ops.reg2 }, emit.code);
 }
 
 fn mirArithScaleImm(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
@@ -629,22 +641,27 @@ fn mirMovSignExtend(emit: *Emit, inst: Mir.Inst.Index) InnerError!void {
     switch (ops.flags) {
         0b00 => {
             const tag: Tag = if (ops.reg2.size() == 32) .movsxd else .movsx;
-            return lowerToRmEnc(tag, ops.reg1, RegisterOrMemory.reg(ops.reg2), emit.code);
+            return lowerToRmEnc(
+                tag,
+                .{ .register = ops.reg1 },
+                RegisterOrMemory.reg(.{ .register = ops.reg2 }),
+                emit.code,
+            );
         },
         0b01 => {
-            return lowerToRmEnc(.movsx, ops.reg1, RegisterOrMemory.mem(.byte_ptr, .{
+            return lowerToRmEnc(.movsx, .{ .register = ops.reg1 }, RegisterOrMemory.mem(.byte_ptr, .{
                 .disp = imm,
                 .base = ops.reg2,
             }), emit.code);
         },
         0b10 => {
-            return lowerToRmEnc(.movsx, ops.reg1, RegisterOrMemory.mem(.word_ptr, .{
+            return lowerToRmEnc(.movsx, .{ .register = ops.reg1 }, RegisterOrMemory.mem(.word_ptr, .{
                 .disp = imm,
                 .base = ops.reg2,
             }), emit.code);
         },
         0b11 => {
-            return lowerToRmEnc(.movsxd, ops.reg1, RegisterOrMemory.mem(.dword_ptr, .{
+            return lowerToRmEnc(.movsxd, .{ .register = ops.reg1 }, RegisterOrMemory.mem(.dword_ptr, .{
                 .disp = imm,
                 .base = ops.reg2,
             }), emit.code);
@@ -659,16 +676,21 @@ fn mirMovZeroExtend(emit: *Emit, inst: Mir.Inst.Index) InnerError!void {
     const imm = if (ops.flags != 0b00) emit.mir.instructions.items(.data)[inst].imm else undefined;
     switch (ops.flags) {
         0b00 => {
-            return lowerToRmEnc(.movzx, ops.reg1, RegisterOrMemory.reg(ops.reg2), emit.code);
+            return lowerToRmEnc(
+                .movzx,
+                .{ .register = ops.reg1 },
+                RegisterOrMemory.reg(.{ .register = ops.reg2 }),
+                emit.code,
+            );
         },
         0b01 => {
-            return lowerToRmEnc(.movzx, ops.reg1, RegisterOrMemory.mem(.byte_ptr, .{
+            return lowerToRmEnc(.movzx, .{ .register = ops.reg1 }, RegisterOrMemory.mem(.byte_ptr, .{
                 .disp = imm,
                 .base = ops.reg2,
             }), emit.code);
         },
         0b10 => {
-            return lowerToRmEnc(.movzx, ops.reg1, RegisterOrMemory.mem(.word_ptr, .{
+            return lowerToRmEnc(.movzx, .{ .register = ops.reg1 }, RegisterOrMemory.mem(.word_ptr, .{
                 .disp = imm,
                 .base = ops.reg2,
             }), emit.code);
@@ -691,16 +713,16 @@ fn mirMovabs(emit: *Emit, inst: Mir.Inst.Index) InnerError!void {
     if (ops.flags == 0b00) {
         // movabs reg, imm64
         // OI
-        return lowerToOiEnc(.mov, ops.reg1, imm, emit.code);
+        return lowerToOiEnc(.mov, .{ .register = ops.reg1 }, imm, emit.code);
     }
     if (ops.reg1 == .none) {
         // movabs moffs64, rax
         // TD
-        return lowerToTdEnc(.mov, imm, ops.reg2, emit.code);
+        return lowerToTdEnc(.mov, imm, .{ .register = ops.reg2 }, emit.code);
     }
     // movabs rax, moffs64
     // FD
-    return lowerToFdEnc(.mov, ops.reg1, imm, emit.code);
+    return lowerToFdEnc(.mov, .{ .register = ops.reg1 }, imm, emit.code);
 }
 
 fn mirFisttp(emit: *Emit, inst: Mir.Inst.Index) InnerError!void {
@@ -751,18 +773,18 @@ fn mirShift(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
         0b00 => {
             // sal reg1, 1
             // M1
-            return lowerToM1Enc(tag, RegisterOrMemory.reg(ops.reg1), emit.code);
+            return lowerToM1Enc(tag, RegisterOrMemory.reg(.{ .register = ops.reg1 }), emit.code);
         },
         0b01 => {
             // sal reg1, .cl
             // MC
-            return lowerToMcEnc(tag, RegisterOrMemory.reg(ops.reg1), emit.code);
+            return lowerToMcEnc(tag, RegisterOrMemory.reg(.{ .register = ops.reg1 }), emit.code);
         },
         0b10 => {
             // sal reg1, imm8
             // MI
             const imm = @truncate(u8, emit.mir.instructions.items(.data)[inst].imm);
-            return lowerToMiImm8Enc(tag, RegisterOrMemory.reg(ops.reg1), imm, emit.code);
+            return lowerToMiImm8Enc(tag, RegisterOrMemory.reg(.{ .register = ops.reg1 }), imm, emit.code);
         },
         0b11 => {
             return emit.fail("TODO unused variant: SHIFT reg1, 0b11", .{});
@@ -774,7 +796,7 @@ fn mirMulDiv(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
     const ops = Mir.Ops.decode(emit.mir.instructions.items(.ops)[inst]);
     if (ops.reg1 != .none) {
         assert(ops.reg2 == .none);
-        return lowerToMEnc(tag, RegisterOrMemory.reg(ops.reg1), emit.code);
+        return lowerToMEnc(tag, RegisterOrMemory.reg(.{ .register = ops.reg1 }), emit.code);
     }
     assert(ops.reg1 == .none);
     assert(ops.reg2 != .none);
@@ -797,24 +819,35 @@ fn mirIMulComplex(emit: *Emit, inst: Mir.Inst.Index) InnerError!void {
     const ops = Mir.Ops.decode(emit.mir.instructions.items(.ops)[inst]);
     switch (ops.flags) {
         0b00 => {
-            return lowerToRmEnc(.imul, ops.reg1, RegisterOrMemory.reg(ops.reg2), emit.code);
+            return lowerToRmEnc(
+                .imul,
+                .{ .register = ops.reg1 },
+                RegisterOrMemory.reg(.{ .register = ops.reg2 }),
+                emit.code,
+            );
         },
         0b01 => {
             const imm = emit.mir.instructions.items(.data)[inst].imm;
-            const src_reg: ?Register = if (ops.reg2 == .none) null else ops.reg2;
-            return lowerToRmEnc(.imul, ops.reg1, RegisterOrMemory.mem(.qword_ptr, .{
+            const src_reg: ?GpRegister = if (ops.reg2 == .none) null else ops.reg2;
+            return lowerToRmEnc(.imul, .{ .register = ops.reg1 }, RegisterOrMemory.mem(.qword_ptr, .{
                 .disp = imm,
                 .base = src_reg,
             }), emit.code);
         },
         0b10 => {
             const imm = emit.mir.instructions.items(.data)[inst].imm;
-            return lowerToRmiEnc(.imul, ops.reg1, RegisterOrMemory.reg(ops.reg2), imm, emit.code);
+            return lowerToRmiEnc(
+                .imul,
+                .{ .register = ops.reg1 },
+                RegisterOrMemory.reg(.{ .register = ops.reg2 }),
+                imm,
+                emit.code,
+            );
         },
         0b11 => {
             const payload = emit.mir.instructions.items(.data)[inst].payload;
             const imm_pair = emit.mir.extraData(Mir.ImmPair, payload).data;
-            return lowerToRmiEnc(.imul, ops.reg1, RegisterOrMemory.mem(.qword_ptr, .{
+            return lowerToRmiEnc(.imul, .{ .register = ops.reg1 }, RegisterOrMemory.mem(.qword_ptr, .{
                 .disp = imm_pair.dest_off,
                 .base = ops.reg2,
             }), imm_pair.operand, emit.code);
@@ -842,10 +875,10 @@ fn mirLea(emit: *Emit, inst: Mir.Inst.Index) InnerError!void {
             // lea reg1, [reg2 + imm32]
             // RM
             const imm = emit.mir.instructions.items(.data)[inst].imm;
-            const src_reg: ?Register = if (ops.reg2 == .none) null else ops.reg2;
+            const src_reg: ?GpRegister = if (ops.reg2 == .none) null else ops.reg2;
             return lowerToRmEnc(
                 .lea,
-                ops.reg1,
+                .{ .register = ops.reg1 },
                 RegisterOrMemory.mem(Memory.PtrSize.fromBits(ops.reg1.size()), .{
                     .disp = imm,
                     .base = src_reg,
@@ -859,7 +892,7 @@ fn mirLea(emit: *Emit, inst: Mir.Inst.Index) InnerError!void {
             const start_offset = emit.code.items.len;
             try lowerToRmEnc(
                 .lea,
-                ops.reg1,
+                .{ .register = ops.reg1 },
                 RegisterOrMemory.rip(Memory.PtrSize.fromBits(ops.reg1.size()), 0),
                 emit.code,
             );
@@ -873,14 +906,14 @@ fn mirLea(emit: *Emit, inst: Mir.Inst.Index) InnerError!void {
         0b10 => {
             // lea reg, [rbp + rcx + imm32]
             const imm = emit.mir.instructions.items(.data)[inst].imm;
-            const src_reg: ?Register = if (ops.reg2 == .none) null else ops.reg2;
+            const src_reg: ?GpRegister = if (ops.reg2 == .none) null else ops.reg2;
             const scale_index = ScaleIndex{
                 .scale = 0,
                 .index = .rcx,
             };
             return lowerToRmEnc(
                 .lea,
-                ops.reg1,
+                .{ .register = ops.reg1 },
                 RegisterOrMemory.mem(Memory.PtrSize.fromBits(ops.reg1.size()), .{
                     .disp = imm,
                     .base = src_reg,
@@ -903,7 +936,7 @@ fn mirLeaPie(emit: *Emit, inst: Mir.Inst.Index) InnerError!void {
     // RM
     try lowerToRmEnc(
         .lea,
-        ops.reg1,
+        .{ .register = ops.reg1 },
         RegisterOrMemory.rip(Memory.PtrSize.fromBits(ops.reg1.size()), 0),
         emit.code,
     );
@@ -1489,11 +1522,11 @@ inline fn getModRmExt(tag: Tag) ?u3 {
 
 const ScaleIndex = struct {
     scale: u2,
-    index: Register,
+    index: GpRegister,
 };
 
 const Memory = struct {
-    base: ?Register,
+    base: ?GpRegister,
     rip: bool = false,
     disp: u32,
     ptr_size: PtrSize,
@@ -1595,6 +1628,40 @@ fn encodeImm(encoder: Encoder, imm: u32, size: u64) void {
     }
 }
 
+const Register = union(enum) {
+    register: GpRegister,
+    avx_register: AvxRegister,
+
+    fn reg(register: GpRegister) Register {
+        return .{ .register = register };
+    }
+
+    fn avxReg(register: AvxRegister) Register {
+        return .{ .avx_register = register };
+    }
+
+    fn lowId(register: Register) u3 {
+        return switch (register) {
+            .register => |r| r.lowId(),
+            .avx_register => |r| r.lowId(),
+        };
+    }
+
+    fn size(register: Register) u64 {
+        return switch (register) {
+            .register => |r| r.size(),
+            .avx_register => |r| r.size(),
+        };
+    }
+
+    fn isExtended(register: Register) bool {
+        return switch (register) {
+            .register => |r| r.isExtended(),
+            .avx_register => |r| r.isExtended(),
+        };
+    }
+};
+
 const RegisterOrMemory = union(enum) {
     register: Register,
     memory: Memory,
@@ -1605,7 +1672,7 @@ const RegisterOrMemory = union(enum) {
 
     fn mem(ptr_size: Memory.PtrSize, args: struct {
         disp: u32,
-        base: ?Register = null,
+        base: ?GpRegister = null,
         scale_index: ?ScaleIndex = null,
     }) RegisterOrMemory {
         return .{