Commit 3a5e3c52e0

Jacob Young <jacobly0@users.noreply.github.com>
2023-05-07 02:31:48
x86_64: implement `@mulAdd`
1 parent 0bd92da
Changed files (7)
src/arch/x86_64/bits.zig
@@ -485,7 +485,9 @@ pub const Memory = union(enum) {
         dword,
         qword,
         tbyte,
-        dqword,
+        xword,
+        yword,
+        zword,
 
         pub fn fromSize(size: u32) PtrSize {
             return switch (size) {
@@ -493,7 +495,9 @@ pub const Memory = union(enum) {
                 2...2 => .word,
                 3...4 => .dword,
                 5...8 => .qword,
-                9...16 => .dqword,
+                9...16 => .xword,
+                17...32 => .yword,
+                33...64 => .zword,
                 else => unreachable,
             };
         }
@@ -505,7 +509,9 @@ pub const Memory = union(enum) {
                 32 => .dword,
                 64 => .qword,
                 80 => .tbyte,
-                128 => .dqword,
+                128 => .xword,
+                256 => .yword,
+                512 => .zword,
                 else => unreachable,
             };
         }
@@ -517,7 +523,9 @@ pub const Memory = union(enum) {
                 .dword => 32,
                 .qword => 64,
                 .tbyte => 80,
-                .dqword => 128,
+                .xword => 128,
+                .yword => 256,
+                .zword => 512,
             };
         }
     };
src/arch/x86_64/CodeGen.zig
@@ -1200,6 +1200,32 @@ fn asmRegisterRegisterImmediate(
     });
 }
 
+fn asmRegisterRegisterMemory(
+    self: *Self,
+    tag: Mir.Inst.Tag,
+    reg1: Register,
+    reg2: Register,
+    m: Memory,
+) !void {
+    _ = try self.addInst(.{
+        .tag = tag,
+        .ops = switch (m) {
+            .sib => .rrm_sib,
+            .rip => .rrm_rip,
+            else => unreachable,
+        },
+        .data = .{ .rrx = .{
+            .r1 = reg1,
+            .r2 = reg2,
+            .payload = switch (m) {
+                .sib => try self.addExtra(Mir.MemorySib.encode(m)),
+                .rip => try self.addExtra(Mir.MemoryRip.encode(m)),
+                else => unreachable,
+            },
+        } },
+    });
+}
+
 fn asmMemory(self: *Self, tag: Mir.Inst.Tag, m: Memory) !void {
     _ = try self.addInst(.{
         .tag = tag,
@@ -9369,9 +9395,146 @@ fn airPrefetch(self: *Self, inst: Air.Inst.Index) !void {
 fn airMulAdd(self: *Self, inst: Air.Inst.Index) !void {
     const pl_op = self.air.instructions.items(.data)[inst].pl_op;
     const extra = self.air.extraData(Air.Bin, pl_op.payload).data;
-    _ = extra;
-    return self.fail("TODO implement airMulAdd for x86_64", .{});
-    //return self.finishAir(inst, result, .{ extra.lhs, extra.rhs, pl_op.operand });
+    const ty = self.air.typeOfIndex(inst);
+
+    if (!self.hasFeature(.fma)) return self.fail("TODO implement airMulAdd for {}", .{
+        ty.fmt(self.bin_file.options.module.?),
+    });
+
+    const ops = [3]Air.Inst.Ref{ extra.lhs, extra.rhs, pl_op.operand };
+    var mcvs: [3]MCValue = undefined;
+    var locks = [1]?RegisterManager.RegisterLock{null} ** 3;
+    defer for (locks) |reg_lock| if (reg_lock) |lock| self.register_manager.unlockReg(lock);
+    var order = [1]u2{0} ** 3;
+    var unused = std.StaticBitSet(3).initFull();
+    for (ops, &mcvs, &locks, 0..) |op, *mcv, *lock, op_i| {
+        const op_index = @intCast(u2, op_i);
+        mcv.* = try self.resolveInst(op);
+        if (unused.isSet(0) and mcv.isRegister() and self.reuseOperand(inst, op, op_index, mcv.*)) {
+            order[op_index] = 1;
+            unused.unset(0);
+        } else if (unused.isSet(2) and mcv.isMemory()) {
+            order[op_index] = 3;
+            unused.unset(2);
+        }
+        switch (mcv.*) {
+            .register => |reg| lock.* = self.register_manager.lockReg(reg),
+            else => {},
+        }
+    }
+    for (&order, &mcvs, &locks) |*mop_index, *mcv, *lock| {
+        if (mop_index.* != 0) continue;
+        mop_index.* = 1 + @intCast(u2, unused.toggleFirstSet().?);
+        if (mop_index.* > 1 and mcv.isRegister()) continue;
+        const reg = try self.copyToTmpRegister(ty, mcv.*);
+        mcv.* = .{ .register = reg };
+        if (lock.*) |old_lock| self.register_manager.unlockReg(old_lock);
+        lock.* = self.register_manager.lockRegAssumeUnused(reg);
+    }
+
+    const tag: ?Mir.Inst.Tag =
+        if (mem.eql(u2, &order, &.{ 1, 3, 2 }) or mem.eql(u2, &order, &.{ 3, 1, 2 }))
+        switch (ty.zigTypeTag()) {
+            .Float => switch (ty.floatBits(self.target.*)) {
+                32 => .vfmadd132ss,
+                64 => .vfmadd132sd,
+                else => null,
+            },
+            .Vector => switch (ty.childType().zigTypeTag()) {
+                .Float => switch (ty.childType().floatBits(self.target.*)) {
+                    32 => switch (ty.vectorLen()) {
+                        1 => .vfmadd132ss,
+                        2...8 => .vfmadd132ps,
+                        else => null,
+                    },
+                    64 => switch (ty.vectorLen()) {
+                        1 => .vfmadd132sd,
+                        2...4 => .vfmadd132pd,
+                        else => null,
+                    },
+                    else => null,
+                },
+                else => null,
+            },
+            else => unreachable,
+        }
+    else if (mem.eql(u2, &order, &.{ 2, 1, 3 }) or mem.eql(u2, &order, &.{ 1, 2, 3 }))
+        switch (ty.zigTypeTag()) {
+            .Float => switch (ty.floatBits(self.target.*)) {
+                32 => .vfmadd213ss,
+                64 => .vfmadd213sd,
+                else => null,
+            },
+            .Vector => switch (ty.childType().zigTypeTag()) {
+                .Float => switch (ty.childType().floatBits(self.target.*)) {
+                    32 => switch (ty.vectorLen()) {
+                        1 => .vfmadd213ss,
+                        2...8 => .vfmadd213ps,
+                        else => null,
+                    },
+                    64 => switch (ty.vectorLen()) {
+                        1 => .vfmadd213sd,
+                        2...4 => .vfmadd213pd,
+                        else => null,
+                    },
+                    else => null,
+                },
+                else => null,
+            },
+            else => unreachable,
+        }
+    else if (mem.eql(u2, &order, &.{ 2, 3, 1 }) or mem.eql(u2, &order, &.{ 3, 2, 1 }))
+        switch (ty.zigTypeTag()) {
+            .Float => switch (ty.floatBits(self.target.*)) {
+                32 => .vfmadd231ss,
+                64 => .vfmadd231sd,
+                else => null,
+            },
+            .Vector => switch (ty.childType().zigTypeTag()) {
+                .Float => switch (ty.childType().floatBits(self.target.*)) {
+                    32 => switch (ty.vectorLen()) {
+                        1 => .vfmadd231ss,
+                        2...8 => .vfmadd231ps,
+                        else => null,
+                    },
+                    64 => switch (ty.vectorLen()) {
+                        1 => .vfmadd231sd,
+                        2...4 => .vfmadd231pd,
+                        else => null,
+                    },
+                    else => null,
+                },
+                else => null,
+            },
+            else => null,
+        }
+    else
+        unreachable;
+    if (tag == null) return self.fail("TODO implement airMulAdd for {}", .{
+        ty.fmt(self.bin_file.options.module.?),
+    });
+
+    var mops: [3]MCValue = undefined;
+    for (order, mcvs) |mop_index, mcv| mops[mop_index - 1] = mcv;
+
+    const abi_size = @intCast(u32, ty.abiSize(self.target.*));
+    const mop1_reg = registerAlias(mops[0].getReg().?, abi_size);
+    const mop2_reg = registerAlias(mops[1].getReg().?, abi_size);
+    if (mops[2].isRegister())
+        try self.asmRegisterRegisterRegister(
+            tag.?,
+            mop1_reg,
+            mop2_reg,
+            registerAlias(mops[2].getReg().?, abi_size),
+        )
+    else
+        try self.asmRegisterRegisterMemory(
+            tag.?,
+            mop1_reg,
+            mop2_reg,
+            mops[2].mem(Memory.PtrSize.fromSize(abi_size)),
+        );
+    return self.finishAir(inst, mops[0], ops);
 }
 
 fn resolveInst(self: *Self, ref: Air.Inst.Ref) InnerError!MCValue {
src/arch/x86_64/Encoding.zig
@@ -340,6 +340,11 @@ pub const Mnemonic = enum {
     vpunpcklbw, vpunpckldq, vpunpcklqdq, vpunpcklwd,
     // F16C
     vcvtph2ps, vcvtps2ph,
+    // FMA
+    vfmadd132pd, vfmadd213pd, vfmadd231pd,
+    vfmadd132ps, vfmadd213ps, vfmadd231ps,
+    vfmadd132sd, vfmadd213sd, vfmadd231sd,
+    vfmadd132ss, vfmadd213ss, vfmadd231ss,
     // zig fmt: on
 };
 
@@ -368,12 +373,13 @@ pub const Op = enum {
     r8, r16, r32, r64,
     rm8, rm16, rm32, rm64,
     r32_m16, r64_m16,
-    m8, m16, m32, m64, m80, m128,
+    m8, m16, m32, m64, m80, m128, m256,
     rel8, rel16, rel32,
     m,
     moffs,
     sreg,
     xmm, xmm_m32, xmm_m64, xmm_m128,
+    ymm, ymm_m256,
     // zig fmt: on
 
     pub fn fromOperand(operand: Instruction.Operand) Op {
@@ -385,6 +391,7 @@ pub const Op = enum {
                     .segment => return .sreg,
                     .floating_point => return switch (reg.bitSize()) {
                         128 => .xmm,
+                        256 => .ymm,
                         else => unreachable,
                     },
                     .general_purpose => {
@@ -418,6 +425,7 @@ pub const Op = enum {
                         64 => .m64,
                         80 => .m80,
                         128 => .m128,
+                        256 => .m256,
                         else => unreachable,
                     };
                 },
@@ -454,7 +462,8 @@ pub const Op = enum {
             .eax, .r32, .rm32, .r32_m16 => unreachable,
             .rax, .r64, .rm64, .r64_m16 => unreachable,
             .xmm, .xmm_m32, .xmm_m64, .xmm_m128 => unreachable,
-            .m8, .m16, .m32, .m64, .m80, .m128 => unreachable,
+            .ymm, .ymm_m256 => unreachable,
+            .m8, .m16, .m32, .m64, .m80, .m128, .m256 => unreachable,
             .unity => 1,
             .imm8, .imm8s, .rel8 => 8,
             .imm16, .imm16s, .rel16 => 16,
@@ -468,12 +477,13 @@ pub const Op = enum {
             .none, .o16, .o32, .o64, .moffs, .m, .sreg => unreachable,
             .unity, .imm8, .imm8s, .imm16, .imm16s, .imm32, .imm32s, .imm64 => unreachable,
             .rel8, .rel16, .rel32 => unreachable,
-            .m8, .m16, .m32, .m64, .m80, .m128 => unreachable,
+            .m8, .m16, .m32, .m64, .m80, .m128, .m256 => unreachable,
             .al, .cl, .r8, .rm8 => 8,
             .ax, .r16, .rm16 => 16,
             .eax, .r32, .rm32, .r32_m16 => 32,
             .rax, .r64, .rm64, .r64_m16 => 64,
             .xmm, .xmm_m32, .xmm_m64, .xmm_m128 => 128,
+            .ymm, .ymm_m256 => 256,
         };
     }
 
@@ -482,13 +492,14 @@ pub const Op = enum {
             .none, .o16, .o32, .o64, .moffs, .m, .sreg => unreachable,
             .unity, .imm8, .imm8s, .imm16, .imm16s, .imm32, .imm32s, .imm64 => unreachable,
             .rel8, .rel16, .rel32 => unreachable,
-            .al, .cl, .r8, .ax, .r16, .eax, .r32, .rax, .r64, .xmm => unreachable,
+            .al, .cl, .r8, .ax, .r16, .eax, .r32, .rax, .r64, .xmm, .ymm => unreachable,
             .m8, .rm8 => 8,
             .m16, .rm16, .r32_m16, .r64_m16 => 16,
             .m32, .rm32, .xmm_m32 => 32,
             .m64, .rm64, .xmm_m64 => 64,
             .m80 => 80,
             .m128, .xmm_m128 => 128,
+            .m256, .ymm_m256 => 256,
         };
     }
 
@@ -513,6 +524,7 @@ pub const Op = enum {
             .rm8, .rm16, .rm32, .rm64,
             .r32_m16, .r64_m16,
             .xmm, .xmm_m32, .xmm_m64, .xmm_m128,
+            .ymm, .ymm_m256,
             => true,
             else => false,
         };
@@ -539,7 +551,7 @@ pub const Op = enum {
             .r32_m16, .r64_m16,
             .m8, .m16, .m32, .m64, .m80, .m128,
             .m,
-            .xmm_m32, .xmm_m64, .xmm_m128,
+            .xmm_m32, .xmm_m64, .xmm_m128, .ymm_m256,
             =>  true,
             else => false,
         };
@@ -562,6 +574,7 @@ pub const Op = enum {
             .r32_m16, .r64_m16 => .general_purpose,
             .sreg => .segment,
             .xmm, .xmm_m32, .xmm_m64, .xmm_m128 => .floating_point,
+            .ymm, .ymm_m256 => .floating_point,
         };
     }
 
@@ -625,6 +638,7 @@ pub const Feature = enum {
     none,
     avx,
     f16c,
+    fma,
     sse,
     sse2,
     sse3,
src/arch/x86_64/encodings.zig
@@ -1016,5 +1016,28 @@ pub const table = [_]Entry{
     .{ .vcvtph2ps, .rm, &.{ .xmm, .xmm_m64 }, &.{ 0x66, 0x0f, 0x38, 0x13 }, 0, .vex_128, .f16c },
 
     .{ .vcvtps2ph, .mri, &.{ .xmm_m64, .xmm, .imm8 }, &.{ 0x66, 0x0f, 0x3a, 0x1d }, 0, .vex_128, .f16c },
+
+    // FMA
+    .{ .vfmadd132pd, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x38, 0x98 }, 0, .vex_128_long, .fma },
+    .{ .vfmadd132pd, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x38, 0x98 }, 0, .vex_256_long, .fma },
+    .{ .vfmadd213pd, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x38, 0xa8 }, 0, .vex_128_long, .fma },
+    .{ .vfmadd213pd, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x38, 0xa8 }, 0, .vex_256_long, .fma },
+    .{ .vfmadd231pd, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x38, 0xb8 }, 0, .vex_128_long, .fma },
+    .{ .vfmadd231pd, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x38, 0xb8 }, 0, .vex_256_long, .fma },
+
+    .{ .vfmadd132ps, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x38, 0x98 }, 0, .vex_128, .fma },
+    .{ .vfmadd132ps, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x38, 0x98 }, 0, .vex_256, .fma },
+    .{ .vfmadd213ps, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x38, 0xa8 }, 0, .vex_128, .fma },
+    .{ .vfmadd213ps, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x38, 0xa8 }, 0, .vex_256, .fma },
+    .{ .vfmadd231ps, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x38, 0xb8 }, 0, .vex_128, .fma },
+    .{ .vfmadd231ps, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x38, 0xb8 }, 0, .vex_256, .fma },
+
+    .{ .vfmadd132sd, .rvm, &.{ .xmm, .xmm, .xmm_m64 }, &.{ 0x66, 0x0f, 0x38, 0x99 }, 0, .vex_128_long, .fma },
+    .{ .vfmadd213sd, .rvm, &.{ .xmm, .xmm, .xmm_m64 }, &.{ 0x66, 0x0f, 0x38, 0xa9 }, 0, .vex_128_long, .fma },
+    .{ .vfmadd231sd, .rvm, &.{ .xmm, .xmm, .xmm_m64 }, &.{ 0x66, 0x0f, 0x38, 0xb9 }, 0, .vex_128_long, .fma },
+
+    .{ .vfmadd132ss, .rvm, &.{ .xmm, .xmm, .xmm_m32 }, &.{ 0x66, 0x0f, 0x38, 0x99 }, 0, .vex_128, .fma },
+    .{ .vfmadd213ss, .rvm, &.{ .xmm, .xmm, .xmm_m32 }, &.{ 0x66, 0x0f, 0x38, 0xa9 }, 0, .vex_128, .fma },
+    .{ .vfmadd231ss, .rvm, &.{ .xmm, .xmm, .xmm_m32 }, &.{ 0x66, 0x0f, 0x38, 0xb9 }, 0, .vex_128, .fma },
 };
 // zig fmt: on
src/arch/x86_64/Lower.zig
@@ -205,6 +205,19 @@ pub fn lowerMir(lower: *Lower, index: Mir.Inst.Index) Error!struct {
 
         .vcvtph2ps,
         .vcvtps2ph,
+
+        .vfmadd132pd,
+        .vfmadd213pd,
+        .vfmadd231pd,
+        .vfmadd132ps,
+        .vfmadd213ps,
+        .vfmadd231ps,
+        .vfmadd132sd,
+        .vfmadd213sd,
+        .vfmadd231sd,
+        .vfmadd132ss,
+        .vfmadd213ss,
+        .vfmadd231ss,
         => try lower.mirGeneric(inst),
 
         .cmps,
@@ -288,6 +301,8 @@ fn imm(lower: Lower, ops: Mir.Inst.Ops, i: u32) Immediate {
         .rmi_rip,
         .mri_sib,
         .mri_rip,
+        .rrm_sib,
+        .rrm_rip,
         .rrmi_sib,
         .rrmi_rip,
         => Immediate.u(i),
@@ -310,6 +325,7 @@ fn mem(lower: Lower, ops: Mir.Inst.Ops, payload: u32) Memory {
         .mr_sib,
         .mrr_sib,
         .mri_sib,
+        .rrm_sib,
         .rrmi_sib,
         .lock_m_sib,
         .lock_mi_sib_u,
@@ -327,6 +343,7 @@ fn mem(lower: Lower, ops: Mir.Inst.Ops, payload: u32) Memory {
         .mr_rip,
         .mrr_rip,
         .mri_rip,
+        .rrm_rip,
         .rrmi_rip,
         .lock_m_rip,
         .lock_mi_rip_u,
@@ -449,6 +466,11 @@ fn mirGeneric(lower: *Lower, inst: Mir.Inst) Error!void {
             .{ .reg = inst.data.rix.r },
             .{ .imm = lower.imm(inst.ops, inst.data.rix.i) },
         },
+        .rrm_sib, .rrm_rip => &.{
+            .{ .reg = inst.data.rrx.r1 },
+            .{ .reg = inst.data.rrx.r2 },
+            .{ .mem = lower.mem(inst.ops, inst.data.rrx.payload) },
+        },
         .rrmi_sib, .rrmi_rip => &.{
             .{ .reg = inst.data.rrix.r1 },
             .{ .reg = inst.data.rrix.r2 },
src/arch/x86_64/Mir.zig
@@ -324,6 +324,31 @@ pub const Inst = struct {
         /// Convert single-precision floating-point values to 16-bit floating-point values
         vcvtps2ph,
 
+        /// Fused multiply-add of packed double-precision floating-point values
+        vfmadd132pd,
+        /// Fused multiply-add of packed double-precision floating-point values
+        vfmadd213pd,
+        /// Fused multiply-add of packed double-precision floating-point values
+        vfmadd231pd,
+        /// Fused multiply-add of packed single-precision floating-point values
+        vfmadd132ps,
+        /// Fused multiply-add of packed single-precision floating-point values
+        vfmadd213ps,
+        /// Fused multiply-add of packed single-precision floating-point values
+        vfmadd231ps,
+        /// Fused multiply-add of scalar double-precision floating-point values
+        vfmadd132sd,
+        /// Fused multiply-add of scalar double-precision floating-point values
+        vfmadd213sd,
+        /// Fused multiply-add of scalar double-precision floating-point values
+        vfmadd231sd,
+        /// Fused multiply-add of scalar single-precision floating-point values
+        vfmadd132ss,
+        /// Fused multiply-add of scalar single-precision floating-point values
+        vfmadd213ss,
+        /// Fused multiply-add of scalar single-precision floating-point values
+        vfmadd231ss,
+
         /// Compare string operands
         cmps,
         /// Load string
@@ -434,6 +459,12 @@ pub const Inst = struct {
         /// Register, memory (SIB), immediate (byte) operands.
         /// Uses `rix` payload with extra data of type `MemorySib`.
         rmi_sib,
+        /// Register, register, memory (RIP).
+        /// Uses `rrix` payload with extra data of type `MemoryRip`.
+        rrm_rip,
+        /// Register, register, memory (SIB).
+        /// Uses `rrix` payload with extra data of type `MemorySib`.
+        rrm_sib,
         /// Register, register, memory (RIP), immediate (byte) operands.
         /// Uses `rrix` payload with extra data of type `MemoryRip`.
         rrmi_rip,
test/behavior/muladd.zig
@@ -1,8 +1,10 @@
+const std = @import("std");
 const builtin = @import("builtin");
-const expect = @import("std").testing.expect;
+const expect = std.testing.expect;
 
 test "@mulAdd" {
-    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64 and
+        !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .fma)) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO