Commit f8708e2c4d

Jacob Young <jacobly0@users.noreply.github.com>
2023-05-07 16:04:56
x86_64: implement `@floor`, `@ceil`, and `@trunc` for float vectors
1 parent 057139f
Changed files (2)
src
arch
test
behavior
src/arch/x86_64/CodeGen.zig
@@ -1587,9 +1587,9 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             .round,
             => try self.airUnaryMath(inst),
 
-            .floor => try self.airRound(inst, Immediate.u(0b1_0_01)),
-            .ceil => try self.airRound(inst, Immediate.u(0b1_0_10)),
-            .trunc_float => try self.airRound(inst, Immediate.u(0b1_0_11)),
+            .floor => try self.airRound(inst, 0b1_0_01),
+            .ceil => try self.airRound(inst, 0b1_0_10),
+            .trunc_float => try self.airRound(inst, 0b1_0_11),
             .sqrt => try self.airSqrt(inst),
             .neg, .fabs => try self.airFloatSign(inst),
 
@@ -4509,49 +4509,91 @@ fn airFloatSign(self: *Self, inst: Air.Inst.Index) !void {
     return self.finishAir(inst, dst_mcv, .{ un_op, .none, .none });
 }
 
-fn airRound(self: *Self, inst: Air.Inst.Index, mode: Immediate) !void {
+fn airRound(self: *Self, inst: Air.Inst.Index, mode: u4) !void {
     const un_op = self.air.instructions.items(.data)[inst].un_op;
     const ty = self.air.typeOf(un_op);
 
-    if (!self.hasFeature(.sse4_1))
-        return self.fail("TODO implement airRound without sse4_1 feature", .{});
-
     const src_mcv = try self.resolveInst(un_op);
     const dst_mcv = if (src_mcv.isRegister() and self.reuseOperand(inst, un_op, 0, src_mcv))
         src_mcv
     else
         try self.copyToRegisterWithInstTracking(inst, ty, src_mcv);
+    const dst_reg = dst_mcv.getReg().?;
+    const dst_lock = self.register_manager.lockReg(dst_reg);
+    defer if (dst_lock) |lock| self.register_manager.unlockReg(lock);
+    try self.genRound(ty, dst_reg, src_mcv, mode);
+    return self.finishAir(inst, dst_mcv, .{ un_op, .none, .none });
+}
 
-    const mir_tag: Mir.Inst.Tag = switch (ty.zigTypeTag()) {
+fn genRound(self: *Self, ty: Type, dst_reg: Register, src_mcv: MCValue, mode: u4) !void {
+    if (!self.hasFeature(.sse4_1))
+        return self.fail("TODO implement genRound without sse4_1 feature", .{});
+
+    const mir_tag = if (@as(?Mir.Inst.Tag, switch (ty.zigTypeTag()) {
         .Float => switch (ty.floatBits(self.target.*)) {
-            32 => .roundss,
-            64 => .roundsd,
-            else => return self.fail("TODO implement airRound for {}", .{
-                ty.fmt(self.bin_file.options.module.?),
-            }),
+            32 => if (self.hasFeature(.avx)) .vroundss else .roundss,
+            64 => if (self.hasFeature(.avx)) .vroundsd else .roundsd,
+            16, 80, 128 => null,
+            else => unreachable,
         },
-        else => return self.fail("TODO implement airRound for {}", .{
-            ty.fmt(self.bin_file.options.module.?),
-        }),
-    };
-    assert(dst_mcv.isRegister());
+        .Vector => switch (ty.childType().zigTypeTag()) {
+            .Float => switch (ty.childType().floatBits(self.target.*)) {
+                32 => switch (ty.vectorLen()) {
+                    1 => if (self.hasFeature(.avx)) .vroundss else .roundss,
+                    2...4 => if (self.hasFeature(.avx)) .vroundps else .roundps,
+                    5...8 => if (self.hasFeature(.avx)) .vroundps else null,
+                    else => null,
+                },
+                64 => switch (ty.vectorLen()) {
+                    1 => if (self.hasFeature(.avx)) .vroundsd else .roundsd,
+                    2 => if (self.hasFeature(.avx)) .vroundpd else .roundpd,
+                    3...4 => if (self.hasFeature(.avx)) .vroundpd else null,
+                    else => null,
+                },
+                16, 80, 128 => null,
+                else => unreachable,
+            },
+            else => null,
+        },
+        else => unreachable,
+    })) |tag| tag else return self.fail("TODO implement genRound for {}", .{
+        ty.fmt(self.bin_file.options.module.?),
+    });
+
     const abi_size = @intCast(u32, ty.abiSize(self.target.*));
-    const dst_reg = registerAlias(dst_mcv.getReg().?, abi_size);
-    if (src_mcv.isRegister())
-        try self.asmRegisterRegisterImmediate(
+    const dst_alias = registerAlias(dst_reg, abi_size);
+    switch (mir_tag) {
+        .vroundss, .vroundsd => if (src_mcv.isMemory()) try self.asmRegisterRegisterMemoryImmediate(
             mir_tag,
-            dst_reg,
-            registerAlias(src_mcv.getReg().?, abi_size),
-            mode,
-        )
-    else
-        try self.asmRegisterMemoryImmediate(
+            dst_alias,
+            dst_alias,
+            src_mcv.mem(Memory.PtrSize.fromSize(abi_size)),
+            Immediate.u(mode),
+        ) else try self.asmRegisterRegisterRegisterImmediate(
             mir_tag,
-            dst_reg,
-            src_mcv.mem(Memory.PtrSize.fromSize(@intCast(u32, ty.abiSize(self.target.*)))),
-            mode,
-        );
-    return self.finishAir(inst, dst_mcv, .{ un_op, .none, .none });
+            dst_alias,
+            dst_alias,
+            registerAlias(if (src_mcv.isRegister())
+                src_mcv.getReg().?
+            else
+                try self.copyToTmpRegister(ty, src_mcv), abi_size),
+            Immediate.u(mode),
+        ),
+        else => if (src_mcv.isMemory()) try self.asmRegisterMemoryImmediate(
+            mir_tag,
+            dst_alias,
+            src_mcv.mem(Memory.PtrSize.fromSize(abi_size)),
+            Immediate.u(mode),
+        ) else try self.asmRegisterRegisterImmediate(
+            mir_tag,
+            dst_alias,
+            registerAlias(if (src_mcv.isRegister())
+                src_mcv.getReg().?
+            else
+                try self.copyToTmpRegister(ty, src_mcv), abi_size),
+            Immediate.u(mode),
+        ),
+    }
 }
 
 fn airSqrt(self: *Self, inst: Air.Inst.Index) !void {
@@ -6188,18 +6230,18 @@ fn genBinOp(
     })) |tag| tag else return self.fail("TODO implement genBinOp for {s} {}", .{
         @tagName(air_tag), lhs_ty.fmt(self.bin_file.options.module.?),
     });
-    const dst_alias = registerAlias(dst_mcv.getReg().?, abi_size);
+    const dst_reg = registerAlias(dst_mcv.getReg().?, abi_size);
     if (self.hasFeature(.avx)) {
         const src1_alias =
-            if (copied_to_dst) dst_alias else registerAlias(lhs_mcv.getReg().?, abi_size);
+            if (copied_to_dst) dst_reg else registerAlias(lhs_mcv.getReg().?, abi_size);
         if (src_mcv.isMemory()) try self.asmRegisterRegisterMemory(
             mir_tag,
-            dst_alias,
+            dst_reg,
             src1_alias,
             src_mcv.mem(Memory.PtrSize.fromSize(abi_size)),
         ) else try self.asmRegisterRegisterRegister(
             mir_tag,
-            dst_alias,
+            dst_reg,
             src1_alias,
             registerAlias(if (src_mcv.isRegister())
                 src_mcv.getReg().?
@@ -6210,11 +6252,11 @@ fn genBinOp(
         assert(copied_to_dst);
         if (src_mcv.isMemory()) try self.asmRegisterMemory(
             mir_tag,
-            dst_alias,
+            dst_reg,
             src_mcv.mem(Memory.PtrSize.fromSize(abi_size)),
         ) else try self.asmRegisterRegister(
             mir_tag,
-            dst_alias,
+            dst_reg,
             registerAlias(if (src_mcv.isRegister())
                 src_mcv.getReg().?
             else
@@ -6223,60 +6265,16 @@ fn genBinOp(
     }
     switch (air_tag) {
         .add, .sub, .mul, .div_float, .div_exact => {},
-        .div_trunc, .div_floor => if (self.hasFeature(.sse4_1)) {
-            const round_tag = if (@as(?Mir.Inst.Tag, switch (lhs_ty.zigTypeTag()) {
-                .Float => switch (lhs_ty.floatBits(self.target.*)) {
-                    32 => if (self.hasFeature(.avx)) .vroundss else .roundss,
-                    64 => if (self.hasFeature(.avx)) .vroundsd else .roundsd,
-                    16, 80, 128 => null,
-                    else => unreachable,
-                },
-                .Vector => switch (lhs_ty.childType().zigTypeTag()) {
-                    .Float => switch (lhs_ty.childType().floatBits(self.target.*)) {
-                        32 => switch (lhs_ty.vectorLen()) {
-                            1 => if (self.hasFeature(.avx)) .vroundss else .roundss,
-                            2...4 => if (self.hasFeature(.avx)) .vroundps else .roundps,
-                            5...8 => if (self.hasFeature(.avx)) .vroundps else null,
-                            else => null,
-                        },
-                        64 => switch (lhs_ty.vectorLen()) {
-                            1 => if (self.hasFeature(.avx)) .vroundsd else .roundsd,
-                            2 => if (self.hasFeature(.avx)) .vroundpd else .roundpd,
-                            3...4 => if (self.hasFeature(.avx)) .vroundpd else null,
-                            else => null,
-                        },
-                        16, 80, 128 => null,
-                        else => unreachable,
-                    },
-                    else => null,
-                },
-                else => unreachable,
-            })) |tag| tag else return self.fail("TODO implement genBinOp for {s} {}", .{
-                @tagName(air_tag), lhs_ty.fmt(self.bin_file.options.module.?),
-            });
-            const round_mode = Immediate.u(switch (air_tag) {
+        .div_trunc, .div_floor => try self.genRound(
+            lhs_ty,
+            dst_reg,
+            .{ .register = dst_reg },
+            switch (air_tag) {
                 .div_trunc => 0b1_0_11,
                 .div_floor => 0b1_0_01,
                 else => unreachable,
-            });
-            switch (round_tag) {
-                .vroundss, .vroundsd => try self.asmRegisterRegisterRegisterImmediate(
-                    round_tag,
-                    dst_alias,
-                    dst_alias,
-                    dst_alias,
-                    round_mode,
-                ),
-                else => try self.asmRegisterRegisterImmediate(
-                    round_tag,
-                    dst_alias,
-                    dst_alias,
-                    round_mode,
-                ),
-            }
-        } else return self.fail("TODO implement genBinOp for {s} {} without sse4_1", .{
-            @tagName(air_tag), lhs_ty.fmt(self.bin_file.options.module.?),
-        }),
+            },
+        ),
         .max, .min => {}, // TODO: unordered select
         else => unreachable,
     }
test/behavior/floatop.zig
@@ -617,7 +617,8 @@ fn testFloor() !void {
 
 test "@floor with vectors" {
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64 and
+        !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .sse4_1)) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
 
@@ -707,7 +708,8 @@ fn testCeil() !void {
 
 test "@ceil with vectors" {
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64 and
+        !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .sse4_1)) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
 
@@ -797,7 +799,8 @@ fn testTrunc() !void {
 
 test "@trunc with vectors" {
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64 and
+        !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .sse4_1)) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO