Commit 5c5da179fb

Jacob Young <jacobly0@users.noreply.github.com>
2023-05-07 09:47:56
x86_64: implement `@sqrt` for vectors
1 parent 05580b9
src/arch/x86_64/CodeGen.zig
@@ -4520,25 +4520,69 @@ fn airRound(self: *Self, inst: Air.Inst.Index, mode: Immediate) !void {
 fn airSqrt(self: *Self, inst: Air.Inst.Index) !void {
     const un_op = self.air.instructions.items(.data)[inst].un_op;
     const ty = self.air.typeOf(un_op);
+    const abi_size = @intCast(u32, ty.abiSize(self.target.*));
 
     const src_mcv = try self.resolveInst(un_op);
     const dst_mcv = if (src_mcv.isRegister() and self.reuseOperand(inst, un_op, 0, src_mcv))
         src_mcv
     else
         try self.copyToRegisterWithInstTracking(inst, ty, src_mcv);
+    const dst_reg = registerAlias(dst_mcv.getReg().?, abi_size);
+    const dst_lock = self.register_manager.lockReg(dst_reg);
+    defer if (dst_lock) |lock| self.register_manager.unlockReg(lock);
 
-    try self.genBinOpMir(switch (ty.zigTypeTag()) {
-        .Float => switch (ty.floatBits(self.target.*)) {
-            32 => .sqrtss,
-            64 => .sqrtsd,
-            else => return self.fail("TODO implement airSqrt for {}", .{
-                ty.fmt(self.bin_file.options.module.?),
-            }),
+    const tag = if (@as(?Mir.Inst.Tag, switch (ty.zigTypeTag()) {
+        .Float => switch (ty.childType().floatBits(self.target.*)) {
+            32 => if (self.hasFeature(.avx)) .vsqrtss else .sqrtss,
+            64 => if (self.hasFeature(.avx)) .vsqrtsd else .sqrtsd,
+            16, 80, 128 => null,
+            else => unreachable,
         },
-        else => return self.fail("TODO implement airSqrt for {}", .{
-            ty.fmt(self.bin_file.options.module.?),
-        }),
-    }, ty, dst_mcv, src_mcv);
+        .Vector => switch (ty.childType().zigTypeTag()) {
+            .Float => switch (ty.childType().floatBits(self.target.*)) {
+                32 => switch (ty.vectorLen()) {
+                    1 => if (self.hasFeature(.avx)) .vsqrtss else .sqrtss,
+                    2...4 => if (self.hasFeature(.avx)) .vsqrtps else .sqrtps,
+                    5...8 => if (self.hasFeature(.avx)) .vsqrtps else null,
+                    else => null,
+                },
+                64 => switch (ty.vectorLen()) {
+                    1 => if (self.hasFeature(.avx)) .vsqrtsd else .sqrtsd,
+                    2 => if (self.hasFeature(.avx)) .vsqrtpd else .sqrtpd,
+                    3...4 => if (self.hasFeature(.avx)) .vsqrtpd else null,
+                    else => null,
+                },
+                16, 80, 128 => null,
+                else => unreachable,
+            },
+            else => unreachable,
+        },
+        else => unreachable,
+    })) |tag| tag else return self.fail("TODO implement airSqrt for {}", .{
+        ty.fmt(self.bin_file.options.module.?),
+    });
+    switch (tag) {
+        .vsqrtss, .vsqrtsd => if (src_mcv.isRegister()) try self.asmRegisterRegisterRegister(
+            tag,
+            dst_reg,
+            dst_reg,
+            registerAlias(src_mcv.getReg().?, abi_size),
+        ) else try self.asmRegisterRegisterMemory(
+            tag,
+            dst_reg,
+            dst_reg,
+            src_mcv.mem(Memory.PtrSize.fromSize(abi_size)),
+        ),
+        else => if (src_mcv.isRegister()) try self.asmRegisterRegister(
+            tag,
+            dst_reg,
+            registerAlias(src_mcv.getReg().?, abi_size),
+        ) else try self.asmRegisterMemory(
+            tag,
+            dst_reg,
+            src_mcv.mem(Memory.PtrSize.fromSize(abi_size)),
+        ),
+    }
     return self.finishAir(inst, dst_mcv, .{ un_op, .none, .none });
 }
 
@@ -9544,85 +9588,92 @@ fn airMulAdd(self: *Self, inst: Air.Inst.Index) !void {
         lock.* = self.register_manager.lockRegAssumeUnused(reg);
     }
 
-    const tag: ?Mir.Inst.Tag =
+    const tag = if (@as(
+        ?Mir.Inst.Tag,
         if (mem.eql(u2, &order, &.{ 1, 3, 2 }) or mem.eql(u2, &order, &.{ 3, 1, 2 }))
-        switch (ty.zigTypeTag()) {
-            .Float => switch (ty.floatBits(self.target.*)) {
-                32 => .vfmadd132ss,
-                64 => .vfmadd132sd,
-                else => null,
-            },
-            .Vector => switch (ty.childType().zigTypeTag()) {
-                .Float => switch (ty.childType().floatBits(self.target.*)) {
-                    32 => switch (ty.vectorLen()) {
-                        1 => .vfmadd132ss,
-                        2...8 => .vfmadd132ps,
-                        else => null,
-                    },
-                    64 => switch (ty.vectorLen()) {
-                        1 => .vfmadd132sd,
-                        2...4 => .vfmadd132pd,
-                        else => null,
-                    },
-                    else => null,
+            switch (ty.zigTypeTag()) {
+                .Float => switch (ty.floatBits(self.target.*)) {
+                    32 => .vfmadd132ss,
+                    64 => .vfmadd132sd,
+                    16, 80, 128 => null,
+                    else => unreachable,
                 },
-                else => null,
-            },
-            else => unreachable,
-        }
-    else if (mem.eql(u2, &order, &.{ 2, 1, 3 }) or mem.eql(u2, &order, &.{ 1, 2, 3 }))
-        switch (ty.zigTypeTag()) {
-            .Float => switch (ty.floatBits(self.target.*)) {
-                32 => .vfmadd213ss,
-                64 => .vfmadd213sd,
-                else => null,
-            },
-            .Vector => switch (ty.childType().zigTypeTag()) {
-                .Float => switch (ty.childType().floatBits(self.target.*)) {
-                    32 => switch (ty.vectorLen()) {
-                        1 => .vfmadd213ss,
-                        2...8 => .vfmadd213ps,
-                        else => null,
-                    },
-                    64 => switch (ty.vectorLen()) {
-                        1 => .vfmadd213sd,
-                        2...4 => .vfmadd213pd,
-                        else => null,
+                .Vector => switch (ty.childType().zigTypeTag()) {
+                    .Float => switch (ty.childType().floatBits(self.target.*)) {
+                        32 => switch (ty.vectorLen()) {
+                            1 => .vfmadd132ss,
+                            2...8 => .vfmadd132ps,
+                            else => null,
+                        },
+                        64 => switch (ty.vectorLen()) {
+                            1 => .vfmadd132sd,
+                            2...4 => .vfmadd132pd,
+                            else => null,
+                        },
+                        16, 80, 128 => null,
+                        else => unreachable,
                     },
-                    else => null,
+                    else => unreachable,
                 },
-                else => null,
-            },
-            else => unreachable,
-        }
-    else if (mem.eql(u2, &order, &.{ 2, 3, 1 }) or mem.eql(u2, &order, &.{ 3, 2, 1 }))
-        switch (ty.zigTypeTag()) {
-            .Float => switch (ty.floatBits(self.target.*)) {
-                32 => .vfmadd231ss,
-                64 => .vfmadd231sd,
-                else => null,
-            },
-            .Vector => switch (ty.childType().zigTypeTag()) {
-                .Float => switch (ty.childType().floatBits(self.target.*)) {
-                    32 => switch (ty.vectorLen()) {
-                        1 => .vfmadd231ss,
-                        2...8 => .vfmadd231ps,
-                        else => null,
+                else => unreachable,
+            }
+        else if (mem.eql(u2, &order, &.{ 2, 1, 3 }) or mem.eql(u2, &order, &.{ 1, 2, 3 }))
+            switch (ty.zigTypeTag()) {
+                .Float => switch (ty.floatBits(self.target.*)) {
+                    32 => .vfmadd213ss,
+                    64 => .vfmadd213sd,
+                    16, 80, 128 => null,
+                    else => unreachable,
+                },
+                .Vector => switch (ty.childType().zigTypeTag()) {
+                    .Float => switch (ty.childType().floatBits(self.target.*)) {
+                        32 => switch (ty.vectorLen()) {
+                            1 => .vfmadd213ss,
+                            2...8 => .vfmadd213ps,
+                            else => null,
+                        },
+                        64 => switch (ty.vectorLen()) {
+                            1 => .vfmadd213sd,
+                            2...4 => .vfmadd213pd,
+                            else => null,
+                        },
+                        16, 80, 128 => null,
+                        else => unreachable,
                     },
-                    64 => switch (ty.vectorLen()) {
-                        1 => .vfmadd231sd,
-                        2...4 => .vfmadd231pd,
-                        else => null,
+                    else => unreachable,
+                },
+                else => unreachable,
+            }
+        else if (mem.eql(u2, &order, &.{ 2, 3, 1 }) or mem.eql(u2, &order, &.{ 3, 2, 1 }))
+            switch (ty.zigTypeTag()) {
+                .Float => switch (ty.floatBits(self.target.*)) {
+                    32 => .vfmadd231ss,
+                    64 => .vfmadd231sd,
+                    16, 80, 128 => null,
+                    else => unreachable,
+                },
+                .Vector => switch (ty.childType().zigTypeTag()) {
+                    .Float => switch (ty.childType().floatBits(self.target.*)) {
+                        32 => switch (ty.vectorLen()) {
+                            1 => .vfmadd231ss,
+                            2...8 => .vfmadd231ps,
+                            else => null,
+                        },
+                        64 => switch (ty.vectorLen()) {
+                            1 => .vfmadd231sd,
+                            2...4 => .vfmadd231pd,
+                            else => null,
+                        },
+                        16, 80, 128 => null,
+                        else => unreachable,
                     },
-                    else => null,
+                    else => unreachable,
                 },
-                else => null,
-            },
-            else => null,
-        }
-    else
-        unreachable;
-    if (tag == null) return self.fail("TODO implement airMulAdd for {}", .{
+                else => unreachable,
+            }
+        else
+            unreachable,
+    )) |tag| tag else return self.fail("TODO implement airMulAdd for {}", .{
         ty.fmt(self.bin_file.options.module.?),
     });
 
@@ -9634,14 +9685,14 @@ fn airMulAdd(self: *Self, inst: Air.Inst.Index) !void {
     const mop2_reg = registerAlias(mops[1].getReg().?, abi_size);
     if (mops[2].isRegister())
         try self.asmRegisterRegisterRegister(
-            tag.?,
+            tag,
             mop1_reg,
             mop2_reg,
             registerAlias(mops[2].getReg().?, abi_size),
         )
     else
         try self.asmRegisterRegisterMemory(
-            tag.?,
+            tag,
             mop1_reg,
             mop2_reg,
             mops[2].mem(Memory.PtrSize.fromSize(abi_size)),
src/arch/x86_64/Encoding.zig
@@ -316,6 +316,7 @@ pub const Mnemonic = enum {
     vpsrld, vpsrlq, vpsrlw,
     vpunpckhbw, vpunpckhdq, vpunpckhqdq, vpunpckhwd,
     vpunpcklbw, vpunpckldq, vpunpcklqdq, vpunpcklwd,
+    vsqrtpd, vsqrtps, vsqrtsd, vsqrtss,
     // F16C
     vcvtph2ps, vcvtps2ph,
     // FMA
src/arch/x86_64/encodings.zig
@@ -869,8 +869,9 @@ pub const table = [_]Entry{
 
     .{ .subss, .rm, &.{ .xmm, .xmm_m32 }, &.{ 0xf3, 0x0f, 0x5c }, 0, .none, .sse },
 
-    .{ .sqrtps, .rm, &.{ .xmm, .xmm_m128 }, &.{       0x0f, 0x51 }, 0, .none, .sse },
-    .{ .sqrtss, .rm, &.{ .xmm, .xmm_m32  }, &.{ 0xf3, 0x0f, 0x51 }, 0, .none, .sse },
+    .{ .sqrtps, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x0f, 0x51 }, 0, .none, .sse },
+
+    .{ .sqrtss, .rm, &.{ .xmm, .xmm_m32 }, &.{ 0xf3, 0x0f, 0x51 }, 0, .none, .sse },
 
     .{ .ucomiss, .rm, &.{ .xmm, .xmm_m32 }, &.{ 0x0f, 0x2e }, 0, .none, .sse },
 
@@ -943,7 +944,8 @@ pub const table = [_]Entry{
     .{ .punpcklqdq, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x6c }, 0, .none, .sse2 },
 
     .{ .sqrtpd, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x51 }, 0, .none, .sse2 },
-    .{ .sqrtsd, .rm, &.{ .xmm, .xmm_m64  }, &.{ 0xf2, 0x0f, 0x51 }, 0, .none, .sse2 },
+
+    .{ .sqrtsd, .rm, &.{ .xmm, .xmm_m64 }, &.{ 0xf2, 0x0f, 0x51 }, 0, .none, .sse2 },
 
     .{ .subsd, .rm, &.{ .xmm, .xmm_m64 }, &.{ 0xf2, 0x0f, 0x5c }, 0, .none, .sse2 },
 
@@ -1039,6 +1041,16 @@ pub const table = [_]Entry{
     .{ .vpunpckldq,  .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x62 }, 0, .vex_128_wig, .avx },
     .{ .vpunpcklqdq, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x6c }, 0, .vex_128_wig, .avx },
 
+    .{ .vsqrtpd, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x51 }, 0, .vex_128_wig, .avx },
+    .{ .vsqrtpd, .rm, &.{ .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x51 }, 0, .vex_256_wig, .avx },
+
+    .{ .vsqrtps, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x0f, 0x51 }, 0, .vex_128_wig, .avx },
+    .{ .vsqrtps, .rm, &.{ .ymm, .ymm_m256 }, &.{ 0x0f, 0x51 }, 0, .vex_256_wig, .avx },
+
+    .{ .vsqrtsd, .rvm, &.{ .xmm, .xmm, .xmm_m64 }, &.{ 0xf2, 0x0f }, 0, .vex_lig_wig, .avx },
+
+    .{ .vsqrtss, .rvm, &.{ .xmm, .xmm, .xmm_m32 }, &.{ 0xf3, 0x0f }, 0, .vex_lig_wig, .avx },
+
     // F16C
     .{ .vcvtph2ps, .rm, &.{ .xmm, .xmm_m64  }, &.{ 0x66, 0x0f, 0x38, 0x13 }, 0, .vex_128_w0, .f16c },
     .{ .vcvtph2ps, .rm, &.{ .ymm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x38, 0x13 }, 0, .vex_256_w0, .f16c },
src/arch/x86_64/Lower.zig
@@ -212,6 +212,10 @@ pub fn lowerMir(lower: *Lower, index: Mir.Inst.Index) Error!struct {
         .vpunpckldq,
         .vpunpcklqdq,
         .vpunpcklwd,
+        .vsqrtpd,
+        .vsqrtps,
+        .vsqrtsd,
+        .vsqrtss,
 
         .vcvtph2ps,
         .vcvtps2ph,
src/arch/x86_64/Mir.zig
@@ -338,6 +338,14 @@ pub const Inst = struct {
         vpunpcklqdq,
         /// Unpack low data
         vpunpcklwd,
+        /// Square root of packed double-precision floating-point value
+        vsqrtpd,
+        /// Square root of packed single-precision floating-point value
+        vsqrtps,
+        /// Square root of scalar double-precision floating-point value
+        vsqrtsd,
+        /// Square root of scalar single-precision floating-point value
+        vsqrtss,
 
         /// Convert 16-bit floating-point values to single-precision floating-point values
         vcvtph2ps,