Commit c50bb2b80f

Jakub Konka <kubkon@jakubkonka.com>
2021-12-23 20:29:34
stage2: lower jcc and setcc conditional jump/set instructions
1 parent 8c664d3
Changed files (2)
src
src/arch/x86_64/CodeGen.zig
@@ -2898,7 +2898,7 @@ fn genSetReg(self: *Self, ty: Type, reg: Register, mcv: MCValue) InnerError!void
             _ = try self.addInst(.{
                 .tag = tag,
                 .ops = (Mir.Ops{
-                    .reg1 = reg,
+                    .reg1 = reg.to8(),
                     .flags = flags,
                 }).encode(),
                 .data = undefined,
src/arch/x86_64/Emit.zig
@@ -260,149 +260,60 @@ fn mirJmpCall(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
     return lowerToMEnc(tag, RegisterOrMemory.reg(ops.reg1), emit.code);
 }
 
-const CondType = enum {
-    /// greater than or equal
-    gte,
-
-    /// greater than
-    gt,
-
-    /// less than
-    lt,
-
-    /// less than or equal
-    lte,
-
-    /// above or equal
-    ae,
-
-    /// above
-    a,
-
-    /// below
-    b,
-
-    /// below or equal
-    be,
-
-    /// not equal
-    ne,
-
-    /// equal
-    eq,
-
-    fn fromTagAndFlags(tag: Mir.Inst.Tag, flags: u2) CondType {
-        return switch (tag) {
-            .cond_jmp_greater_less,
-            .cond_set_byte_greater_less,
-            => switch (flags) {
-                0b00 => CondType.gte,
-                0b01 => CondType.gt,
-                0b10 => CondType.lt,
-                0b11 => CondType.lte,
-            },
-            .cond_jmp_above_below,
-            .cond_set_byte_above_below,
-            => switch (flags) {
-                0b00 => CondType.ae,
-                0b01 => CondType.a,
-                0b10 => CondType.b,
-                0b11 => CondType.be,
-            },
-            .cond_jmp_eq_ne,
-            .cond_set_byte_eq_ne,
-            => switch (@truncate(u1, flags)) {
-                0b0 => CondType.ne,
-                0b1 => CondType.eq,
-            },
-            else => unreachable,
-        };
-    }
-};
-
-inline fn getCondOpCode(tag: Mir.Inst.Tag, cond: CondType) u8 {
-    switch (cond) {
-        .gte => return switch (tag) {
-            .cond_jmp_greater_less => 0x8d,
-            .cond_set_byte_greater_less => 0x9d,
-            else => unreachable,
-        },
-        .gt => return switch (tag) {
-            .cond_jmp_greater_less => 0x8f,
-            .cond_set_byte_greater_less => 0x9f,
-            else => unreachable,
-        },
-        .lt => return switch (tag) {
-            .cond_jmp_greater_less => 0x8c,
-            .cond_set_byte_greater_less => 0x9c,
-            else => unreachable,
-        },
-        .lte => return switch (tag) {
-            .cond_jmp_greater_less => 0x8e,
-            .cond_set_byte_greater_less => 0x9e,
-            else => unreachable,
-        },
-        .ae => return switch (tag) {
-            .cond_jmp_above_below => 0x83,
-            .cond_set_byte_above_below => 0x93,
-            else => unreachable,
-        },
-        .a => return switch (tag) {
-            .cond_jmp_above_below => 0x87,
-            .cond_set_byte_greater_less => 0x97,
-            else => unreachable,
-        },
-        .b => return switch (tag) {
-            .cond_jmp_above_below => 0x82,
-            .cond_set_byte_greater_less => 0x92,
-            else => unreachable,
-        },
-        .be => return switch (tag) {
-            .cond_jmp_above_below => 0x86,
-            .cond_set_byte_greater_less => 0x96,
-            else => unreachable,
+fn mirCondJmp(emit: *Emit, mir_tag: Mir.Inst.Tag, inst: Mir.Inst.Index) InnerError!void {
+    const ops = Mir.Ops.decode(emit.mir.instructions.items(.ops)[inst]);
+    const target = emit.mir.instructions.items(.data)[inst].inst;
+    const tag = switch (mir_tag) {
+        .cond_jmp_greater_less => switch (ops.flags) {
+            0b00 => Tag.jge,
+            0b01 => Tag.jg,
+            0b10 => Tag.jl,
+            0b11 => Tag.jle,
         },
-        .eq => return switch (tag) {
-            .cond_jmp_eq_ne => 0x84,
-            .cond_set_byte_eq_ne => 0x94,
-            else => unreachable,
+        .cond_jmp_above_below => switch (ops.flags) {
+            0b00 => Tag.jae,
+            0b01 => Tag.ja,
+            0b10 => Tag.jb,
+            0b11 => Tag.jbe,
         },
-        .ne => return switch (tag) {
-            .cond_jmp_eq_ne => 0x85,
-            .cond_set_byte_eq_ne => 0x95,
-            else => unreachable,
+        .cond_jmp_eq_ne => switch (@truncate(u1, ops.flags)) {
+            0b0 => Tag.jne,
+            0b1 => Tag.je,
         },
-    }
-}
-
-fn mirCondJmp(emit: *Emit, tag: Mir.Inst.Tag, inst: Mir.Inst.Index) InnerError!void {
-    const ops = Mir.Ops.decode(emit.mir.instructions.items(.ops)[inst]);
-    const target = emit.mir.instructions.items(.data)[inst].inst;
-    const cond = CondType.fromTagAndFlags(tag, ops.flags);
-    const opc = getCondOpCode(tag, cond);
+        else => unreachable,
+    };
     const source = emit.code.items.len;
-    const encoder = try Encoder.init(emit.code, 6);
-    encoder.opcode_2byte(0x0f, opc);
+    try lowerToDEnc(tag, 0, emit.code);
     try emit.relocs.append(emit.bin_file.allocator, .{
         .source = source,
         .target = target,
-        .offset = emit.code.items.len,
+        .offset = emit.code.items.len - 4,
         .length = 6,
     });
-    encoder.imm32(0);
 }
 
-fn mirCondSetByte(emit: *Emit, tag: Mir.Inst.Tag, inst: Mir.Inst.Index) InnerError!void {
+fn mirCondSetByte(emit: *Emit, mir_tag: Mir.Inst.Tag, inst: Mir.Inst.Index) InnerError!void {
     const ops = Mir.Ops.decode(emit.mir.instructions.items(.ops)[inst]);
-    const cond = CondType.fromTagAndFlags(tag, ops.flags);
-    const opc = getCondOpCode(tag, cond);
-    const encoder = try Encoder.init(emit.code, 4);
-    encoder.rex(.{
-        .w = true,
-        .b = ops.reg1.isExtended(),
-    });
-    encoder.opcode_2byte(0x0f, opc);
-    encoder.modRm_direct(0x0, ops.reg1.lowId());
+    const tag = switch (mir_tag) {
+        .cond_set_byte_greater_less => switch (ops.flags) {
+            0b00 => Tag.setge,
+            0b01 => Tag.setg,
+            0b10 => Tag.setl,
+            0b11 => Tag.setle,
+        },
+        .cond_set_byte_above_below => switch (ops.flags) {
+            0b00 => Tag.setae,
+            0b01 => Tag.seta,
+            0b10 => Tag.setb,
+            0b11 => Tag.setbe,
+        },
+        .cond_set_byte_eq_ne => switch (@truncate(u1, ops.flags)) {
+            0b0 => Tag.setne,
+            0b1 => Tag.sete,
+        },
+        else => unreachable,
+    };
+    return lowerToMEnc(tag, RegisterOrMemory.reg(ops.reg1), emit.code);
 }
 
 fn mirTest(emit: *Emit, inst: Mir.Inst.Index) InnerError!void {
@@ -472,6 +383,103 @@ const Tag = enum {
     syscall,
     ret_near,
     ret_far,
+    jo,
+    jno,
+    jb,
+    jbe,
+    jc,
+    jnae,
+    jnc,
+    jae,
+    je,
+    jz,
+    jne,
+    jnz,
+    jna,
+    jnb,
+    jnbe,
+    ja,
+    js,
+    jns,
+    jpe,
+    jp,
+    jpo,
+    jnp,
+    jnge,
+    jl,
+    jge,
+    jnl,
+    jle,
+    jng,
+    jg,
+    jnle,
+    seto,
+    setno,
+    setb,
+    setc,
+    setnae,
+    setnb,
+    setnc,
+    setae,
+    sete,
+    setz,
+    setne,
+    setnz,
+    setbe,
+    setna,
+    seta,
+    setnbe,
+    sets,
+    setns,
+    setp,
+    setpe,
+    setnp,
+    setop,
+    setl,
+    setnge,
+    setnl,
+    setge,
+    setle,
+    setng,
+    setnle,
+    setg,
+
+    fn isSetCC(tag: Tag) bool {
+        return switch (tag) {
+            .seto,
+            .setno,
+            .setb,
+            .setc,
+            .setnae,
+            .setnb,
+            .setnc,
+            .setae,
+            .sete,
+            .setz,
+            .setne,
+            .setnz,
+            .setbe,
+            .setna,
+            .seta,
+            .setnbe,
+            .sets,
+            .setns,
+            .setp,
+            .setpe,
+            .setnp,
+            .setop,
+            .setl,
+            .setnge,
+            .setnl,
+            .setge,
+            .setle,
+            .setng,
+            .setnle,
+            .setg,
+            => true,
+            else => false,
+        };
+    }
 };
 
 const Encoding = enum {
@@ -547,11 +555,43 @@ inline fn getOpCode(tag: Tag, enc: Encoding, is_one_byte: bool) ?OpCode {
         .d => return switch (tag) {
             .jmp_near => OpCode.oneByte(0xe9),
             .call_near => OpCode.oneByte(0xe8),
+            .jo => if (is_one_byte) OpCode.oneByte(0x70) else OpCode.twoByte(0x0f, 0x80),
+            .jno => if (is_one_byte) OpCode.oneByte(0x71) else OpCode.twoByte(0x0f, 0x81),
+            .jb, .jc, .jnae => if (is_one_byte) OpCode.oneByte(0x72) else OpCode.twoByte(0x0f, 0x82),
+            .jnb, .jnc, .jae => if (is_one_byte) OpCode.oneByte(0x73) else OpCode.twoByte(0x0f, 0x83),
+            .je, .jz => if (is_one_byte) OpCode.oneByte(0x74) else OpCode.twoByte(0x0f, 0x84),
+            .jne, .jnz => if (is_one_byte) OpCode.oneByte(0x75) else OpCode.twoByte(0x0f, 0x85),
+            .jna, .jbe => if (is_one_byte) OpCode.oneByte(0x76) else OpCode.twoByte(0x0f, 0x86),
+            .jnbe, .ja => if (is_one_byte) OpCode.oneByte(0x77) else OpCode.twoByte(0x0f, 0x87),
+            .js => if (is_one_byte) OpCode.oneByte(0x78) else OpCode.twoByte(0x0f, 0x88),
+            .jns => if (is_one_byte) OpCode.oneByte(0x79) else OpCode.twoByte(0x0f, 0x89),
+            .jpe, .jp => if (is_one_byte) OpCode.oneByte(0x7a) else OpCode.twoByte(0x0f, 0x8a),
+            .jpo, .jnp => if (is_one_byte) OpCode.oneByte(0x7b) else OpCode.twoByte(0x0f, 0x8b),
+            .jnge, .jl => if (is_one_byte) OpCode.oneByte(0x7c) else OpCode.twoByte(0x0f, 0x8c),
+            .jge, .jnl => if (is_one_byte) OpCode.oneByte(0x7d) else OpCode.twoByte(0x0f, 0x8d),
+            .jle, .jng => if (is_one_byte) OpCode.oneByte(0x7e) else OpCode.twoByte(0x0f, 0x8e),
+            .jg, .jnle => if (is_one_byte) OpCode.oneByte(0x7f) else OpCode.twoByte(0x0f, 0x8f),
             else => null,
         },
         .m => return switch (tag) {
             .jmp_near, .call_near, .push => OpCode.oneByte(0xff),
             .pop => OpCode.oneByte(0x8f),
+            .seto => OpCode.twoByte(0x0f, 0x90),
+            .setno => OpCode.twoByte(0x0f, 0x91),
+            .setb, .setc, .setnae => OpCode.twoByte(0x0f, 0x92),
+            .setnb, .setnc, .setae => OpCode.twoByte(0x0f, 0x93),
+            .sete, .setz => OpCode.twoByte(0x0f, 0x94),
+            .setne, .setnz => OpCode.twoByte(0x0f, 0x95),
+            .setbe, .setna => OpCode.twoByte(0x0f, 0x96),
+            .seta, .setnbe => OpCode.twoByte(0x0f, 0x97),
+            .sets => OpCode.twoByte(0x0f, 0x98),
+            .setns => OpCode.twoByte(0x0f, 0x99),
+            .setp, .setpe => OpCode.twoByte(0x0f, 0x9a),
+            .setnp, .setop => OpCode.twoByte(0x0f, 0x9b),
+            .setl, .setnge => OpCode.twoByte(0x0f, 0x9c),
+            .setnl, .setge => OpCode.twoByte(0x0f, 0x9d),
+            .setle, .setng => OpCode.twoByte(0x0f, 0x9e),
+            .setnle, .setg => OpCode.twoByte(0x0f, 0x9f),
             else => null,
         },
         .o => return switch (tag) {
@@ -628,6 +668,37 @@ inline fn getModRmExt(tag: Tag) ?u3 {
         .push => 0x6,
         .pop => 0x0,
         .@"test" => 0x0,
+        .seto,
+        .setno,
+        .setb,
+        .setc,
+        .setnae,
+        .setnb,
+        .setnc,
+        .setae,
+        .sete,
+        .setz,
+        .setne,
+        .setnz,
+        .setbe,
+        .setna,
+        .seta,
+        .setnbe,
+        .sets,
+        .setns,
+        .setp,
+        .setpe,
+        .setnp,
+        .setop,
+        .setl,
+        .setnge,
+        .setnl,
+        .setge,
+        .setle,
+        .setng,
+        .setnle,
+        .setg,
+        => 0x0,
         else => null,
     };
 }
@@ -718,7 +789,7 @@ fn lowerToOEnc(tag: Tag, reg: Register, code: *std.ArrayList(u8)) InnerError!voi
 
 fn lowerToDEnc(tag: Tag, imm: i32, code: *std.ArrayList(u8)) InnerError!void {
     const opc = getOpCode(tag, .d, false).?;
-    const encoder = try Encoder.init(code, 5);
+    const encoder = try Encoder.init(code, 6);
     opc.encode(encoder);
     encoder.imm32(imm);
 }
@@ -728,10 +799,13 @@ fn lowerToMEnc(tag: Tag, reg_or_mem: RegisterOrMemory, code: *std.ArrayList(u8))
     const modrm_ext = getModRmExt(tag).?;
     switch (reg_or_mem) {
         .register => |reg| {
-            if (reg.size() != 64) return error.EmitFail;
+            // TODO clean this up!
+            if (reg.size() != 64) {
+                if (reg.size() != 8 and !tag.isSetCC()) return error.EmitFail;
+            }
             const encoder = try Encoder.init(code, 3);
             encoder.rex(.{
-                .w = false,
+                .w = tag.isSetCC(),
                 .b = reg.isExtended(),
             });
             opc.encode(encoder);
@@ -740,9 +814,12 @@ fn lowerToMEnc(tag: Tag, reg_or_mem: RegisterOrMemory, code: *std.ArrayList(u8))
         .memory => |mem_op| {
             const encoder = try Encoder.init(code, 8);
             if (mem_op.reg) |reg| {
-                if (reg.size() != 64) return error.EmitFail;
+                // TODO clean this up!
+                if (reg.size() != 64) {
+                    if (reg.size() != 8 and !tag.isSetCC()) return error.EmitFail;
+                }
                 encoder.rex(.{
-                    .w = false,
+                    .w = tag.isSetCC(),
                     .b = reg.isExtended(),
                 });
                 opc.encode(encoder);
@@ -1172,6 +1249,7 @@ fn immOpSize(imm: i64) u8 {
     return 64;
 }
 
+// TODO
 fn mirArithScaleSrc(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
     const ops = Mir.Ops.decode(emit.mir.instructions.items(.ops)[inst]);
     const scale = ops.flags;
@@ -1196,6 +1274,7 @@ fn mirArithScaleSrc(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void
     }
 }
 
+// TODO
 fn mirArithScaleDst(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
     const ops = Mir.Ops.decode(emit.mir.instructions.items(.ops)[inst]);
     const scale = ops.flags;
@@ -1243,6 +1322,7 @@ fn mirArithScaleDst(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void
     }
 }
 
+// TODO
 fn mirArithScaleImm(emit: *Emit, tag: Tag, inst: Mir.Inst.Index) InnerError!void {
     const ops = Mir.Ops.decode(emit.mir.instructions.items(.ops)[inst]);
     const scale = ops.flags;
@@ -1757,6 +1837,8 @@ test "lower M encoding" {
     try expectEqualHexStrings("\xFF\x25\x10\x00\x00\x00", code.emitted(), "jmp qword ptr [rip + 0x10]");
     try lowerToMEnc(.jmp_near, RegisterOrMemory.mem(null, 0x10), code.buffer());
     try expectEqualHexStrings("\xFF\x24\x25\x10\x00\x00\x00", code.emitted(), "jmp qword ptr [ds:0x10]");
+    try lowerToMEnc(.seta, RegisterOrMemory.reg(.r11b), code.buffer());
+    try expectEqualHexStrings("\x49\x0F\x97\xC3", code.emitted(), "seta r11b");
 }
 
 test "lower O encoding" {