Commit ea957c4cff

Jacob Young <jacobly0@users.noreply.github.com>
2023-05-07 11:01:37
x86_64: implement `@sqrt` for `f16` scalars and vectors
1 parent 5c5da17
Changed files (3)
src
test
behavior
src/arch/x86_64/CodeGen.zig
@@ -4531,59 +4531,117 @@ fn airSqrt(self: *Self, inst: Air.Inst.Index) !void {
     const dst_lock = self.register_manager.lockReg(dst_reg);
     defer if (dst_lock) |lock| self.register_manager.unlockReg(lock);
 
-    const tag = if (@as(?Mir.Inst.Tag, switch (ty.zigTypeTag()) {
-        .Float => switch (ty.childType().floatBits(self.target.*)) {
-            32 => if (self.hasFeature(.avx)) .vsqrtss else .sqrtss,
-            64 => if (self.hasFeature(.avx)) .vsqrtsd else .sqrtsd,
-            16, 80, 128 => null,
-            else => unreachable,
-        },
-        .Vector => switch (ty.childType().zigTypeTag()) {
-            .Float => switch (ty.childType().floatBits(self.target.*)) {
-                32 => switch (ty.vectorLen()) {
-                    1 => if (self.hasFeature(.avx)) .vsqrtss else .sqrtss,
-                    2...4 => if (self.hasFeature(.avx)) .vsqrtps else .sqrtps,
-                    5...8 => if (self.hasFeature(.avx)) .vsqrtps else null,
-                    else => null,
-                },
-                64 => switch (ty.vectorLen()) {
-                    1 => if (self.hasFeature(.avx)) .vsqrtsd else .sqrtsd,
-                    2 => if (self.hasFeature(.avx)) .vsqrtpd else .sqrtpd,
-                    3...4 => if (self.hasFeature(.avx)) .vsqrtpd else null,
-                    else => null,
+    const result: MCValue = result: {
+        const tag = if (@as(?Mir.Inst.Tag, switch (ty.zigTypeTag()) {
+            .Float => switch (ty.floatBits(self.target.*)) {
+                16 => if (self.hasFeature(.f16c)) {
+                    const mat_src_reg = if (src_mcv.isRegister())
+                        src_mcv.getReg().?
+                    else
+                        try self.copyToTmpRegister(ty, src_mcv);
+                    try self.asmRegisterRegister(.vcvtph2ps, dst_reg, mat_src_reg.to128());
+                    try self.asmRegisterRegisterRegister(.vsqrtss, dst_reg, dst_reg, dst_reg);
+                    try self.asmRegisterRegisterImmediate(
+                        .vcvtps2ph,
+                        dst_reg,
+                        dst_reg,
+                        Immediate.u(0b1_00),
+                    );
+                    break :result dst_mcv;
+                } else null,
+                32 => if (self.hasFeature(.avx)) .vsqrtss else .sqrtss,
+                64 => if (self.hasFeature(.avx)) .vsqrtsd else .sqrtsd,
+                80, 128 => null,
+                else => unreachable,
+            },
+            .Vector => switch (ty.childType().zigTypeTag()) {
+                .Float => switch (ty.childType().floatBits(self.target.*)) {
+                    16 => if (self.hasFeature(.f16c)) switch (ty.vectorLen()) {
+                        1 => {
+                            const mat_src_reg = if (src_mcv.isRegister())
+                                src_mcv.getReg().?
+                            else
+                                try self.copyToTmpRegister(ty, src_mcv);
+                            try self.asmRegisterRegister(.vcvtph2ps, dst_reg, mat_src_reg.to128());
+                            try self.asmRegisterRegisterRegister(.vsqrtss, dst_reg, dst_reg, dst_reg);
+                            try self.asmRegisterRegisterImmediate(
+                                .vcvtps2ph,
+                                dst_reg,
+                                dst_reg,
+                                Immediate.u(0b1_00),
+                            );
+                            break :result dst_mcv;
+                        },
+                        2...8 => {
+                            const wide_reg = registerAlias(dst_reg, abi_size * 2);
+                            if (src_mcv.isRegister()) try self.asmRegisterRegister(
+                                .vcvtph2ps,
+                                wide_reg,
+                                src_mcv.getReg().?.to128(),
+                            ) else try self.asmRegisterMemory(
+                                .vcvtph2ps,
+                                wide_reg,
+                                src_mcv.mem(Memory.PtrSize.fromSize(
+                                    @intCast(u32, @divExact(wide_reg.bitSize(), 16)),
+                                )),
+                            );
+                            try self.asmRegisterRegister(.vsqrtps, wide_reg, wide_reg);
+                            try self.asmRegisterRegisterImmediate(
+                                .vcvtps2ph,
+                                dst_reg,
+                                wide_reg,
+                                Immediate.u(0b1_00),
+                            );
+                            break :result dst_mcv;
+                        },
+                        else => null,
+                    } else null,
+                    32 => switch (ty.vectorLen()) {
+                        1 => if (self.hasFeature(.avx)) .vsqrtss else .sqrtss,
+                        2...4 => if (self.hasFeature(.avx)) .vsqrtps else .sqrtps,
+                        5...8 => if (self.hasFeature(.avx)) .vsqrtps else null,
+                        else => null,
+                    },
+                    64 => switch (ty.vectorLen()) {
+                        1 => if (self.hasFeature(.avx)) .vsqrtsd else .sqrtsd,
+                        2 => if (self.hasFeature(.avx)) .vsqrtpd else .sqrtpd,
+                        3...4 => if (self.hasFeature(.avx)) .vsqrtpd else null,
+                        else => null,
+                    },
+                    80, 128 => null,
+                    else => unreachable,
                 },
-                16, 80, 128 => null,
                 else => unreachable,
             },
             else => unreachable,
-        },
-        else => unreachable,
-    })) |tag| tag else return self.fail("TODO implement airSqrt for {}", .{
-        ty.fmt(self.bin_file.options.module.?),
-    });
-    switch (tag) {
-        .vsqrtss, .vsqrtsd => if (src_mcv.isRegister()) try self.asmRegisterRegisterRegister(
-            tag,
-            dst_reg,
-            dst_reg,
-            registerAlias(src_mcv.getReg().?, abi_size),
-        ) else try self.asmRegisterRegisterMemory(
-            tag,
-            dst_reg,
-            dst_reg,
-            src_mcv.mem(Memory.PtrSize.fromSize(abi_size)),
-        ),
-        else => if (src_mcv.isRegister()) try self.asmRegisterRegister(
-            tag,
-            dst_reg,
-            registerAlias(src_mcv.getReg().?, abi_size),
-        ) else try self.asmRegisterMemory(
-            tag,
-            dst_reg,
-            src_mcv.mem(Memory.PtrSize.fromSize(abi_size)),
-        ),
-    }
-    return self.finishAir(inst, dst_mcv, .{ un_op, .none, .none });
+        })) |tag| tag else return self.fail("TODO implement airSqrt for {}", .{
+            ty.fmt(self.bin_file.options.module.?),
+        });
+        switch (tag) {
+            .vsqrtss, .vsqrtsd => if (src_mcv.isRegister()) try self.asmRegisterRegisterRegister(
+                tag,
+                dst_reg,
+                dst_reg,
+                registerAlias(src_mcv.getReg().?, abi_size),
+            ) else try self.asmRegisterRegisterMemory(
+                tag,
+                dst_reg,
+                dst_reg,
+                src_mcv.mem(Memory.PtrSize.fromSize(abi_size)),
+            ),
+            else => if (src_mcv.isRegister()) try self.asmRegisterRegister(
+                tag,
+                dst_reg,
+                registerAlias(src_mcv.getReg().?, abi_size),
+            ) else try self.asmRegisterMemory(
+                tag,
+                dst_reg,
+                src_mcv.mem(Memory.PtrSize.fromSize(abi_size)),
+            ),
+        }
+        break :result dst_mcv;
+    };
+    return self.finishAir(inst, result, .{ un_op, .none, .none });
 }
 
 fn airUnaryMath(self: *Self, inst: Air.Inst.Index) !void {
src/arch/x86_64/encodings.zig
@@ -1047,9 +1047,9 @@ pub const table = [_]Entry{
     .{ .vsqrtps, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x0f, 0x51 }, 0, .vex_128_wig, .avx },
     .{ .vsqrtps, .rm, &.{ .ymm, .ymm_m256 }, &.{ 0x0f, 0x51 }, 0, .vex_256_wig, .avx },
 
-    .{ .vsqrtsd, .rvm, &.{ .xmm, .xmm, .xmm_m64 }, &.{ 0xf2, 0x0f }, 0, .vex_lig_wig, .avx },
+    .{ .vsqrtsd, .rvm, &.{ .xmm, .xmm, .xmm_m64 }, &.{ 0xf2, 0x0f, 0x51 }, 0, .vex_lig_wig, .avx },
 
-    .{ .vsqrtss, .rvm, &.{ .xmm, .xmm, .xmm_m32 }, &.{ 0xf3, 0x0f }, 0, .vex_lig_wig, .avx },
+    .{ .vsqrtss, .rvm, &.{ .xmm, .xmm, .xmm_m32 }, &.{ 0xf3, 0x0f, 0x51 }, 0, .vex_lig_wig, .avx },
 
     // F16C
     .{ .vcvtph2ps, .rm, &.{ .xmm, .xmm_m64  }, &.{ 0x66, 0x0f, 0x38, 0x13 }, 0, .vex_128_w0, .f16c },
test/behavior/floatop.zig
@@ -135,7 +135,6 @@ fn testSqrt() !void {
 
 test "@sqrt with vectors" {
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO