Commit 8c664d3f6a

Jakub Konka <kubkon@jakubkonka.com>
2021-12-23 18:49:03
stage2: support multibyte opcodes and refactor 1byte opcode changes
1 parent d23a148
Changed files (1)
src
arch
x86_64
src/arch/x86_64/Emit.zig
@@ -193,8 +193,7 @@ fn mirNop(emit: *Emit) InnerError!void {
 }
 
 fn mirSyscall(emit: *Emit) InnerError!void {
-    const encoder = try Encoder.init(emit.code, 2);
-    encoder.opcode_2byte(0x0f, 0x05);
+    return lowerToZoEnc(.syscall, emit.code);
 }
 
 fn mirPushPop(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
@@ -470,6 +469,7 @@ const Tag = enum {
     @"test",
     brk,
     nop,
+    syscall,
     ret_near,
     ret_far,
 };
@@ -509,78 +509,104 @@ const Encoding = enum {
     td,
 };
 
-inline fn getOpCode(tag: Tag, enc: Encoding) ?u8 {
+const OpCode = union(enum) {
+    one_byte: u8,
+    two_byte: struct { _1: u8, _2: u8 },
+
+    fn oneByte(opc: u8) OpCode {
+        return .{ .one_byte = opc };
+    }
+
+    fn twoByte(opc1: u8, opc2: u8) OpCode {
+        return .{ .two_byte = .{ ._1 = opc1, ._2 = opc2 } };
+    }
+
+    fn encode(opc: OpCode, encoder: Encoder) void {
+        switch (opc) {
+            .one_byte => |v| encoder.opcode_1byte(v),
+            .two_byte => |v| encoder.opcode_2byte(v._1, v._2),
+        }
+    }
+
+    fn encodeWithReg(opc: OpCode, encoder: Encoder, reg: Register) void {
+        assert(opc == .one_byte);
+        encoder.opcode_withReg(opc.one_byte, reg.lowId());
+    }
+};
+
+inline fn getOpCode(tag: Tag, enc: Encoding, is_one_byte: bool) ?OpCode {
     switch (enc) {
         .zo => return switch (tag) {
-            .ret_near => 0xc3,
-            .ret_far => 0xcb,
-            .brk => 0xcc,
-            .nop => 0x90,
+            .ret_near => OpCode.oneByte(0xc3),
+            .ret_far => OpCode.oneByte(0xcb),
+            .brk => OpCode.oneByte(0xcc),
+            .nop => OpCode.oneByte(0x90),
+            .syscall => OpCode.twoByte(0x0f, 0x05),
             else => null,
         },
         .d => return switch (tag) {
-            .jmp_near => 0xe9,
-            .call_near => 0xe8,
+            .jmp_near => OpCode.oneByte(0xe9),
+            .call_near => OpCode.oneByte(0xe8),
             else => null,
         },
         .m => return switch (tag) {
-            .jmp_near, .call_near, .push => 0xff,
-            .pop => 0x8f,
+            .jmp_near, .call_near, .push => OpCode.oneByte(0xff),
+            .pop => OpCode.oneByte(0x8f),
             else => null,
         },
         .o => return switch (tag) {
-            .push => 0x50,
-            .pop => 0x58,
+            .push => OpCode.oneByte(0x50),
+            .pop => OpCode.oneByte(0x58),
             else => null,
         },
         .i => return switch (tag) {
-            .push => 0x68,
-            .@"test" => 0xa9,
-            .ret_near => 0xc2,
-            .ret_far => 0xca,
+            .push => OpCode.oneByte(if (is_one_byte) 0x6a else 0x68),
+            .@"test" => OpCode.oneByte(if (is_one_byte) 0xa8 else 0xa9),
+            .ret_near => OpCode.oneByte(0xc2),
+            .ret_far => OpCode.oneByte(0xca),
             else => null,
         },
         .mi => return switch (tag) {
-            .adc, .add, .sub, .xor, .@"and", .@"or", .sbb, .cmp => 0x81,
-            .mov => 0xc7,
-            .@"test" => 0xf7,
+            .adc, .add, .sub, .xor, .@"and", .@"or", .sbb, .cmp => OpCode.oneByte(if (is_one_byte) 0x80 else 0x81),
+            .mov => OpCode.oneByte(if (is_one_byte) 0xc6 else 0xc7),
+            .@"test" => OpCode.oneByte(if (is_one_byte) 0xf6 else 0xf7),
             else => null,
         },
         .mr => return switch (tag) {
-            .adc => 0x11,
-            .add => 0x01,
-            .sub => 0x29,
-            .xor => 0x31,
-            .@"and" => 0x21,
-            .@"or" => 0x09,
-            .sbb => 0x19,
-            .cmp => 0x39,
-            .mov => 0x89,
+            .adc => OpCode.oneByte(if (is_one_byte) 0x10 else 0x11),
+            .add => OpCode.oneByte(if (is_one_byte) 0x00 else 0x01),
+            .sub => OpCode.oneByte(if (is_one_byte) 0x28 else 0x29),
+            .xor => OpCode.oneByte(if (is_one_byte) 0x30 else 0x31),
+            .@"and" => OpCode.oneByte(if (is_one_byte) 0x20 else 0x21),
+            .@"or" => OpCode.oneByte(if (is_one_byte) 0x08 else 0x09),
+            .sbb => OpCode.oneByte(if (is_one_byte) 0x18 else 0x19),
+            .cmp => OpCode.oneByte(if (is_one_byte) 0x38 else 0x39),
+            .mov => OpCode.oneByte(if (is_one_byte) 0x88 else 0x89),
             else => null,
         },
         .rm => return switch (tag) {
-            .adc => 0x13,
-            .add => 0x03,
-            .sub => 0x2b,
-            .xor => 0x33,
-            .@"and" => 0x23,
-            .@"or" => 0x0b,
-            .sbb => 0x1b,
-            .cmp => 0x3b,
-            .mov => 0x8b,
-            .lea => 0x8d,
+            .adc => OpCode.oneByte(if (is_one_byte) 0x12 else 0x13),
+            .add => OpCode.oneByte(if (is_one_byte) 0x02 else 0x03),
+            .sub => OpCode.oneByte(if (is_one_byte) 0x2a else 0x2b),
+            .xor => OpCode.oneByte(if (is_one_byte) 0x32 else 0x33),
+            .@"and" => OpCode.oneByte(if (is_one_byte) 0x22 else 0x23),
+            .@"or" => OpCode.oneByte(if (is_one_byte) 0x0b else 0x0b),
+            .sbb => OpCode.oneByte(if (is_one_byte) 0x1a else 0x1b),
+            .cmp => OpCode.oneByte(if (is_one_byte) 0x3a else 0x3b),
+            .mov => OpCode.oneByte(if (is_one_byte) 0x8a else 0x8b),
+            .lea => OpCode.oneByte(if (is_one_byte) 0x8c else 0x8d),
             else => null,
         },
         .oi => return switch (tag) {
-            .mov => 0xb8,
+            .mov => OpCode.oneByte(if (is_one_byte) 0xb0 else 0xb8),
             else => null,
         },
         .fd => return switch (tag) {
-            .mov => 0xa1,
+            .mov => OpCode.oneByte(if (is_one_byte) 0xa0 else 0xa1),
             else => null,
         },
         .td => return switch (tag) {
-            .mov => 0xa3,
+            .mov => OpCode.oneByte(if (is_one_byte) 0xa2 else 0xa3),
             else => null,
         },
     }
@@ -648,32 +674,25 @@ const RegisterOrMemory = union(enum) {
 };
 
 fn lowerToZoEnc(tag: Tag, code: *std.ArrayList(u8)) InnerError!void {
-    const opc = getOpCode(tag, .zo).?;
+    const opc = getOpCode(tag, .zo, false).?;
     const encoder = try Encoder.init(code, 1);
-    encoder.opcode_1byte(opc);
+    opc.encode(encoder);
 }
 
 fn lowerToIEnc(tag: Tag, imm: i32, code: *std.ArrayList(u8)) InnerError!void {
-    var opc = getOpCode(tag, .i).?;
     if (tag == .ret_far or tag == .ret_near) {
         const encoder = try Encoder.init(code, 3);
-        encoder.opcode_1byte(opc);
+        const opc = getOpCode(tag, .i, false).?;
+        opc.encode(encoder);
         encoder.imm16(@intCast(i16, imm));
         return;
     }
-    if (immOpSize(imm) == 8) {
-        // TODO I think getOpCode should track this
-        switch (tag) {
-            .push => opc += 2,
-            .@"test" => opc -= 1,
-            else => return error.EmitFail,
-        }
-    }
+    const opc = getOpCode(tag, .i, immOpSize(imm) == 8).?;
     const encoder = try Encoder.init(code, 5);
     if (immOpSize(imm) == 16) {
         encoder.opcode_1byte(0x66);
     }
-    encoder.opcode_1byte(opc);
+    opc.encode(encoder);
     if (immOpSize(imm) == 8) {
         encoder.imm8(@intCast(i8, imm));
     } else if (immOpSize(imm) == 16) {
@@ -685,7 +704,7 @@ fn lowerToIEnc(tag: Tag, imm: i32, code: *std.ArrayList(u8)) InnerError!void {
 
 fn lowerToOEnc(tag: Tag, reg: Register, code: *std.ArrayList(u8)) InnerError!void {
     if (reg.size() != 16 and reg.size() != 64) return error.EmitFail; // TODO correct for push/pop, but is it universal?
-    const opc = getOpCode(tag, .o).?;
+    const opc = getOpCode(tag, .o, false).?;
     const encoder = try Encoder.init(code, 3);
     if (reg.size() == 16) {
         encoder.opcode_1byte(0x66);
@@ -694,18 +713,18 @@ fn lowerToOEnc(tag: Tag, reg: Register, code: *std.ArrayList(u8)) InnerError!voi
         .w = false,
         .b = reg.isExtended(),
     });
-    encoder.opcode_withReg(opc, reg.lowId());
+    opc.encodeWithReg(encoder, reg);
 }
 
 fn lowerToDEnc(tag: Tag, imm: i32, code: *std.ArrayList(u8)) InnerError!void {
-    const opc = getOpCode(tag, .d).?;
+    const opc = getOpCode(tag, .d, false).?;
     const encoder = try Encoder.init(code, 5);
-    encoder.opcode_1byte(opc);
+    opc.encode(encoder);
     encoder.imm32(imm);
 }
 
 fn lowerToMEnc(tag: Tag, reg_or_mem: RegisterOrMemory, code: *std.ArrayList(u8)) InnerError!void {
-    const opc = getOpCode(tag, .m).?;
+    const opc = getOpCode(tag, .m, false).?;
     const modrm_ext = getModRmExt(tag).?;
     switch (reg_or_mem) {
         .register => |reg| {
@@ -715,7 +734,7 @@ fn lowerToMEnc(tag: Tag, reg_or_mem: RegisterOrMemory, code: *std.ArrayList(u8))
                 .w = false,
                 .b = reg.isExtended(),
             });
-            encoder.opcode_1byte(opc);
+            opc.encode(encoder);
             encoder.modRm_direct(modrm_ext, reg.lowId());
         },
         .memory => |mem_op| {
@@ -726,7 +745,7 @@ fn lowerToMEnc(tag: Tag, reg_or_mem: RegisterOrMemory, code: *std.ArrayList(u8))
                     .w = false,
                     .b = reg.isExtended(),
                 });
-                encoder.opcode_1byte(opc);
+                opc.encode(encoder);
                 if (reg.lowId() == 4) {
                     if (mem_op.disp == 0) {
                         encoder.modRm_SIBDisp0(modrm_ext);
@@ -752,7 +771,7 @@ fn lowerToMEnc(tag: Tag, reg_or_mem: RegisterOrMemory, code: *std.ArrayList(u8))
                     }
                 }
             } else {
-                encoder.opcode_1byte(opc);
+                opc.encode(encoder);
                 if (mem_op.rip) {
                     encoder.modRm_RIPDisp32(modrm_ext);
                 } else {
@@ -776,10 +795,10 @@ fn lowerToFdEnc(tag: Tag, reg: Register, moffs: i64, code: *std.ArrayList(u8)) I
 fn lowerToTdFdEnc(tag: Tag, reg: Register, moffs: i64, code: *std.ArrayList(u8), td: bool) InnerError!void {
     if (reg.lowId() != Register.rax.lowId()) return error.EmitFail;
     if (reg.size() != immOpSize(moffs)) return error.EmitFail;
-    var opc = if (td) getOpCode(tag, .td).? else getOpCode(tag, .fd).?;
-    if (reg.size() == 8) {
-        opc -= 1;
-    }
+    const opc = if (td)
+        getOpCode(tag, .td, reg.size() == 8).?
+    else
+        getOpCode(tag, .fd, reg.size() == 8).?;
     const encoder = try Encoder.init(code, 10);
     if (reg.size() == 16) {
         encoder.opcode_1byte(0x66);
@@ -787,7 +806,7 @@ fn lowerToTdFdEnc(tag: Tag, reg: Register, moffs: i64, code: *std.ArrayList(u8),
     encoder.rex(.{
         .w = reg.size() == 64,
     });
-    encoder.opcode_1byte(opc);
+    opc.encode(encoder);
     switch (reg.size()) {
         8 => {
             const moffs8 = try math.cast(i8, moffs);
@@ -809,11 +828,8 @@ fn lowerToTdFdEnc(tag: Tag, reg: Register, moffs: i64, code: *std.ArrayList(u8),
 }
 
 fn lowerToOiEnc(tag: Tag, reg: Register, imm: i64, code: *std.ArrayList(u8)) InnerError!void {
-    var opc = getOpCode(tag, .oi).?;
     if (reg.size() != immOpSize(imm)) return error.EmitFail;
-    if (reg.size() == 8) {
-        opc -= 8;
-    }
+    const opc = getOpCode(tag, .oi, reg.size() == 8).?;
     const encoder = try Encoder.init(code, 10);
     if (reg.size() == 16) {
         encoder.opcode_1byte(0x66);
@@ -822,7 +838,7 @@ fn lowerToOiEnc(tag: Tag, reg: Register, imm: i64, code: *std.ArrayList(u8)) Inn
         .w = reg.size() == 64,
         .b = reg.isExtended(),
     });
-    encoder.opcode_withReg(opc, reg.lowId());
+    opc.encodeWithReg(encoder, reg);
     switch (reg.size()) {
         8 => {
             const imm8 = try math.cast(i8, imm);
@@ -844,13 +860,10 @@ fn lowerToOiEnc(tag: Tag, reg: Register, imm: i64, code: *std.ArrayList(u8)) Inn
 }
 
 fn lowerToMiEnc(tag: Tag, reg_or_mem: RegisterOrMemory, imm: i32, code: *std.ArrayList(u8)) InnerError!void {
-    var opc = getOpCode(tag, .mi).?;
     const modrm_ext = getModRmExt(tag).?;
     switch (reg_or_mem) {
         .register => |dst_reg| {
-            if (dst_reg.size() == 8) {
-                opc -= 1;
-            }
+            const opc = getOpCode(tag, .mi, dst_reg.size() == 8).?;
             const encoder = try Encoder.init(code, 7);
             if (dst_reg.size() == 16) {
                 // 0x66 prefix switches to the non-default size; here we assume a switch from
@@ -862,7 +875,7 @@ fn lowerToMiEnc(tag: Tag, reg_or_mem: RegisterOrMemory, imm: i32, code: *std.Arr
                 .w = dst_reg.size() == 64,
                 .b = dst_reg.isExtended(),
             });
-            encoder.opcode_1byte(opc);
+            opc.encode(encoder);
             encoder.modRm_direct(modrm_ext, dst_reg.lowId());
             switch (dst_reg.size()) {
                 8 => {
@@ -878,6 +891,7 @@ fn lowerToMiEnc(tag: Tag, reg_or_mem: RegisterOrMemory, imm: i32, code: *std.Arr
             }
         },
         .memory => |dst_mem| {
+            const opc = getOpCode(tag, .mi, false).?;
             const encoder = try Encoder.init(code, 12);
             if (dst_mem.reg) |dst_reg| {
                 // Register dst_reg can either be 64bit or 32bit in size.
@@ -891,7 +905,7 @@ fn lowerToMiEnc(tag: Tag, reg_or_mem: RegisterOrMemory, imm: i32, code: *std.Arr
                     .w = false,
                     .b = dst_reg.isExtended(),
                 });
-                encoder.opcode_1byte(opc);
+                opc.encode(encoder);
                 if (dst_reg.lowId() == 4) {
                     if (dst_mem.disp == 0) {
                         encoder.modRm_SIBDisp0(modrm_ext);
@@ -917,7 +931,7 @@ fn lowerToMiEnc(tag: Tag, reg_or_mem: RegisterOrMemory, imm: i32, code: *std.Arr
                     }
                 }
             } else {
-                encoder.opcode_1byte(opc);
+                opc.encode(encoder);
                 if (dst_mem.rip) {
                     encoder.modRm_RIPDisp32(modrm_ext);
                 } else {
@@ -937,10 +951,7 @@ fn lowerToRmEnc(
     reg_or_mem: RegisterOrMemory,
     code: *std.ArrayList(u8),
 ) InnerError!void {
-    var opc = getOpCode(tag, .rm).?;
-    if (reg.size() == 8) {
-        opc -= 1;
-    }
+    const opc = getOpCode(tag, .rm, reg.size() == 8).?;
     switch (reg_or_mem) {
         .register => |src_reg| {
             if (reg.size() != src_reg.size()) return error.EmitFail;
@@ -950,7 +961,7 @@ fn lowerToRmEnc(
                 .r = reg.isExtended(),
                 .b = src_reg.isExtended(),
             });
-            encoder.opcode_1byte(opc);
+            opc.encode(encoder);
             encoder.modRm_direct(reg.lowId(), src_reg.lowId());
         },
         .memory => |src_mem| {
@@ -967,7 +978,7 @@ fn lowerToRmEnc(
                     .r = reg.isExtended(),
                     .b = src_reg.isExtended(),
                 });
-                encoder.opcode_1byte(opc);
+                opc.encode(encoder);
                 if (src_reg.lowId() == 4) {
                     if (src_mem.disp == 0) {
                         encoder.modRm_SIBDisp0(reg.lowId());
@@ -997,7 +1008,7 @@ fn lowerToRmEnc(
                     .w = reg.size() == 64,
                     .r = reg.isExtended(),
                 });
-                encoder.opcode_1byte(opc);
+                opc.encode(encoder);
                 if (src_mem.rip) {
                     encoder.modRm_RIPDisp32(reg.lowId());
                 } else {
@@ -1022,10 +1033,7 @@ fn lowerToMrEnc(
     // * reg is 32bit - dword ptr
     // * reg is 16bit - word ptr
     // * reg is 8bit - byte ptr
-    var opc = getOpCode(tag, .mr).?;
-    if (reg.size() == 8) {
-        opc -= 1;
-    }
+    const opc = getOpCode(tag, .mr, reg.size() == 8).?;
     switch (reg_or_mem) {
         .register => |dst_reg| {
             if (dst_reg.size() != reg.size()) return error.EmitFail;
@@ -1035,7 +1043,7 @@ fn lowerToMrEnc(
                 .r = reg.isExtended(),
                 .b = dst_reg.isExtended(),
             });
-            encoder.opcode_1byte(opc);
+            opc.encode(encoder);
             encoder.modRm_direct(reg.lowId(), dst_reg.lowId());
         },
         .memory => |dst_mem| {
@@ -1050,7 +1058,7 @@ fn lowerToMrEnc(
                     .r = reg.isExtended(),
                     .b = dst_reg.isExtended(),
                 });
-                encoder.opcode_1byte(opc);
+                opc.encode(encoder);
                 if (dst_reg.lowId() == 4) {
                     if (dst_mem.disp == 0) {
                         encoder.modRm_SIBDisp0(reg.lowId());
@@ -1080,7 +1088,7 @@ fn lowerToMrEnc(
                     .w = reg.size() == 64,
                     .r = reg.isExtended(),
                 });
-                encoder.opcode_1byte(opc);
+                opc.encode(encoder);
                 if (dst_mem.rip) {
                     encoder.modRm_RIPDisp32(reg.lowId());
                 } else {
@@ -1168,10 +1176,7 @@ fn mirArithScaleSrc(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void
     const ops = Mir.Ops.decode(emit.mir.instructions.items(.ops)[inst]);
     const scale = ops.flags;
     // OP reg1, [reg2 + scale*rcx + imm32]
-    var opc = getOpCode(tag, .rm).?;
-    if (ops.reg1.size() == 8) {
-        opc -= 1;
-    }
+    const opc = getOpCode(tag, .rm, ops.reg1.size() == 8).?;
     const imm = emit.mir.instructions.items(.data)[inst].imm;
     const encoder = try Encoder.init(emit.code, 8);
     encoder.rex(.{
@@ -1179,7 +1184,7 @@ fn mirArithScaleSrc(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void
         .r = ops.reg1.isExtended(),
         .b = ops.reg2.isExtended(),
     });
-    encoder.opcode_1byte(opc);
+    opc.encode(encoder);
     if (imm <= math.maxInt(i8)) {
         encoder.modRm_SIBDisp8(ops.reg1.lowId());
         encoder.sib_scaleIndexBaseDisp8(scale, Register.rcx.lowId(), ops.reg2.lowId());
@@ -1198,17 +1203,14 @@ fn mirArithScaleDst(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void
 
     if (ops.reg2 == .none) {
         // OP [reg1 + scale*rax + 0], imm32
-        var opc = getOpCode(tag, .mi).?;
+        const opc = getOpCode(tag, .mi, ops.reg1.size() == 8).?;
         const modrm_ext = getModRmExt(tag).?;
-        if (ops.reg1.size() == 8) {
-            opc -= 1;
-        }
         const encoder = try Encoder.init(emit.code, 8);
         encoder.rex(.{
             .w = ops.reg1.size() == 64,
             .b = ops.reg1.isExtended(),
         });
-        encoder.opcode_1byte(opc);
+        opc.encode(encoder);
         encoder.modRm_SIBDisp0(modrm_ext);
         encoder.sib_scaleIndexBase(scale, Register.rax.lowId(), ops.reg1.lowId());
         if (imm <= math.maxInt(i8)) {
@@ -1222,17 +1224,14 @@ fn mirArithScaleDst(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void
     }
 
     // OP [reg1 + scale*rax + imm32], reg2
-    var opc = getOpCode(tag, .mr).?;
-    if (ops.reg1.size() == 8) {
-        opc -= 1;
-    }
+    const opc = getOpCode(tag, .mr, ops.reg1.size() == 8).?;
     const encoder = try Encoder.init(emit.code, 8);
     encoder.rex(.{
         .w = ops.reg1.size() == 64,
         .r = ops.reg2.isExtended(),
         .b = ops.reg1.isExtended(),
     });
-    encoder.opcode_1byte(opc);
+    opc.encode(encoder);
     if (imm <= math.maxInt(i8)) {
         encoder.modRm_SIBDisp8(ops.reg2.lowId());
         encoder.sib_scaleIndexBaseDisp8(scale, Register.rax.lowId(), ops.reg1.lowId());
@@ -1249,17 +1248,14 @@ fn mirArithScaleImm(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void
     const scale = ops.flags;
     const payload = emit.mir.instructions.items(.data)[inst].payload;
     const imm_pair = emit.mir.extraData(Mir.ImmPair, payload).data;
-    var opc = getOpCode(tag, .mi).?;
-    if (ops.reg1.size() == 8) {
-        opc -= 1;
-    }
+    const opc = getOpCode(tag, .mi, ops.reg1.size() == 8).?;
     const modrm_ext = getModRmExt(tag).?;
     const encoder = try Encoder.init(emit.code, 2);
     encoder.rex(.{
         .w = ops.reg1.size() == 64,
         .b = ops.reg1.isExtended(),
     });
-    encoder.opcode_1byte(opc);
+    opc.encode(encoder);
     if (imm_pair.dest_off <= math.maxInt(i8)) {
         encoder.modRm_SIBDisp8(modrm_ext);
         encoder.sib_scaleIndexBaseDisp8(scale, Register.rax.lowId(), ops.reg1.lowId());