Commit 35da95fe87

Jacob Young <jacobly0@users.noreply.github.com>
2023-05-17 06:23:11
x86_64: implement integer vector `@truncate`
1 parent 28c445a
Changed files (5)
src/arch/x86_64/CodeGen.zig
@@ -2709,28 +2709,112 @@ fn airTrunc(self: *Self, inst: Air.Inst.Index) !void {
     const ty_op = self.air.instructions.items(.data)[inst].ty_op;
 
     const dst_ty = self.air.typeOfIndex(inst);
-    const dst_abi_size = dst_ty.abiSize(self.target.*);
-    if (dst_abi_size > 8) {
-        return self.fail("TODO implement trunc for abi sizes larger than 8", .{});
-    }
+    const dst_abi_size = @intCast(u32, dst_ty.abiSize(self.target.*));
+    const src_ty = self.air.typeOf(ty_op.operand);
+    const src_abi_size = @intCast(u32, src_ty.abiSize(self.target.*));
 
-    const src_mcv = try self.resolveInst(ty_op.operand);
-    const src_lock = switch (src_mcv) {
-        .register => |reg| self.register_manager.lockRegAssumeUnused(reg),
-        else => null,
-    };
-    defer if (src_lock) |lock| self.register_manager.unlockReg(lock);
+    const result = result: {
+        const src_mcv = try self.resolveInst(ty_op.operand);
+        const src_lock =
+            if (src_mcv.getReg()) |reg| self.register_manager.lockRegAssumeUnused(reg) else null;
+        defer if (src_lock) |lock| self.register_manager.unlockReg(lock);
 
-    const dst_mcv = if (src_mcv.isRegister() and self.reuseOperand(inst, ty_op.operand, 0, src_mcv))
-        src_mcv
-    else
-        try self.copyToRegisterWithInstTracking(inst, dst_ty, src_mcv);
+        const dst_mcv = if (src_mcv.isRegister() and self.reuseOperand(inst, ty_op.operand, 0, src_mcv))
+            src_mcv
+        else
+            try self.copyToRegisterWithInstTracking(inst, dst_ty, src_mcv);
+
+        if (dst_ty.zigTypeTag() == .Vector) {
+            assert(src_ty.zigTypeTag() == .Vector and dst_ty.vectorLen() == src_ty.vectorLen());
+            const dst_info = dst_ty.childType().intInfo(self.target.*);
+            const src_info = src_ty.childType().intInfo(self.target.*);
+            const mir_tag = if (@as(?Mir.Inst.FixedTag, switch (dst_info.bits) {
+                8 => switch (src_info.bits) {
+                    16 => switch (dst_ty.vectorLen()) {
+                        1...8 => if (self.hasFeature(.avx)) .{ .vp_b, .ackusw } else .{ .p_b, .ackusw },
+                        9...16 => if (self.hasFeature(.avx2)) .{ .vp_b, .ackusw } else null,
+                        else => null,
+                    },
+                    else => null,
+                },
+                16 => switch (src_info.bits) {
+                    32 => switch (dst_ty.vectorLen()) {
+                        1...4 => if (self.hasFeature(.avx))
+                            .{ .vp_w, .ackusd }
+                        else if (self.hasFeature(.sse4_1))
+                            .{ .p_w, .ackusd }
+                        else
+                            null,
+                        5...8 => if (self.hasFeature(.avx2)) .{ .vp_w, .ackusd } else null,
+                        else => null,
+                    },
+                    else => null,
+                },
+                else => null,
+            })) |tag| tag else return self.fail("TODO implement airTrunc for {}", .{
+                dst_ty.fmt(self.bin_file.options.module.?),
+            });
 
-    // when truncating a `u16` to `u5`, for example, those top 3 bits in the result
-    // have to be removed. this only happens if the dst if not a power-of-two size.
-    if (self.regExtraBits(dst_ty) > 0) try self.truncateRegister(dst_ty, dst_mcv.register.to64());
+            var mask_pl = Value.Payload.U64{
+                .base = .{ .tag = .int_u64 },
+                .data = @as(u64, math.maxInt(u64)) >> @intCast(u6, 64 - dst_info.bits),
+            };
+            const mask_val = Value.initPayload(&mask_pl.base);
 
-    return self.finishAir(inst, dst_mcv, .{ ty_op.operand, .none, .none });
+            var splat_pl = Value.Payload.SubValue{
+                .base = .{ .tag = .repeated },
+                .data = mask_val,
+            };
+            const splat_val = Value.initPayload(&splat_pl.base);
+
+            var full_pl = Type.Payload.Array{
+                .base = .{ .tag = .vector },
+                .data = .{
+                    .len = @divExact(@as(u64, if (src_abi_size > 16) 256 else 128), src_info.bits),
+                    .elem_type = src_ty.childType(),
+                },
+            };
+            const full_ty = Type.initPayload(&full_pl.base);
+            const full_abi_size = @intCast(u32, full_ty.abiSize(self.target.*));
+
+            const splat_mcv = try self.genTypedValue(.{ .ty = full_ty, .val = splat_val });
+            const splat_addr_mcv: MCValue = switch (splat_mcv) {
+                .memory, .indirect, .load_frame => splat_mcv.address(),
+                else => .{ .register = try self.copyToTmpRegister(Type.usize, splat_mcv.address()) },
+            };
+
+            const dst_reg = registerAlias(dst_mcv.getReg().?, src_abi_size);
+            if (self.hasFeature(.avx)) {
+                try self.asmRegisterRegisterMemory(
+                    .{ .vp_, .@"and" },
+                    dst_reg,
+                    dst_reg,
+                    splat_addr_mcv.deref().mem(Memory.PtrSize.fromSize(full_abi_size)),
+                );
+                try self.asmRegisterRegisterRegister(mir_tag, dst_reg, dst_reg, dst_reg);
+            } else {
+                try self.asmRegisterMemory(
+                    .{ .p_, .@"and" },
+                    dst_reg,
+                    splat_addr_mcv.deref().mem(Memory.PtrSize.fromSize(full_abi_size)),
+                );
+                try self.asmRegisterRegister(mir_tag, dst_reg, dst_reg);
+            }
+            break :result dst_mcv;
+        }
+
+        if (dst_abi_size > 8) {
+            return self.fail("TODO implement trunc for abi sizes larger than 8", .{});
+        }
+
+        // when truncating a `u16` to `u5`, for example, those top 3 bits in the result
+        // have to be removed. this only happens if the dst if not a power-of-two size.
+        if (self.regExtraBits(dst_ty) > 0)
+            try self.truncateRegister(dst_ty, dst_mcv.register.to64());
+
+        break :result dst_mcv;
+    };
+    return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
 }
 
 fn airBoolToInt(self: *Self, inst: Air.Inst.Index) !void {
@@ -11081,8 +11165,8 @@ fn airSelect(self: *Self, inst: Air.Inst.Index) !void {
 }
 
 fn airShuffle(self: *Self, inst: Air.Inst.Index) !void {
-    const ty_op = self.air.instructions.items(.data)[inst].ty_op;
-    _ = ty_op;
+    const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
+    _ = ty_pl;
     return self.fail("TODO implement airShuffle for x86_64", .{});
     //return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
 }
src/arch/x86_64/Encoding.zig
@@ -263,6 +263,7 @@ pub const Mnemonic = enum {
     fisttp, fld,
     // MMX
     movd, movq,
+    packssdw, packsswb, packuswb,
     paddb, paddd, paddq, paddsb, paddsw, paddusb, paddusw, paddw,
     pand, pandn, por, pxor,
     pmulhw, pmullw,
@@ -319,6 +320,7 @@ pub const Mnemonic = enum {
     blendpd, blendps, blendvpd, blendvps,
     extractps,
     insertps,
+    packusdw,
     pextrb, pextrd, pextrq,
     pinsrb, pinsrd, pinsrq,
     pmaxsb, pmaxsd, pmaxud, pmaxuw, pminsb, pminsd, pminud, pminuw,
@@ -351,6 +353,7 @@ pub const Mnemonic = enum {
     vmovupd, vmovups,
     vmulpd, vmulps, vmulsd, vmulss,
     vorpd, vorps,
+    vpackssdw, vpacksswb, vpackusdw, vpackuswb,
     vpaddb, vpaddd, vpaddq, vpaddsb, vpaddsw, vpaddusb, vpaddusw, vpaddw,
     vpand, vpandn,
     vpextrb, vpextrd, vpextrq, vpextrw,
src/arch/x86_64/encodings.zig
@@ -996,6 +996,11 @@ pub const table = [_]Entry{
 
     .{ .orpd, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x56 }, 0, .none, .sse2 },
 
+    .{ .packsswb, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x63 }, 0, .none, .sse2 },
+    .{ .packssdw, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x6b }, 0, .none, .sse2 },
+
+    .{ .packuswb, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x67 }, 0, .none, .sse2 },
+
     .{ .paddb, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0xfc }, 0, .none, .sse2 },
     .{ .paddw, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0xfd }, 0, .none, .sse2 },
     .{ .paddd, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0xfe }, 0, .none, .sse2 },
@@ -1101,6 +1106,8 @@ pub const table = [_]Entry{
 
     .{ .insertps, .rmi, &.{ .xmm, .xmm_m32, .imm8 }, &.{ 0x66, 0x0f, 0x3a, 0x21 }, 0, .none, .sse4_1 },
 
+    .{ .packusdw, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x38, 0x2b }, 0, .none, .sse4_1 },
+
     .{ .pextrb, .mri, &.{ .r32_m8, .xmm, .imm8 }, &.{ 0x66, 0x0f, 0x3a, 0x14 }, 0, .none, .sse4_1 },
     .{ .pextrd, .mri, &.{ .rm32,   .xmm, .imm8 }, &.{ 0x66, 0x0f, 0x3a, 0x16 }, 0, .none, .sse4_1 },
     .{ .pextrq, .mri, &.{ .rm64,   .xmm, .imm8 }, &.{ 0x66, 0x0f, 0x3a, 0x16 }, 0, .long, .sse4_1 },
@@ -1346,6 +1353,13 @@ pub const table = [_]Entry{
     .{ .vorps, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x0f, 0x56 }, 0, .vex_128_wig, .avx },
     .{ .vorps, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x0f, 0x56 }, 0, .vex_256_wig, .avx },
 
+    .{ .vpacksswb, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x63 }, 0, .vex_128_wig, .avx },
+    .{ .vpackssdw, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x6b }, 0, .vex_128_wig, .avx },
+
+    .{ .vpackusdw, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x38, 0x2b }, 0, .vex_128_wig, .avx },
+
+    .{ .vpackuswb, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x67 }, 0, .vex_128_wig, .avx },
+
     .{ .vpaddb, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0xfc }, 0, .vex_128_wig, .avx },
     .{ .vpaddw, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0xfd }, 0, .vex_128_wig, .avx },
     .{ .vpaddd, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0xfe }, 0, .vex_128_wig, .avx },
@@ -1508,6 +1522,13 @@ pub const table = [_]Entry{
     .{ .vbroadcastss, .rm, &.{ .ymm, .xmm }, &.{ 0x66, 0x0f, 0x38, 0x18 }, 0, .vex_256_w0, .avx2 },
     .{ .vbroadcastsd, .rm, &.{ .ymm, .xmm }, &.{ 0x66, 0x0f, 0x38, 0x19 }, 0, .vex_256_w0, .avx2 },
 
+    .{ .vpacksswb, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x63 }, 0, .vex_256_wig, .avx2 },
+    .{ .vpackssdw, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x6b }, 0, .vex_256_wig, .avx2 },
+
+    .{ .vpackusdw, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x38, 0x2b }, 0, .vex_256_wig, .avx2 },
+
+    .{ .vpackuswb, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x67 }, 0, .vex_256_wig, .avx2 },
+
     .{ .vpaddb, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0xfc }, 0, .vex_256_wig, .avx2 },
     .{ .vpaddw, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0xfd }, 0, .vex_256_wig, .avx2 },
     .{ .vpaddd, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0xfe }, 0, .vex_256_wig, .avx2 },
src/arch/x86_64/Mir.zig
@@ -446,6 +446,12 @@ pub const Inst = struct {
         /// Bitwise logical xor of packed double-precision floating-point values
         xor,
 
+        /// Pack with signed saturation
+        ackssw,
+        /// Pack with signed saturation
+        ackssd,
+        /// Pack with unsigned saturation
+        ackusw,
         /// Add packed signed integers with signed saturation
         adds,
         /// Add packed unsigned integers with unsigned saturation
@@ -596,6 +602,8 @@ pub const Inst = struct {
         /// Replicate single floating-point values
         movsldup,
 
+        /// Pack with unsigned saturation
+        ackusd,
         /// Blend packed single-precision floating-point values
         /// Blend scalar single-precision floating-point values
         /// Blend packed double-precision floating-point values
test/behavior/truncate.zig
@@ -61,7 +61,6 @@ test "truncate on comptime integer" {
 
 test "truncate on vectors" {
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO