Commit fbe8c8938b

Jacob Young <jacobly0@users.noreply.github.com>
2023-10-23 03:58:26
x86_64: implement `@mod` for floating-point types
1 parent fe93332
Changed files (3)
lib
src
arch
test
behavior
lib/std/math.zig
@@ -903,12 +903,12 @@ pub fn mod(comptime T: type, numerator: T, denominator: T) !T {
 }
 
 test "mod" {
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
+
     try testMod();
     try comptime testMod();
 }
 fn testMod() !void {
-    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
-
     try testing.expect((mod(i32, -5, 3) catch unreachable) == 1);
     try testing.expect((mod(i32, 5, 3) catch unreachable) == 2);
     try testing.expectError(error.NegativeDenominator, mod(i32, 10, -1));
src/arch/x86_64/CodeGen.zig
@@ -2745,11 +2745,11 @@ fn airFptrunc(self: *Self, inst: Air.Inst.Index) !void {
             },
             else => unreachable,
         }) {
-            var callee: ["__trunc?f?f2".len]u8 = undefined;
+            var callee_buf: ["__trunc?f?f2".len]u8 = undefined;
             break :result try self.genCall(.{ .lib = .{
                 .return_type = self.floatCompilerRtAbiType(dst_ty, src_ty).toIntern(),
                 .param_types = &.{self.floatCompilerRtAbiType(src_ty, dst_ty).toIntern()},
-                .callee = std.fmt.bufPrint(&callee, "__trunc{c}f{c}f2", .{
+                .callee = std.fmt.bufPrint(&callee_buf, "__trunc{c}f{c}f2", .{
                     floatCompilerRtAbiName(src_bits),
                     floatCompilerRtAbiName(dst_bits),
                 }) catch unreachable,
@@ -2777,7 +2777,7 @@ fn airFptrunc(self: *Self, inst: Air.Inst.Index) !void {
                         .{ .v_, .cvtps2ph },
                         dst_reg,
                         mat_src_reg.to128(),
-                        Immediate.u(0b1_00),
+                        Immediate.u(@as(u5, @bitCast(RoundMode{ .mode = .mxcsr }))),
                     );
                 },
                 else => unreachable,
@@ -2844,11 +2844,11 @@ fn airFpext(self: *Self, inst: Air.Inst.Index) !void {
             },
             else => unreachable,
         }) {
-            var callee: ["__extend?f?f2".len]u8 = undefined;
+            var callee_buf: ["__extend?f?f2".len]u8 = undefined;
             break :result try self.genCall(.{ .lib = .{
                 .return_type = self.floatCompilerRtAbiType(dst_ty, src_ty).toIntern(),
                 .param_types = &.{self.floatCompilerRtAbiType(src_ty, dst_ty).toIntern()},
-                .callee = std.fmt.bufPrint(&callee, "__extend{c}f{c}f2", .{
+                .callee = std.fmt.bufPrint(&callee_buf, "__extend{c}f{c}f2", .{
                     floatCompilerRtAbiName(src_bits),
                     floatCompilerRtAbiName(dst_bits),
                 }) catch unreachable,
@@ -5347,7 +5347,7 @@ const RoundMode = packed struct(u5) {
     precision: enum(u1) {
         normal = 0b0,
         inexact = 0b1,
-    },
+    } = .normal,
 };
 
 fn airRound(self: *Self, inst: Air.Inst.Index, mode: RoundMode) !void {
@@ -5413,11 +5413,11 @@ fn genRoundLibcall(self: *Self, ty: Type, src_mcv: MCValue, mode: RoundMode) !MC
     if (ty.zigTypeTag(mod) != .Float)
         return self.fail("TODO implement genRound for {}", .{ty.fmt(mod)});
 
-    var callee: ["__trunc?".len]u8 = undefined;
+    var callee_buf: ["__trunc?".len]u8 = undefined;
     return try self.genCall(.{ .lib = .{
         .return_type = ty.toIntern(),
         .param_types = &.{ty.toIntern()},
-        .callee = std.fmt.bufPrint(&callee, "{s}{s}{s}", .{
+        .callee = std.fmt.bufPrint(&callee_buf, "{s}{s}{s}", .{
             floatLibcAbiPrefix(ty),
             switch (mode.mode) {
                 .down => "floor",
@@ -5628,11 +5628,11 @@ fn airSqrt(self: *Self, inst: Air.Inst.Index) !void {
                     80, 128 => true,
                     else => unreachable,
                 }) {
-                    var callee: ["__sqrt?".len]u8 = undefined;
+                    var callee_buf: ["__sqrt?".len]u8 = undefined;
                     break :result try self.genCall(.{ .lib = .{
                         .return_type = ty.toIntern(),
                         .param_types = &.{ty.toIntern()},
-                        .callee = std.fmt.bufPrint(&callee, "{s}sqrt{s}", .{
+                        .callee = std.fmt.bufPrint(&callee_buf, "{s}sqrt{s}", .{
                             floatLibcAbiPrefix(ty),
                             floatLibcAbiSuffix(ty),
                         }) catch unreachable,
@@ -5665,7 +5665,7 @@ fn airSqrt(self: *Self, inst: Air.Inst.Index) !void {
                         .{ .v_, .cvtps2ph },
                         dst_reg,
                         dst_reg,
-                        Immediate.u(0b1_00),
+                        Immediate.u(@as(u5, @bitCast(RoundMode{ .mode = .mxcsr }))),
                     );
                     break :result dst_mcv;
                 },
@@ -5695,7 +5695,7 @@ fn airSqrt(self: *Self, inst: Air.Inst.Index) !void {
                                 .{ .v_, .cvtps2ph },
                                 dst_reg,
                                 dst_reg,
-                                Immediate.u(0b1_00),
+                                Immediate.u(@as(u5, @bitCast(RoundMode{ .mode = .mxcsr }))),
                             );
                             break :result dst_mcv;
                         },
@@ -5720,7 +5720,7 @@ fn airSqrt(self: *Self, inst: Air.Inst.Index) !void {
                                 .{ .v_, .cvtps2ph },
                                 dst_reg,
                                 wide_reg,
-                                Immediate.u(0b1_00),
+                                Immediate.u(@as(u5, @bitCast(RoundMode{ .mode = .mxcsr }))),
                             );
                             break :result dst_mcv;
                         },
@@ -5783,11 +5783,11 @@ fn airSqrt(self: *Self, inst: Air.Inst.Index) !void {
 fn airUnaryMath(self: *Self, inst: Air.Inst.Index, tag: Air.Inst.Tag) !void {
     const un_op = self.air.instructions.items(.data)[inst].un_op;
     const ty = self.typeOf(un_op);
-    var callee: ["__round?".len]u8 = undefined;
+    var callee_buf: ["__round?".len]u8 = undefined;
     const result = try self.genCall(.{ .lib = .{
         .return_type = ty.toIntern(),
         .param_types = &.{ty.toIntern()},
-        .callee = std.fmt.bufPrint(&callee, "{s}{s}{s}", .{
+        .callee = std.fmt.bufPrint(&callee_buf, "{s}{s}{s}", .{
             floatLibcAbiPrefix(ty),
             switch (tag) {
                 .sin,
@@ -6978,11 +6978,11 @@ fn genMulDivBinOp(
                     ),
                     else => {},
                 };
-                var callee: ["__udiv?i3".len]u8 = undefined;
+                var callee_buf: ["__udiv?i3".len]u8 = undefined;
                 return try self.genCall(.{ .lib = .{
                     .return_type = dst_ty.toIntern(),
                     .param_types = &.{ src_ty.toIntern(), src_ty.toIntern() },
-                    .callee = std.fmt.bufPrint(&callee, "__{s}{s}{c}i3", .{
+                    .callee = std.fmt.bufPrint(&callee_buf, "__{s}{s}{c}i3", .{
                         if (signed) "" else "u",
                         switch (tag) {
                             .div_trunc, .div_exact => "div",
@@ -7205,43 +7205,163 @@ fn genBinOp(
     const rhs_ty = self.typeOf(rhs_air);
     const abi_size: u32 = @intCast(lhs_ty.abiSize(mod));
 
-    if (lhs_ty.isRuntimeFloat() and (air_tag == .rem or switch (lhs_ty.floatBits(self.target.*)) {
-        16 => !self.hasFeature(.f16c),
-        32, 64 => false,
-        80, 128 => true,
-        else => unreachable,
-    })) {
-        var callee: ["__mod?f3".len]u8 = undefined;
+    if (lhs_ty.isRuntimeFloat()) libcall: {
+        const float_bits = lhs_ty.floatBits(self.target.*);
+        const type_needs_libcall = switch (float_bits) {
+            16 => !self.hasFeature(.f16c),
+            32, 64 => false,
+            80, 128 => true,
+            else => unreachable,
+        };
+        switch (air_tag) {
+            .rem, .mod => {},
+            else => if (!type_needs_libcall) break :libcall,
+        }
+        var callee_buf: ["__mod?f3".len]u8 = undefined;
+        const callee = switch (air_tag) {
+            .add,
+            .sub,
+            .mul,
+            .div_float,
+            .div_trunc,
+            .div_floor,
+            => std.fmt.bufPrint(&callee_buf, "__{s}{c}f3", .{
+                @tagName(air_tag)[0..3],
+                floatCompilerRtAbiName(float_bits),
+            }),
+            .rem, .mod, .min, .max => std.fmt.bufPrint(&callee_buf, "{s}f{s}{s}", .{
+                floatLibcAbiPrefix(lhs_ty),
+                switch (air_tag) {
+                    .rem, .mod => "mod",
+                    .min => "min",
+                    .max => "max",
+                    else => unreachable,
+                },
+                floatLibcAbiSuffix(lhs_ty),
+            }),
+            else => return self.fail("TODO implement genBinOp for {s} {}", .{
+                @tagName(air_tag), lhs_ty.fmt(mod),
+            }),
+        } catch unreachable;
         const result = try self.genCall(.{ .lib = .{
             .return_type = lhs_ty.toIntern(),
             .param_types = &.{ lhs_ty.toIntern(), rhs_ty.toIntern() },
-            .callee = switch (air_tag) {
-                .add,
-                .sub,
-                .mul,
-                .div_float,
-                .div_trunc,
-                .div_floor,
-                => std.fmt.bufPrint(&callee, "__{s}{c}f3", .{
-                    @tagName(air_tag)[0..3],
-                    floatCompilerRtAbiName(lhs_ty.floatBits(self.target.*)),
-                }),
-                .rem, .min, .max => std.fmt.bufPrint(&callee, "{s}f{s}{s}", .{
-                    floatLibcAbiPrefix(lhs_ty),
-                    switch (air_tag) {
-                        .rem => "mod",
-                        .min => "min",
-                        .max => "max",
-                        else => unreachable,
-                    },
-                    floatLibcAbiSuffix(lhs_ty),
-                }),
-                else => return self.fail("TODO implement genBinOp for {s} {}", .{
-                    @tagName(air_tag), lhs_ty.fmt(mod),
-                }),
-            } catch unreachable,
+            .callee = callee,
         } }, &.{ lhs_ty, rhs_ty }, &.{ .{ .air_ref = lhs_air }, .{ .air_ref = rhs_air } });
         return switch (air_tag) {
+            .mod => result: {
+                const adjusted: MCValue = if (type_needs_libcall) adjusted: {
+                    var add_callee_buf: ["__add?f3".len]u8 = undefined;
+                    break :adjusted try self.genCall(.{ .lib = .{
+                        .return_type = lhs_ty.toIntern(),
+                        .param_types = &.{
+                            lhs_ty.toIntern(),
+                            rhs_ty.toIntern(),
+                        },
+                        .callee = std.fmt.bufPrint(&add_callee_buf, "__add{c}f3", .{
+                            floatCompilerRtAbiName(float_bits),
+                        }) catch unreachable,
+                    } }, &.{ lhs_ty, rhs_ty }, &.{ result, .{ .air_ref = rhs_air } });
+                } else switch (float_bits) {
+                    16, 32, 64 => adjusted: {
+                        const dst_reg = switch (result) {
+                            .register => |reg| reg,
+                            else => if (maybe_inst) |inst|
+                                (try self.copyToRegisterWithInstTracking(inst, lhs_ty, result)).register
+                            else
+                                try self.copyToTmpRegister(lhs_ty, result),
+                        };
+                        const dst_lock = self.register_manager.lockReg(dst_reg);
+                        defer if (dst_lock) |lock| self.register_manager.unlockReg(lock);
+
+                        const rhs_mcv = try self.resolveInst(rhs_air);
+                        const src_mcv: MCValue = if (float_bits == 16) src: {
+                            assert(self.hasFeature(.f16c));
+                            const tmp_reg = (try self.register_manager.allocReg(
+                                null,
+                                abi.RegisterClass.sse,
+                            )).to128();
+                            const tmp_lock = self.register_manager.lockRegAssumeUnused(tmp_reg);
+                            defer self.register_manager.unlockReg(tmp_lock);
+
+                            if (rhs_mcv.isMemory()) try self.asmRegisterRegisterMemoryImmediate(
+                                .{ .vp_w, .insr },
+                                dst_reg,
+                                dst_reg,
+                                rhs_mcv.mem(.word),
+                                Immediate.u(1),
+                            ) else try self.asmRegisterRegisterRegister(
+                                .{ .vp_, .unpcklwd },
+                                dst_reg,
+                                dst_reg,
+                                (if (rhs_mcv.isRegister())
+                                    rhs_mcv.getReg().?
+                                else
+                                    try self.copyToTmpRegister(rhs_ty, rhs_mcv)).to128(),
+                            );
+                            try self.asmRegisterRegister(.{ .v_ps, .cvtph2 }, dst_reg, dst_reg);
+                            break :src .{ .register = tmp_reg };
+                        } else rhs_mcv;
+
+                        if (self.hasFeature(.avx)) {
+                            const mir_tag: Mir.Inst.FixedTag = switch (float_bits) {
+                                16, 32 => .{ .v_ss, .add },
+                                64 => .{ .v_sd, .add },
+                                else => unreachable,
+                            };
+                            if (src_mcv.isMemory()) try self.asmRegisterRegisterMemory(
+                                mir_tag,
+                                dst_reg,
+                                dst_reg,
+                                src_mcv.mem(Memory.PtrSize.fromBitSize(float_bits)),
+                            ) else try self.asmRegisterRegisterRegister(
+                                mir_tag,
+                                dst_reg,
+                                dst_reg,
+                                (if (src_mcv.isRegister())
+                                    src_mcv.getReg().?
+                                else
+                                    try self.copyToTmpRegister(rhs_ty, src_mcv)).to128(),
+                            );
+                        } else {
+                            const mir_tag: Mir.Inst.FixedTag = switch (float_bits) {
+                                32 => .{ ._ss, .add },
+                                64 => .{ ._sd, .add },
+                                else => unreachable,
+                            };
+                            if (src_mcv.isMemory()) try self.asmRegisterMemory(
+                                mir_tag,
+                                dst_reg,
+                                src_mcv.mem(Memory.PtrSize.fromBitSize(float_bits)),
+                            ) else try self.asmRegisterRegister(
+                                mir_tag,
+                                dst_reg,
+                                (if (src_mcv.isRegister())
+                                    src_mcv.getReg().?
+                                else
+                                    try self.copyToTmpRegister(rhs_ty, src_mcv)).to128(),
+                            );
+                        }
+
+                        if (float_bits == 16) try self.asmRegisterRegisterImmediate(
+                            .{ .v_, .cvtps2ph },
+                            dst_reg,
+                            dst_reg,
+                            Immediate.u(@as(u5, @bitCast(RoundMode{ .mode = .mxcsr }))),
+                        );
+                        break :adjusted .{ .register = dst_reg };
+                    },
+                    80, 128 => return self.fail("TODO implement genBinOp for {s} of {}", .{
+                        @tagName(air_tag), lhs_ty.fmt(mod),
+                    }),
+                    else => unreachable,
+                };
+                break :result try self.genCall(.{ .lib = .{
+                    .return_type = lhs_ty.toIntern(),
+                    .param_types = &.{ lhs_ty.toIntern(), rhs_ty.toIntern() },
+                    .callee = callee,
+                } }, &.{ lhs_ty, rhs_ty }, &.{ adjusted, .{ .air_ref = rhs_air } });
+            },
             .div_trunc, .div_floor => try self.genRoundLibcall(lhs_ty, result, .{
                 .mode = switch (air_tag) {
                     .div_trunc => .zero,
@@ -7263,6 +7383,7 @@ fn genBinOp(
 
     const maybe_mask_reg = switch (air_tag) {
         else => null,
+        .rem, .mod => unreachable,
         .max, .min => if (lhs_ty.scalarType(mod).isRuntimeFloat()) registerAlias(
             if (!self.hasFeature(.avx) and self.hasFeature(.sse4_1)) mask: {
                 try self.register_manager.getReg(.xmm0, null);
@@ -7270,9 +7391,6 @@ fn genBinOp(
             } else try self.register_manager.allocReg(null, abi.RegisterClass.sse),
             abi_size,
         ) else null,
-        .rem, .mod => return self.fail("TODO implement genBinOp for {s} {}", .{
-            @tagName(air_tag), lhs_ty.fmt(mod),
-        }),
     };
     const mask_lock =
         if (maybe_mask_reg) |mask_reg| self.register_manager.lockRegAssumeUnused(mask_reg) else null;
@@ -7667,7 +7785,7 @@ fn genBinOp(
                     .{ .v_, .cvtps2ph },
                     dst_reg,
                     dst_reg,
-                    Immediate.u(0b1_00),
+                    Immediate.u(@as(u5, @bitCast(RoundMode{ .mode = .mxcsr }))),
                 );
                 return dst_mcv;
             },
@@ -8096,7 +8214,7 @@ fn genBinOp(
                                 .{ .v_, .cvtps2ph },
                                 dst_reg,
                                 dst_reg,
-                                Immediate.u(0b1_00),
+                                Immediate.u(@as(u5, @bitCast(RoundMode{ .mode = .mxcsr }))),
                             );
                             return dst_mcv;
                         },
@@ -8147,7 +8265,7 @@ fn genBinOp(
                                 .{ .v_, .cvtps2ph },
                                 dst_reg,
                                 dst_reg,
-                                Immediate.u(0b1_00),
+                                Immediate.u(@as(u5, @bitCast(RoundMode{ .mode = .mxcsr }))),
                             );
                             return dst_mcv;
                         },
@@ -8190,7 +8308,7 @@ fn genBinOp(
                                 .{ .v_, .cvtps2ph },
                                 dst_reg,
                                 dst_reg,
-                                Immediate.u(0b1_00),
+                                Immediate.u(@as(u5, @bitCast(RoundMode{ .mode = .mxcsr }))),
                             );
                             return dst_mcv;
                         },
@@ -8233,7 +8351,7 @@ fn genBinOp(
                                 .{ .v_, .cvtps2ph },
                                 dst_reg,
                                 dst_reg.to256(),
-                                Immediate.u(0b1_00),
+                                Immediate.u(@as(u5, @bitCast(RoundMode{ .mode = .mxcsr }))),
                             );
                             return dst_mcv;
                         },
@@ -9607,11 +9725,11 @@ fn airCmp(self: *Self, inst: Air.Inst.Index, op: math.CompareOperator) !void {
                     80, 128 => true,
                     else => unreachable,
                 }) {
-                    var callee: ["__???f2".len]u8 = undefined;
+                    var callee_buf: ["__???f2".len]u8 = undefined;
                     const ret = try self.genCall(.{ .lib = .{
                         .return_type = .i32_type,
                         .param_types = &.{ ty.toIntern(), ty.toIntern() },
-                        .callee = std.fmt.bufPrint(&callee, "__{s}{c}f2", .{
+                        .callee = std.fmt.bufPrint(&callee_buf, "__{s}{c}f2", .{
                             switch (op) {
                                 .eq => "eq",
                                 .neq => "ne",
@@ -12224,11 +12342,11 @@ fn airFloatFromInt(self: *Self, inst: Air.Inst.Index) !void {
                 src_ty.fmt(mod), dst_ty.fmt(mod),
             });
 
-            var callee: ["__floatun?i?f".len]u8 = undefined;
+            var callee_buf: ["__floatun?i?f".len]u8 = undefined;
             break :result try self.genCall(.{ .lib = .{
                 .return_type = dst_ty.toIntern(),
                 .param_types = &.{src_ty.toIntern()},
-                .callee = std.fmt.bufPrint(&callee, "__float{s}{c}i{c}f", .{
+                .callee = std.fmt.bufPrint(&callee_buf, "__float{s}{c}i{c}f", .{
                     switch (src_signedness) {
                         .signed => "",
                         .unsigned => "un",
@@ -12303,11 +12421,11 @@ fn airIntFromFloat(self: *Self, inst: Air.Inst.Index) !void {
                 src_ty.fmt(mod), dst_ty.fmt(mod),
             });
 
-            var callee: ["__fixuns?f?i".len]u8 = undefined;
+            var callee_buf: ["__fixuns?f?i".len]u8 = undefined;
             break :result try self.genCall(.{ .lib = .{
                 .return_type = dst_ty.toIntern(),
                 .param_types = &.{src_ty.toIntern()},
-                .callee = std.fmt.bufPrint(&callee, "__fix{s}{c}f{c}i", .{
+                .callee = std.fmt.bufPrint(&callee_buf, "__fix{s}{c}f{c}i", .{
                     switch (dst_signedness) {
                         .signed => "",
                         .unsigned => "uns",
@@ -13516,11 +13634,11 @@ fn airMulAdd(self: *Self, inst: Air.Inst.Index) !void {
                 ty.fmt(mod),
             });
 
-            var callee: ["__fma?".len]u8 = undefined;
+            var callee_buf: ["__fma?".len]u8 = undefined;
             break :result try self.genCall(.{ .lib = .{
                 .return_type = ty.toIntern(),
                 .param_types = &.{ ty.toIntern(), ty.toIntern(), ty.toIntern() },
-                .callee = std.fmt.bufPrint(&callee, "{s}fma{s}", .{
+                .callee = std.fmt.bufPrint(&callee_buf, "{s}fma{s}", .{
                     floatLibcAbiPrefix(ty),
                     floatLibcAbiSuffix(ty),
                 }) catch unreachable,
test/behavior/math.zig
@@ -1321,7 +1321,7 @@ test "remainder division" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_c and comptime builtin.cpu.arch.isArmOrThumb()) return error.SkipZigTest;
 
     if (builtin.zig_backend == .stage2_llvm and builtin.os.tag == .windows) {
@@ -1401,9 +1401,9 @@ test "float modulo division using @mod" {
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf) return error.SkipZigTest;
 
     try comptime fmod(f16);
     try comptime fmod(f32);