Commit b9a6f81d1a

Jakub Konka <kubkon@jakubkonka.com>
2021-12-22 18:53:20
stage2: add lowering fn for OI encoding
Implement movabs using OI generic encoding.
1 parent 2b5de94
Changed files (1)
src
arch
x86_64
src/arch/x86_64/Emit.zig
@@ -541,61 +541,64 @@ const Encoding = enum {
 
     /// OP r64, r/m64
     rm,
-};
 
-const OpCode = struct {
-    opc: u8,
-    /// Only used if `Encoding == .mi`.
-    modrm_ext: u3,
+    /// OP r64, imm64
+    oi,
 };
 
-inline fn getOpCode(tag: Mir.Inst.Tag, enc: Encoding) OpCode {
+inline fn getOpCode(tag: Mir.Inst.Tag, enc: Encoding) u8 {
     switch (enc) {
         .mi => return switch (tag) {
-            .adc => .{ .opc = 0x81, .modrm_ext = 0x2 },
-            .add => .{ .opc = 0x81, .modrm_ext = 0x0 },
-            .sub => .{ .opc = 0x81, .modrm_ext = 0x5 },
-            .xor => .{ .opc = 0x81, .modrm_ext = 0x6 },
-            .@"and" => .{ .opc = 0x81, .modrm_ext = 0x4 },
-            .@"or" => .{ .opc = 0x81, .modrm_ext = 0x1 },
-            .sbb => .{ .opc = 0x81, .modrm_ext = 0x3 },
-            .cmp => .{ .opc = 0x81, .modrm_ext = 0x7 },
-            .mov => .{ .opc = 0xc7, .modrm_ext = 0x0 },
+            .adc, .add, .sub, .xor, .@"and", .@"or", .sbb, .cmp => 0x81,
+            .mov => 0xc7,
             else => unreachable,
         },
-        .mr => {
-            const opc: u8 = switch (tag) {
-                .adc => 0x11,
-                .add => 0x01,
-                .sub => 0x29,
-                .xor => 0x31,
-                .@"and" => 0x21,
-                .@"or" => 0x09,
-                .sbb => 0x19,
-                .cmp => 0x39,
-                .mov => 0x89,
-                else => unreachable,
-            };
-            return .{ .opc = opc, .modrm_ext = undefined };
+        .mr => return switch (tag) {
+            .adc => 0x11,
+            .add => 0x01,
+            .sub => 0x29,
+            .xor => 0x31,
+            .@"and" => 0x21,
+            .@"or" => 0x09,
+            .sbb => 0x19,
+            .cmp => 0x39,
+            .mov => 0x89,
+            else => unreachable,
         },
-        .rm => {
-            const opc: u8 = switch (tag) {
-                .adc => 0x13,
-                .add => 0x03,
-                .sub => 0x2b,
-                .xor => 0x33,
-                .@"and" => 0x23,
-                .@"or" => 0x0b,
-                .sbb => 0x1b,
-                .cmp => 0x3b,
-                .mov => 0x8b,
-                else => unreachable,
-            };
-            return .{ .opc = opc, .modrm_ext = undefined };
+        .rm => return switch (tag) {
+            .adc => 0x13,
+            .add => 0x03,
+            .sub => 0x2b,
+            .xor => 0x33,
+            .@"and" => 0x23,
+            .@"or" => 0x0b,
+            .sbb => 0x1b,
+            .cmp => 0x3b,
+            .mov => 0x8b,
+            else => unreachable,
+        },
+        .oi => return switch (tag) {
+            .mov => 0xb8,
+            else => unreachable,
         },
     }
 }
 
+inline fn getMiModRmExt(tag: Mir.Inst.Tag) u3 {
+    return switch (tag) {
+        .adc => 0x2,
+        .add => 0x0,
+        .sub => 0x5,
+        .xor => 0x6,
+        .@"and" => 0x4,
+        .@"or" => 0x1,
+        .sbb => 0x3,
+        .cmp => 0x7,
+        .mov => 0x0,
+        else => unreachable,
+    };
+}
+
 const ScaleIndexBase = struct {
     scale: u2,
     index_reg: ?Register,
@@ -626,16 +629,56 @@ const RegisterOrMemory = union(enum) {
     }
 };
 
+fn lowerToOiEnc(
+    tag: Mir.Inst.Tag,
+    reg: Register,
+    imm: i64,
+    code: *std.ArrayList(u8),
+) InnerError!void {
+    var opc = getOpCode(tag, .oi);
+    if (reg.size() != immOpSize(imm)) return error.EmitFail;
+    if (reg.size() == 8) {
+        opc -= 8;
+    }
+    const encoder = try Encoder.init(code, 10);
+    encoder.rex(.{
+        .w = reg.size() == 64,
+        .b = reg.isExtended(),
+    });
+    encoder.opcode_withReg(opc, reg.lowId());
+    switch (reg.size()) {
+        8 => {
+            const imm8 = try math.cast(i8, imm);
+            encoder.imm8(imm8);
+        },
+        16 => {
+            const imm16 = try math.cast(i16, imm);
+            encoder.imm16(imm16);
+        },
+        32 => {
+            const imm32 = try math.cast(i32, imm);
+            encoder.imm32(imm32);
+        },
+        64 => {
+            encoder.imm64(@bitCast(u64, imm));
+        },
+        else => unreachable,
+    }
+}
+
 fn lowerToMiEnc(
     tag: Mir.Inst.Tag,
     reg_or_mem: RegisterOrMemory,
     imm: i32,
     code: *std.ArrayList(u8),
 ) InnerError!void {
-    const opcode = getOpCode(tag, .mi);
+    var opc = getOpCode(tag, .mi);
+    const modrm_ext = getMiModRmExt(tag);
     switch (reg_or_mem) {
         .register => |dst_reg| {
-            const opc: u8 = if (dst_reg.size() == 8) opcode.opc - 1 else opcode.opc;
+            if (dst_reg.size() == 8) {
+                opc -= 1;
+            }
             const encoder = try Encoder.init(code, 7);
             if (dst_reg.size() == 16) {
                 // 0x66 prefix switches to the non-default size; here we assume a switch from
@@ -648,7 +691,7 @@ fn lowerToMiEnc(
                 .b = dst_reg.isExtended(),
             });
             encoder.opcode_1byte(opc);
-            encoder.modRm_direct(opcode.modrm_ext, dst_reg.lowId());
+            encoder.modRm_direct(modrm_ext, dst_reg.lowId());
             switch (dst_reg.size()) {
                 8 => {
                     const imm8 = try math.cast(i8, imm);
@@ -676,25 +719,25 @@ fn lowerToMiEnc(
                     .w = false,
                     .b = dst_reg.isExtended(),
                 });
-                encoder.opcode_1byte(opcode.opc);
+                encoder.opcode_1byte(opc);
                 if (dst_mem.disp == 0) {
-                    encoder.modRm_indirectDisp0(opcode.modrm_ext, dst_reg.lowId());
+                    encoder.modRm_indirectDisp0(modrm_ext, dst_reg.lowId());
                 } else if (immOpSize(dst_mem.disp) == 8) {
-                    encoder.modRm_indirectDisp8(opcode.modrm_ext, dst_reg.lowId());
+                    encoder.modRm_indirectDisp8(modrm_ext, dst_reg.lowId());
                     encoder.disp8(@intCast(i8, dst_mem.disp));
                 } else {
                     if (dst_reg.lowId() == 4) {
-                        encoder.modRm_SIBDisp32(opcode.modrm_ext);
+                        encoder.modRm_SIBDisp32(modrm_ext);
                         encoder.sib_baseDisp32(dst_reg.lowId());
                         encoder.disp32(dst_mem.disp);
                     } else {
-                        encoder.modRm_indirectDisp32(opcode.modrm_ext, dst_reg.lowId());
+                        encoder.modRm_indirectDisp32(modrm_ext, dst_reg.lowId());
                         encoder.disp32(dst_mem.disp);
                     }
                 }
             } else {
-                encoder.opcode_1byte(opcode.opc);
-                encoder.modRm_SIBDisp0(opcode.modrm_ext);
+                encoder.opcode_1byte(opc);
+                encoder.modRm_SIBDisp0(modrm_ext);
                 encoder.sib_disp32();
                 encoder.disp32(dst_mem.disp);
             }
@@ -709,8 +752,10 @@ fn lowerToRmEnc(
     reg_or_mem: RegisterOrMemory,
     code: *std.ArrayList(u8),
 ) InnerError!void {
-    const opcode = getOpCode(tag, .rm);
-    const opc: u8 = if (reg.size() == 8) opcode.opc - 1 else opcode.opc;
+    var opc = getOpCode(tag, .rm);
+    if (reg.size() == 8) {
+        opc -= 1;
+    }
     switch (reg_or_mem) {
         .register => |src_reg| {
             if (reg.size() != src_reg.size()) return error.EmitFail;
@@ -779,8 +824,10 @@ fn lowerToMrEnc(
     // * reg is 32bit - dword ptr
     // * reg is 16bit - word ptr
     // * reg is 8bit - byte ptr
-    const opcode = getOpCode(tag, .mr);
-    const opc: u8 = if (reg.size() == 8) opcode.opc - 1 else opcode.opc;
+    var opc = getOpCode(tag, .mr);
+    if (reg.size() == 8) {
+        opc -= 1;
+    }
     switch (reg_or_mem) {
         .register => |dst_reg| {
             if (dst_reg.size() != reg.size()) return error.EmitFail;
@@ -890,7 +937,7 @@ fn mirArith(emit: *Emit, tag: Mir.Inst.Tag, inst: Mir.Inst.Index) InnerError!voi
     }
 }
 
-fn immOpSize(imm: i32) u8 {
+fn immOpSize(imm: i64) u8 {
     blk: {
         _ = math.cast(i8, imm) catch break :blk;
         return 8;
@@ -899,15 +946,21 @@ fn immOpSize(imm: i32) u8 {
         _ = math.cast(i16, imm) catch break :blk;
         return 16;
     }
-    return 32;
+    blk: {
+        _ = math.cast(i32, imm) catch break :blk;
+        return 32;
+    }
+    return 64;
 }
 
 fn mirArithScaleSrc(emit: *Emit, tag: Mir.Inst.Tag, inst: Mir.Inst.Index) InnerError!void {
     const ops = Mir.Ops.decode(emit.mir.instructions.items(.ops)[inst]);
     const scale = ops.flags;
     // OP reg1, [reg2 + scale*rcx + imm32]
-    const opcode = getOpCode(tag, .rm);
-    const opc = if (ops.reg1.size() == 8) opcode.opc - 1 else opcode.opc;
+    var opc = getOpCode(tag, .rm);
+    if (ops.reg1.size() == 8) {
+        opc -= 1;
+    }
     const imm = emit.mir.instructions.items(.data)[inst].imm;
     const encoder = try Encoder.init(emit.code, 8);
     encoder.rex(.{
@@ -934,15 +987,18 @@ fn mirArithScaleDst(emit: *Emit, tag: Mir.Inst.Tag, inst: Mir.Inst.Index) InnerE
 
     if (ops.reg2 == .none) {
         // OP [reg1 + scale*rax + 0], imm32
-        const opcode = getOpCode(tag, .mi);
-        const opc = if (ops.reg1.size() == 8) opcode.opc - 1 else opcode.opc;
+        var opc = getOpCode(tag, .mi);
+        const modrm_ext = getMiModRmExt(tag);
+        if (ops.reg1.size() == 8) {
+            opc -= 1;
+        }
         const encoder = try Encoder.init(emit.code, 8);
         encoder.rex(.{
             .w = ops.reg1.size() == 64,
             .b = ops.reg1.isExtended(),
         });
         encoder.opcode_1byte(opc);
-        encoder.modRm_SIBDisp0(opcode.modrm_ext);
+        encoder.modRm_SIBDisp0(modrm_ext);
         encoder.sib_scaleIndexBase(scale, Register.rax.lowId(), ops.reg1.lowId());
         if (imm <= math.maxInt(i8)) {
             encoder.imm8(@intCast(i8, imm));
@@ -955,8 +1011,10 @@ fn mirArithScaleDst(emit: *Emit, tag: Mir.Inst.Tag, inst: Mir.Inst.Index) InnerE
     }
 
     // OP [reg1 + scale*rax + imm32], reg2
-    const opcode = getOpCode(tag, .mr);
-    const opc = if (ops.reg1.size() == 8) opcode.opc - 1 else opcode.opc;
+    var opc = getOpCode(tag, .mr);
+    if (ops.reg1.size() == 8) {
+        opc -= 1;
+    }
     const encoder = try Encoder.init(emit.code, 8);
     encoder.rex(.{
         .w = ops.reg1.size() == 64,
@@ -980,8 +1038,11 @@ fn mirArithScaleImm(emit: *Emit, tag: Mir.Inst.Tag, inst: Mir.Inst.Index) InnerE
     const scale = ops.flags;
     const payload = emit.mir.instructions.items(.data)[inst].payload;
     const imm_pair = emit.mir.extraData(Mir.ImmPair, payload).data;
-    const opcode = getOpCode(tag, .mi);
-    const opc = if (ops.reg1.size() == 8) opcode.opc - 1 else opcode.opc;
+    var opc = getOpCode(tag, .mi);
+    if (ops.reg1.size() == 8) {
+        opc -= 1;
+    }
+    const modrm_ext = getMiModRmExt(tag);
     const encoder = try Encoder.init(emit.code, 2);
     encoder.rex(.{
         .w = ops.reg1.size() == 64,
@@ -989,11 +1050,11 @@ fn mirArithScaleImm(emit: *Emit, tag: Mir.Inst.Tag, inst: Mir.Inst.Index) InnerE
     });
     encoder.opcode_1byte(opc);
     if (imm_pair.dest_off <= math.maxInt(i8)) {
-        encoder.modRm_SIBDisp8(opcode.modrm_ext);
+        encoder.modRm_SIBDisp8(modrm_ext);
         encoder.sib_scaleIndexBaseDisp8(scale, Register.rax.lowId(), ops.reg1.lowId());
         encoder.disp8(@intCast(i8, imm_pair.dest_off));
     } else {
-        encoder.modRm_SIBDisp32(opcode.modrm_ext);
+        encoder.modRm_SIBDisp32(modrm_ext);
         encoder.sib_scaleIndexBaseDisp32(scale, Register.rax.lowId(), ops.reg1.lowId());
         encoder.disp32(imm_pair.dest_off);
     }
@@ -1005,21 +1066,19 @@ fn mirMovabs(emit: *Emit, inst: Mir.Inst.Index) InnerError!void {
     assert(tag == .movabs);
     const ops = Mir.Ops.decode(emit.mir.instructions.items(.ops)[inst]);
 
+    if (ops.flags == 0b00) {
+        // movabs reg, imm64
+        // OI
+        const imm: i64 = if (ops.reg1.size() == 64) blk: {
+            const payload = emit.mir.instructions.items(.data)[inst].payload;
+            const imm = emit.mir.extraData(Mir.Imm64, payload).data;
+            break :blk @bitCast(i64, imm.decode());
+        } else emit.mir.instructions.items(.data)[inst].imm;
+        return lowerToOiEnc(.mov, ops.reg1, imm, emit.code);
+    }
+
     const encoder = try Encoder.init(emit.code, 10);
     const is_64 = blk: {
-        if (ops.flags == 0b00) {
-            // movabs reg, imm64
-            const opc: u8 = if (ops.reg1.size() == 8) 0xb0 else 0xb8;
-            if (ops.reg1.size() == 64) {
-                encoder.rex(.{
-                    .w = true,
-                    .b = ops.reg1.isExtended(),
-                });
-                encoder.opcode_withReg(opc, ops.reg1.lowId());
-                break :blk true;
-            }
-            break :blk false;
-        }
         if (ops.reg1 == .none) {
             // movabs moffs64, rax
             const opc: u8 = if (ops.reg2.size() == 8) 0xa2 else 0xa3;
@@ -1486,3 +1545,24 @@ test "lower MR encoding" {
         "sub qword ptr [r11 + 0x10000000], r12",
     );
 }
+
+test "lower OI encoding" {
+    var code = TestEmitCode.init();
+    defer code.deinit();
+    try lowerToOiEnc(.mov, .rax, 0x1000000000000000, code.buffer());
+    try expectEqualHexStrings(
+        "\x48\xB8\x00\x00\x00\x00\x00\x00\x00\x10",
+        code.emitted(),
+        "movabs rax, 0x1000000000000000",
+    );
+    try lowerToOiEnc(.mov, .r11, 0x1000000000000000, code.buffer());
+    try expectEqualHexStrings(
+        "\x49\xBB\x00\x00\x00\x00\x00\x00\x00\x10",
+        code.emitted(),
+        "movabs r11, 0x1000000000000000",
+    );
+    try lowerToOiEnc(.mov, .r11d, 0x10000000, code.buffer());
+    try expectEqualHexStrings("\x41\xBB\x00\x00\x00\x10", code.emitted(), "mov r11d, 0x10000000");
+    try lowerToOiEnc(.mov, .r11b, 0x10, code.buffer());
+    try expectEqualHexStrings("\x41\xB3\x10", code.emitted(), "mov r11b, 0x10");
+}