Commit 9078cb0197

Jakub Konka <kubkon@jakubkonka.com>
2021-12-23 01:22:07
stage2: add lowering of M encoding
Examples include jmp / call near with memory or register operand like `jmp [rax]`, or even RIP-relative `call [rip + 0x10]`.
1 parent 1167e24
Changed files (1)
src
arch
x86_64
src/arch/x86_64/Emit.zig
@@ -304,25 +304,13 @@ fn mirJmpCall(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
         });
         return;
     }
-    const modrm_ext: u3 = switch (tag) {
-        .jmp_near => 0x4,
-        .call_near => 0x2,
-        else => unreachable,
-    };
     if (ops.reg1 == .none) {
         // JMP/CALL [imm]
         const imm = emit.mir.instructions.items(.data)[inst].imm;
-        const encoder = try Encoder.init(emit.code, 7);
-        encoder.opcode_1byte(0xff);
-        encoder.modRm_SIBDisp0(modrm_ext);
-        encoder.sib_disp32();
-        encoder.imm32(imm);
-        return;
+        return lowerToMEnc(tag, RegisterOrMemory.mem(null, imm), emit.code);
     }
     // JMP/CALL reg
-    const encoder = try Encoder.init(emit.code, 2);
-    encoder.opcode_1byte(0xff);
-    encoder.modRm_direct(modrm_ext, ops.reg1.lowId());
+    return lowerToMEnc(tag, RegisterOrMemory.reg(ops.reg1), emit.code);
 }
 
 const CondType = enum {
@@ -628,7 +616,7 @@ inline fn getOpCode(tag: Tag, enc: Encoding) ?u8 {
     }
 }
 
-inline fn getModRmExt(tag: Tag) u3 {
+inline fn getModRmExt(tag: Tag) ?u3 {
     return switch (tag) {
         .adc => 0x2,
         .add => 0x0,
@@ -639,8 +627,9 @@ inline fn getModRmExt(tag: Tag) u3 {
         .sbb => 0x3,
         .cmp => 0x7,
         .mov => 0x0,
+        .jmp_near => 0x4,
         .call_near => 0x2,
-        else => unreachable,
+        else => null,
     };
 }
 
@@ -692,6 +681,67 @@ fn lowerToDEnc(tag: Tag, imm: i32, code: *std.ArrayList(u8)) InnerError!void {
     encoder.imm32(imm);
 }
 
+fn lowerToMEnc(tag: Tag, reg_or_mem: RegisterOrMemory, code: *std.ArrayList(u8)) InnerError!void {
+    const opc = getOpCode(tag, .m).?;
+    const modrm_ext = getModRmExt(tag).?;
+    switch (reg_or_mem) {
+        .register => |reg| {
+            if (reg.size() != 64) return error.EmitFail;
+            const encoder = try Encoder.init(code, 3);
+            encoder.rex(.{
+                .w = false,
+                .b = reg.isExtended(),
+            });
+            encoder.opcode_1byte(opc);
+            encoder.modRm_direct(modrm_ext, reg.lowId());
+        },
+        .memory => |mem_op| {
+            const encoder = try Encoder.init(code, 8);
+            if (mem_op.reg) |reg| {
+                if (reg.size() != 64) return error.EmitFail;
+                encoder.rex(.{
+                    .w = false,
+                    .b = reg.isExtended(),
+                });
+                encoder.opcode_1byte(opc);
+                if (reg.lowId() == 4) {
+                    if (mem_op.disp == 0) {
+                        encoder.modRm_SIBDisp0(modrm_ext);
+                        encoder.sib_base(reg.lowId());
+                    } else if (immOpSize(mem_op.disp) == 8) {
+                        encoder.modRm_SIBDisp8(modrm_ext);
+                        encoder.sib_baseDisp8(reg.lowId());
+                        encoder.disp8(@intCast(i8, mem_op.disp));
+                    } else {
+                        encoder.modRm_SIBDisp32(modrm_ext);
+                        encoder.sib_baseDisp32(reg.lowId());
+                        encoder.disp32(mem_op.disp);
+                    }
+                } else {
+                    if (mem_op.disp == 0) {
+                        encoder.modRm_indirectDisp0(modrm_ext, reg.lowId());
+                    } else if (immOpSize(mem_op.disp) == 8) {
+                        encoder.modRm_indirectDisp8(modrm_ext, reg.lowId());
+                        encoder.disp8(@intCast(i8, mem_op.disp));
+                    } else {
+                        encoder.modRm_indirectDisp32(modrm_ext, reg.lowId());
+                        encoder.disp32(mem_op.disp);
+                    }
+                }
+            } else {
+                encoder.opcode_1byte(opc);
+                if (mem_op.rip) {
+                    encoder.modRm_RIPDisp32(modrm_ext);
+                } else {
+                    encoder.modRm_SIBDisp0(modrm_ext);
+                    encoder.sib_disp32();
+                }
+                encoder.disp32(mem_op.disp);
+            }
+        },
+    }
+}
+
 fn lowerToTdEnc(tag: Tag, moffs: i64, reg: Register, code: *std.ArrayList(u8)) InnerError!void {
     return lowerToTdFdEnc(tag, reg, moffs, code, true);
 }
@@ -772,7 +822,7 @@ fn lowerToOiEnc(tag: Tag, reg: Register, imm: i64, code: *std.ArrayList(u8)) Inn
 
 fn lowerToMiEnc(tag: Tag, reg_or_mem: RegisterOrMemory, imm: i32, code: *std.ArrayList(u8)) InnerError!void {
     var opc = getOpCode(tag, .mi).?;
-    const modrm_ext = getModRmExt(tag);
+    const modrm_ext = getModRmExt(tag).?;
     switch (reg_or_mem) {
         .register => |dst_reg| {
             if (dst_reg.size() == 8) {
@@ -819,16 +869,25 @@ fn lowerToMiEnc(tag: Tag, reg_or_mem: RegisterOrMemory, imm: i32, code: *std.Arr
                     .b = dst_reg.isExtended(),
                 });
                 encoder.opcode_1byte(opc);
-                if (dst_mem.disp == 0) {
-                    encoder.modRm_indirectDisp0(modrm_ext, dst_reg.lowId());
-                } else if (immOpSize(dst_mem.disp) == 8) {
-                    encoder.modRm_indirectDisp8(modrm_ext, dst_reg.lowId());
-                    encoder.disp8(@intCast(i8, dst_mem.disp));
-                } else {
-                    if (dst_reg.lowId() == 4) {
+                if (dst_reg.lowId() == 4) {
+                    if (dst_mem.disp == 0) {
+                        encoder.modRm_SIBDisp0(modrm_ext);
+                        encoder.sib_base(dst_reg.lowId());
+                    } else if (immOpSize(dst_mem.disp) == 8) {
+                        encoder.modRm_SIBDisp8(modrm_ext);
+                        encoder.sib_baseDisp8(dst_reg.lowId());
+                        encoder.disp8(@intCast(i8, dst_mem.disp));
+                    } else {
                         encoder.modRm_SIBDisp32(modrm_ext);
                         encoder.sib_baseDisp32(dst_reg.lowId());
                         encoder.disp32(dst_mem.disp);
+                    }
+                } else {
+                    if (dst_mem.disp == 0) {
+                        encoder.modRm_indirectDisp0(modrm_ext, dst_reg.lowId());
+                    } else if (immOpSize(dst_mem.disp) == 8) {
+                        encoder.modRm_indirectDisp8(modrm_ext, dst_reg.lowId());
+                        encoder.disp8(@intCast(i8, dst_mem.disp));
                     } else {
                         encoder.modRm_indirectDisp32(modrm_ext, dst_reg.lowId());
                         encoder.disp32(dst_mem.disp);
@@ -886,16 +945,25 @@ fn lowerToRmEnc(
                     .b = src_reg.isExtended(),
                 });
                 encoder.opcode_1byte(opc);
-                if (src_mem.disp == 0) {
-                    encoder.modRm_indirectDisp0(reg.lowId(), src_reg.lowId());
-                } else if (immOpSize(src_mem.disp) == 8) {
-                    encoder.modRm_indirectDisp8(reg.lowId(), src_reg.lowId());
-                    encoder.disp8(@intCast(i8, src_mem.disp));
-                } else {
-                    if (src_reg.lowId() == 4) {
+                if (src_reg.lowId() == 4) {
+                    if (src_mem.disp == 0) {
+                        encoder.modRm_SIBDisp0(reg.lowId());
+                        encoder.sib_base(src_reg.lowId());
+                    } else if (immOpSize(src_mem.disp) == 8) {
+                        encoder.modRm_SIBDisp8(reg.lowId());
+                        encoder.sib_baseDisp8(src_reg.lowId());
+                        encoder.disp8(@intCast(i8, src_mem.disp));
+                    } else {
                         encoder.modRm_SIBDisp32(reg.lowId());
                         encoder.sib_baseDisp32(src_reg.lowId());
                         encoder.disp32(src_mem.disp);
+                    }
+                } else {
+                    if (src_mem.disp == 0) {
+                        encoder.modRm_indirectDisp0(reg.lowId(), src_reg.lowId());
+                    } else if (immOpSize(src_mem.disp) == 8) {
+                        encoder.modRm_indirectDisp8(reg.lowId(), src_reg.lowId());
+                        encoder.disp8(@intCast(i8, src_mem.disp));
                     } else {
                         encoder.modRm_indirectDisp32(reg.lowId(), src_reg.lowId());
                         encoder.disp32(src_mem.disp);
@@ -960,16 +1028,25 @@ fn lowerToMrEnc(
                     .b = dst_reg.isExtended(),
                 });
                 encoder.opcode_1byte(opc);
-                if (dst_mem.disp == 0) {
-                    encoder.modRm_indirectDisp0(reg.lowId(), dst_reg.lowId());
-                } else if (immOpSize(dst_mem.disp) == 8) {
-                    encoder.modRm_indirectDisp8(reg.lowId(), dst_reg.lowId());
-                    encoder.disp8(@intCast(i8, dst_mem.disp));
-                } else {
-                    if (dst_reg.lowId() == 4) {
+                if (dst_reg.lowId() == 4) {
+                    if (dst_mem.disp == 0) {
+                        encoder.modRm_SIBDisp0(reg.lowId());
+                        encoder.sib_base(dst_reg.lowId());
+                    } else if (immOpSize(dst_mem.disp) == 8) {
+                        encoder.modRm_SIBDisp8(reg.lowId());
+                        encoder.sib_baseDisp8(dst_reg.lowId());
+                        encoder.disp8(@intCast(i8, dst_mem.disp));
+                    } else {
                         encoder.modRm_SIBDisp32(reg.lowId());
                         encoder.sib_baseDisp32(dst_reg.lowId());
                         encoder.disp32(dst_mem.disp);
+                    }
+                } else {
+                    if (dst_mem.disp == 0) {
+                        encoder.modRm_indirectDisp0(reg.lowId(), dst_reg.lowId());
+                    } else if (immOpSize(dst_mem.disp) == 8) {
+                        encoder.modRm_indirectDisp8(reg.lowId(), dst_reg.lowId());
+                        encoder.disp8(@intCast(i8, dst_mem.disp));
                     } else {
                         encoder.modRm_indirectDisp32(reg.lowId(), dst_reg.lowId());
                         encoder.disp32(dst_mem.disp);
@@ -1099,7 +1176,7 @@ fn mirArithScaleDst(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void
     if (ops.reg2 == .none) {
         // OP [reg1 + scale*rax + 0], imm32
         var opc = getOpCode(tag, .mi).?;
-        const modrm_ext = getModRmExt(tag);
+        const modrm_ext = getModRmExt(tag).?;
         if (ops.reg1.size() == 8) {
             opc -= 1;
         }
@@ -1153,7 +1230,7 @@ fn mirArithScaleImm(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void
     if (ops.reg1.size() == 8) {
         opc -= 1;
     }
-    const modrm_ext = getModRmExt(tag);
+    const modrm_ext = getModRmExt(tag).?;
     const encoder = try Encoder.init(emit.code, 2);
     encoder.rex(.{
         .w = ops.reg1.size() == 64,
@@ -1641,3 +1718,24 @@ test "lower FD/TD encoding" {
     try lowerToFdEnc(.mov, .al, 0x10, code.buffer());
     try expectEqualHexStrings("\xa0\x10", code.emitted(), "mov al, ds:0x10");
 }
+
+test "lower M encoding" {
+    var code = TestEmitCode.init();
+    defer code.deinit();
+    try lowerToMEnc(.jmp_near, RegisterOrMemory.reg(.r12), code.buffer());
+    try expectEqualHexStrings("\x41\xFF\xE4", code.emitted(), "jmp r12");
+    try lowerToMEnc(.jmp_near, RegisterOrMemory.mem(.r12, 0), code.buffer());
+    try expectEqualHexStrings("\x41\xFF\x24\x24", code.emitted(), "jmp qword ptr [r12]");
+    try lowerToMEnc(.jmp_near, RegisterOrMemory.mem(.r12, 0x10), code.buffer());
+    try expectEqualHexStrings("\x41\xFF\x64\x24\x10", code.emitted(), "jmp qword ptr [r12 + 0x10]");
+    try lowerToMEnc(.jmp_near, RegisterOrMemory.mem(.r12, 0x1000), code.buffer());
+    try expectEqualHexStrings(
+        "\x41\xFF\xA4\x24\x00\x10\x00\x00",
+        code.emitted(),
+        "jmp qword ptr [r12 + 0x1000]",
+    );
+    try lowerToMEnc(.jmp_near, RegisterOrMemory.rip(0x10), code.buffer());
+    try expectEqualHexStrings("\xFF\x25\x10\x00\x00\x00", code.emitted(), "jmp qword ptr [rip + 0x10]");
+    try lowerToMEnc(.jmp_near, RegisterOrMemory.mem(null, 0x10), code.buffer());
+    try expectEqualHexStrings("\xFF\x24\x25\x10\x00\x00\x00", code.emitted(), "jmp qword ptr [ds:0x10]");
+}