Commit 7a2963efab

Jacob Young <jacobly0@users.noreply.github.com>
2025-03-27 17:44:58
x86_64: add avx512 registers
1 parent bbf8abf
Changed files (3)
src/arch/x86_64/bits.zig
@@ -360,11 +360,20 @@ pub const Register = enum(u8) {
 
     ah, ch, dh, bh,
 
-    ymm0, ymm1, ymm2,  ymm3,  ymm4,  ymm5,  ymm6,  ymm7,
-    ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14, ymm15,
+    zmm0,  zmm1, zmm2,  zmm3,  zmm4,  zmm5,  zmm6,  zmm7,
+    zmm8,  zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15,
+    zmm16, zmm17,zmm18, zmm19, zmm20, zmm21, zmm22, zmm23,
+    zmm24, zmm25,zmm26, zmm27, zmm28, zmm29, zmm30, zmm31,
 
-    xmm0, xmm1, xmm2,  xmm3,  xmm4,  xmm5,  xmm6,  xmm7,
-    xmm8, xmm9, xmm10, xmm11, xmm12, xmm13, xmm14, xmm15,
+    ymm0,  ymm1, ymm2,  ymm3,  ymm4,  ymm5,  ymm6,  ymm7,
+    ymm8,  ymm9, ymm10, ymm11, ymm12, ymm13, ymm14, ymm15,
+    ymm16, ymm17,ymm18, ymm19, ymm20, ymm21, ymm22, ymm23,
+    ymm24, ymm25,ymm26, ymm27, ymm28, ymm29, ymm30, ymm31,
+
+    xmm0,  xmm1, xmm2,  xmm3,  xmm4,  xmm5,  xmm6,  xmm7,
+    xmm8,  xmm9, xmm10, xmm11, xmm12, xmm13, xmm14, xmm15,
+    xmm16, xmm17,xmm18, xmm19, xmm20, xmm21, xmm22, xmm23,
+    xmm24, xmm25,xmm26, xmm27, xmm28, xmm29, xmm30, xmm31,
 
     mm0, mm1, mm2, mm3, mm4, mm5, mm6, mm7,
 
@@ -404,8 +413,9 @@ pub const Register = enum(u8) {
             @intFromEnum(Register.al)   ... @intFromEnum(Register.r15b)  => .general_purpose,
             @intFromEnum(Register.ah)   ... @intFromEnum(Register.bh)    => .gphi,
 
-            @intFromEnum(Register.ymm0) ... @intFromEnum(Register.ymm15) => .sse,
-            @intFromEnum(Register.xmm0) ... @intFromEnum(Register.xmm15) => .sse,
+            @intFromEnum(Register.zmm0) ... @intFromEnum(Register.zmm31) => .sse,
+            @intFromEnum(Register.ymm0) ... @intFromEnum(Register.ymm31) => .sse,
+            @intFromEnum(Register.xmm0) ... @intFromEnum(Register.xmm31) => .sse,
             @intFromEnum(Register.mm0)  ... @intFromEnum(Register.mm7)   => .mmx,
             @intFromEnum(Register.st0)  ... @intFromEnum(Register.st7)   => .x87,
 
@@ -428,13 +438,14 @@ pub const Register = enum(u8) {
             @intFromEnum(Register.al)   ... @intFromEnum(Register.r15b)  => @intFromEnum(Register.al),
             @intFromEnum(Register.ah)   ... @intFromEnum(Register.bh)    => @intFromEnum(Register.ah),
 
-            @intFromEnum(Register.ymm0) ... @intFromEnum(Register.ymm15) => @intFromEnum(Register.ymm0) - 16,
-            @intFromEnum(Register.xmm0) ... @intFromEnum(Register.xmm15) => @intFromEnum(Register.xmm0) - 16,
-            @intFromEnum(Register.mm0)  ... @intFromEnum(Register.mm7)   => @intFromEnum(Register.mm0) - 32,
-            @intFromEnum(Register.st0)  ... @intFromEnum(Register.st7)   => @intFromEnum(Register.st0) - 40,
-            @intFromEnum(Register.es)   ... @intFromEnum(Register.gs)    => @intFromEnum(Register.es) - 48,
-            @intFromEnum(Register.cr0)  ... @intFromEnum(Register.cr15)  => @intFromEnum(Register.cr0) - 54,
-            @intFromEnum(Register.dr0)  ... @intFromEnum(Register.dr15)  => @intFromEnum(Register.dr0) - 70,
+            @intFromEnum(Register.zmm0) ... @intFromEnum(Register.zmm31) => @intFromEnum(Register.zmm0) - 16,
+            @intFromEnum(Register.ymm0) ... @intFromEnum(Register.ymm31) => @intFromEnum(Register.ymm0) - 16,
+            @intFromEnum(Register.xmm0) ... @intFromEnum(Register.xmm31) => @intFromEnum(Register.xmm0) - 16,
+            @intFromEnum(Register.mm0)  ... @intFromEnum(Register.mm7)   => @intFromEnum(Register.mm0)  - 48,
+            @intFromEnum(Register.st0)  ... @intFromEnum(Register.st7)   => @intFromEnum(Register.st0)  - 56,
+            @intFromEnum(Register.es)   ... @intFromEnum(Register.gs)    => @intFromEnum(Register.es)   - 64,
+            @intFromEnum(Register.cr0)  ... @intFromEnum(Register.cr15)  => @intFromEnum(Register.cr0)  - 70,
+            @intFromEnum(Register.dr0)  ... @intFromEnum(Register.dr15)  => @intFromEnum(Register.dr0)  - 86,
 
             else => unreachable,
             // zig fmt: on
@@ -451,6 +462,7 @@ pub const Register = enum(u8) {
             @intFromEnum(Register.al)   ... @intFromEnum(Register.r15b)  => 8,
             @intFromEnum(Register.ah)   ... @intFromEnum(Register.bh)    => 8,
 
+            @intFromEnum(Register.zmm0) ... @intFromEnum(Register.zmm15) => 512,
             @intFromEnum(Register.ymm0) ... @intFromEnum(Register.ymm15) => 256,
             @intFromEnum(Register.xmm0) ... @intFromEnum(Register.xmm15) => 128,
             @intFromEnum(Register.mm0)  ... @intFromEnum(Register.mm7)   => 64,
@@ -474,8 +486,9 @@ pub const Register = enum(u8) {
             @intFromEnum(Register.r8w) ... @intFromEnum(Register.r15w)   => true,
             @intFromEnum(Register.r8b) ... @intFromEnum(Register.r15b)   => true,
 
-            @intFromEnum(Register.ymm8) ... @intFromEnum(Register.ymm15) => true,
-            @intFromEnum(Register.xmm8) ... @intFromEnum(Register.xmm15) => true,
+            @intFromEnum(Register.zmm8) ... @intFromEnum(Register.zmm31) => true,
+            @intFromEnum(Register.ymm8) ... @intFromEnum(Register.ymm31) => true,
+            @intFromEnum(Register.xmm8) ... @intFromEnum(Register.xmm31) => true,
 
             @intFromEnum(Register.cr8)  ... @intFromEnum(Register.cr15)  => true,
             @intFromEnum(Register.dr8)  ... @intFromEnum(Register.dr15)  => true,
@@ -485,7 +498,7 @@ pub const Register = enum(u8) {
         };
     }
 
-    pub fn enc(reg: Register) u4 {
+    pub fn enc(reg: Register) u5 {
         const base = switch (@intFromEnum(reg)) {
             // zig fmt: off
             @intFromEnum(Register.rax)  ... @intFromEnum(Register.r15)   => @intFromEnum(Register.rax),
@@ -510,10 +523,6 @@ pub const Register = enum(u8) {
         return @truncate(@intFromEnum(reg) - base);
     }
 
-    pub fn lowEnc(reg: Register) u3 {
-        return @truncate(reg.enc());
-    }
-
     pub fn toBitSize(reg: Register, bit_size: u64) Register {
         return switch (bit_size) {
             8 => reg.to8(),
@@ -558,11 +567,12 @@ pub const Register = enum(u8) {
         };
     }
 
-    fn sseBase(reg: Register) u7 {
+    fn sseBase(reg: Register) u8 {
         assert(reg.class() == .sse);
         return switch (@intFromEnum(reg)) {
-            @intFromEnum(Register.ymm0)...@intFromEnum(Register.ymm15) => @intFromEnum(Register.ymm0),
-            @intFromEnum(Register.xmm0)...@intFromEnum(Register.xmm15) => @intFromEnum(Register.xmm0),
+            @intFromEnum(Register.zmm0)...@intFromEnum(Register.zmm31) => @intFromEnum(Register.zmm0),
+            @intFromEnum(Register.ymm0)...@intFromEnum(Register.ymm31) => @intFromEnum(Register.ymm0),
+            @intFromEnum(Register.xmm0)...@intFromEnum(Register.xmm31) => @intFromEnum(Register.xmm0),
             else => unreachable,
         };
     }
@@ -682,13 +692,6 @@ pub const Memory = struct {
         rip_inst: Mir.Inst.Index,
 
         pub const Tag = @typeInfo(Base).@"union".tag_type.?;
-
-        pub fn isExtended(self: Base) bool {
-            return switch (self) {
-                .none, .frame, .table, .reloc, .rip_inst => false, // rsp, rbp, and rip are not extended
-                .reg => |reg| reg.isExtended(),
-            };
-        }
     };
 
     pub const Mod = union(enum(u1)) {
src/arch/x86_64/encoder.zig
@@ -206,19 +206,22 @@ pub const Instruction = struct {
             };
         }
 
-        pub fn isBaseExtended(op: Operand) bool {
+        pub fn baseExtEnc(op: Operand) u2 {
             return switch (op) {
-                .none, .imm => false,
-                .reg => |reg| reg.isExtended(),
-                .mem => |mem| mem.base().isExtended(),
+                .none, .imm => 0b00,
+                .reg => |reg| @truncate(reg.enc() >> 3),
+                .mem => |mem| switch (mem.base()) {
+                    .none, .frame, .table, .reloc, .rip_inst => 0b00, // rsp, rbp, and rip are not extended
+                    .reg => |reg| @truncate(reg.enc() >> 3),
+                },
                 .bytes => unreachable,
             };
         }
 
-        pub fn isIndexExtended(op: Operand) bool {
+        pub fn indexExtEnc(op: Operand) u2 {
             return switch (op) {
-                .none, .reg, .imm => false,
-                .mem => |mem| if (mem.scaleIndex()) |si| si.index.isExtended() else false,
+                .none, .reg, .imm => 0b00,
+                .mem => |mem| if (mem.scaleIndex()) |si| @truncate(si.index.enc() >> 3) else 0b00,
                 .bytes => unreachable,
             };
         }
@@ -422,14 +425,14 @@ pub const Instruction = struct {
                 };
                 switch (mem_op) {
                     .reg => |reg| {
-                        const rm = switch (data.op_en) {
+                        const rm: u3 = switch (data.op_en) {
                             .ia, .m, .mi, .m1, .mc, .vm, .vmi => enc.modRmExt(),
-                            .mr, .mri, .mrc => inst.ops[1].reg.lowEnc(),
-                            .rm, .rmi, .rm0, .rvm, .rvmr, .rvmi, .rmv => inst.ops[0].reg.lowEnc(),
-                            .mvr => inst.ops[2].reg.lowEnc(),
+                            .mr, .mri, .mrc => @truncate(inst.ops[1].reg.enc()),
+                            .rm, .rmi, .rm0, .rvm, .rvmr, .rvmi, .rmv => @truncate(inst.ops[0].reg.enc()),
+                            .mvr => @truncate(inst.ops[2].reg.enc()),
                             else => unreachable,
                         };
-                        try encoder.modRm_direct(rm, reg.lowEnc());
+                        try encoder.modRm_direct(rm, @truncate(reg.enc()));
                     },
                     .mem => |mem| {
                         const op = switch (data.op_en) {
@@ -448,7 +451,7 @@ pub const Instruction = struct {
                     .ia => try encodeImm(inst.ops[0].imm, data.ops[0], encoder),
                     .mi => try encodeImm(inst.ops[1].imm, data.ops[1], encoder),
                     .rmi, .mri, .vmi => try encodeImm(inst.ops[2].imm, data.ops[2], encoder),
-                    .rvmr => try encoder.imm8(@as(u8, inst.ops[3].reg.enc()) << 4),
+                    .rvmr => try encoder.imm8(@as(u8, @as(u4, @intCast(inst.ops[3].reg.enc()))) << 4),
                     .rvmi => try encodeImm(inst.ops[3].imm, data.ops[3], encoder),
                     else => {},
                 }
@@ -462,8 +465,8 @@ pub const Instruction = struct {
         const final = opcode.len - 1;
         for (opcode[first..final]) |byte| try encoder.opcode_1byte(byte);
         switch (inst.encoding.data.op_en) {
-            .o, .oz, .oi => try encoder.opcode_withReg(opcode[final], inst.ops[0].reg.lowEnc()),
-            .zo => try encoder.opcode_withReg(opcode[final], inst.ops[1].reg.lowEnc()),
+            .o, .oz, .oi => try encoder.opcode_withReg(opcode[final], @truncate(inst.ops[0].reg.enc())),
+            .zo => try encoder.opcode_withReg(opcode[final], @truncate(inst.ops[1].reg.enc())),
             else => try encoder.opcode_1byte(opcode[final]),
         }
     }
@@ -533,23 +536,29 @@ pub const Instruction = struct {
 
         switch (op_en) {
             .z, .i, .zi, .ii, .ia, .fd, .td, .d => {},
-            .o, .oz, .oi => rex.b = inst.ops[0].reg.isExtended(),
-            .zo => rex.b = inst.ops[1].reg.isExtended(),
+            .o, .oz, .oi => rex.b = inst.ops[0].reg.enc() & 0b01000 != 0,
+            .zo => rex.b = inst.ops[1].reg.enc() & 0b01000 != 0,
             .m, .mi, .m1, .mc, .mr, .rm, .rmi, .mri, .mrc, .rm0, .rmv => {
                 const r_op = switch (op_en) {
                     .rm, .rmi, .rm0, .rmv => inst.ops[0],
                     .mr, .mri, .mrc => inst.ops[1],
                     else => .none,
                 };
-                rex.r = r_op.isBaseExtended();
+                const r_op_base_ext_enc = r_op.baseExtEnc();
+                rex.r = r_op_base_ext_enc & 0b01 != 0;
+                assert(r_op_base_ext_enc & 0b10 == 0);
 
                 const b_x_op = switch (op_en) {
                     .rm, .rmi, .rm0 => inst.ops[1],
                     .m, .mi, .m1, .mc, .mr, .mri, .mrc => inst.ops[0],
                     else => unreachable,
                 };
-                rex.b = b_x_op.isBaseExtended();
-                rex.x = b_x_op.isIndexExtended();
+                const b_x_op_base_ext_enc = b_x_op.baseExtEnc();
+                rex.b = b_x_op_base_ext_enc & 0b01 != 0;
+                assert(b_x_op_base_ext_enc & 0b10 == 0);
+                const b_x_op_index_ext_enc = b_x_op.indexExtEnc();
+                rex.x = b_x_op_index_ext_enc & 0b01 != 0;
+                assert(b_x_op_index_ext_enc & 0b10 == 0);
             },
             .vm, .vmi, .rvm, .rvmr, .rvmi, .mvr => unreachable,
         }
@@ -576,7 +585,9 @@ pub const Instruction = struct {
                     .m, .mi, .m1, .mc, .vm, .vmi => .none,
                     else => unreachable,
                 };
-                vex.r = r_op.isBaseExtended();
+                const r_op_base_ext_enc = r_op.baseExtEnc();
+                vex.r = r_op_base_ext_enc & 0b01 != 0;
+                assert(r_op_base_ext_enc & 0b10 == 0);
 
                 const b_x_op = switch (op_en) {
                     .rm, .rmi, .rm0, .vm, .vmi, .rmv => inst.ops[1],
@@ -584,8 +595,12 @@ pub const Instruction = struct {
                     .rvm, .rvmr, .rvmi => inst.ops[2],
                     else => unreachable,
                 };
-                vex.b = b_x_op.isBaseExtended();
-                vex.x = b_x_op.isIndexExtended();
+                const b_x_op_base_ext_enc = b_x_op.baseExtEnc();
+                vex.b = b_x_op_base_ext_enc & 0b01 != 0;
+                assert(b_x_op_base_ext_enc & 0b10 == 0);
+                const b_x_op_index_ext_enc = b_x_op.indexExtEnc();
+                vex.x = b_x_op_index_ext_enc & 0b01 != 0;
+                assert(b_x_op_index_ext_enc & 0b10 == 0);
             },
         }
 
@@ -622,8 +637,8 @@ pub const Instruction = struct {
     }
 
     fn encodeMemory(encoding: Encoding, mem: Memory, operand: Operand, encoder: anytype) !void {
-        const operand_enc = switch (operand) {
-            .reg => |reg| reg.lowEnc(),
+        const operand_enc: u3 = switch (operand) {
+            .reg => |reg| @truncate(reg.enc()),
             .none => encoding.modRmExt(),
             else => unreachable,
         };
@@ -635,7 +650,7 @@ pub const Instruction = struct {
                     try encoder.modRm_SIBDisp0(operand_enc);
                     if (mem.scaleIndex()) |si| {
                         const scale = math.log2_int(u4, si.scale);
-                        try encoder.sib_scaleIndexDisp32(scale, si.index.lowEnc());
+                        try encoder.sib_scaleIndexDisp32(scale, @truncate(si.index.enc()));
                     } else {
                         try encoder.sib_disp32();
                     }
@@ -647,21 +662,21 @@ pub const Instruction = struct {
                         try encoder.modRm_SIBDisp0(operand_enc);
                         if (mem.scaleIndex()) |si| {
                             const scale = math.log2_int(u4, si.scale);
-                            try encoder.sib_scaleIndexDisp32(scale, si.index.lowEnc());
+                            try encoder.sib_scaleIndexDisp32(scale, @truncate(si.index.enc()));
                         } else {
                             try encoder.sib_disp32();
                         }
                         try encoder.disp32(sib.disp);
                     },
                     .general_purpose => {
-                        const dst = base.lowEnc();
+                        const dst: u3 = @truncate(base.enc());
                         const src = operand_enc;
                         if (dst == 4 or mem.scaleIndex() != null) {
                             if (sib.disp == 0 and dst != 5) {
                                 try encoder.modRm_SIBDisp0(src);
                                 if (mem.scaleIndex()) |si| {
                                     const scale = math.log2_int(u4, si.scale);
-                                    try encoder.sib_scaleIndexBase(scale, si.index.lowEnc(), dst);
+                                    try encoder.sib_scaleIndexBase(scale, @truncate(si.index.enc()), dst);
                                 } else {
                                     try encoder.sib_base(dst);
                                 }
@@ -669,7 +684,7 @@ pub const Instruction = struct {
                                 try encoder.modRm_SIBDisp8(src);
                                 if (mem.scaleIndex()) |si| {
                                     const scale = math.log2_int(u4, si.scale);
-                                    try encoder.sib_scaleIndexBaseDisp8(scale, si.index.lowEnc(), dst);
+                                    try encoder.sib_scaleIndexBaseDisp8(scale, @truncate(si.index.enc()), dst);
                                 } else {
                                     try encoder.sib_baseDisp8(dst);
                                 }
@@ -678,7 +693,7 @@ pub const Instruction = struct {
                                 try encoder.modRm_SIBDisp32(src);
                                 if (mem.scaleIndex()) |si| {
                                     const scale = math.log2_int(u4, si.scale);
-                                    try encoder.sib_scaleIndexBaseDisp32(scale, si.index.lowEnc(), dst);
+                                    try encoder.sib_scaleIndexBaseDisp32(scale, @truncate(si.index.enc()), dst);
                                 } else {
                                     try encoder.sib_baseDisp32(dst);
                                 }
@@ -867,7 +882,7 @@ fn Encoder(comptime T: type, comptime opts: Options) type {
 
                 try self.writer.writeByte(
                     @as(u8, @intFromBool(fields.w)) << 7 |
-                        @as(u8, ~fields.v.enc()) << 3 |
+                        @as(u8, ~@as(u4, @intCast(fields.v.enc()))) << 3 |
                         @as(u8, @intFromBool(fields.l)) << 2 |
                         @as(u8, @intFromEnum(fields.p)) << 0,
                 );
@@ -875,7 +890,7 @@ fn Encoder(comptime T: type, comptime opts: Options) type {
                 try self.writer.writeByte(0b1100_0101);
                 try self.writer.writeByte(
                     @as(u8, ~@intFromBool(fields.r)) << 7 |
-                        @as(u8, ~fields.v.enc()) << 3 |
+                        @as(u8, ~@as(u4, @intCast(fields.v.enc()))) << 3 |
                         @as(u8, @intFromBool(fields.l)) << 2 |
                         @as(u8, @intFromEnum(fields.p)) << 0,
                 );
src/arch/x86_64/Encoding.zig
@@ -50,7 +50,7 @@ pub fn findByMnemonic(
         else => {},
     } else false;
     const rex_extended = for (ops) |op| {
-        if (op.isBaseExtended() or op.isIndexExtended()) break true;
+        if (op.baseExtEnc() != 0b00 or op.indexExtEnc() != 0b00) break true;
     } else false;
 
     if ((rex_required or rex_extended) and rex_invalid) return error.CannotEncode;