Commit a2df84d0ff

Sean <69403556+SeanTheGleaming@users.noreply.github.com>
2024-03-29 10:33:57
std.math: rework modf
- Changed `modf_result` to `Modf` to better fit naming conventions - Reworked `modf` to be far simpler and support all floating point types (as well as vectors) (I have done benchmarks and can confirm that the performance is roughly equivalent to the old implementation) - Added more descriptive tests for modf - Deprecated `modf32_result` and `modf64_result` in favor of `Modf(f32)` and `Modf(f64)` respectively
1 parent 2d443cd
Changed files (2)
lib
lib/std/math/modf.zig
@@ -1,207 +1,141 @@
-// Ported from musl, which is licensed under the MIT license:
-// https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT
-//
-// https://git.musl-libc.org/cgit/musl/tree/src/math/modff.c
-// https://git.musl-libc.org/cgit/musl/tree/src/math/modf.c
-
 const std = @import("../std.zig");
 const math = std.math;
 const expect = std.testing.expect;
 const expectEqual = std.testing.expectEqual;
-const maxInt = std.math.maxInt;
+const expectApproxEqAbs = std.testing.expectApproxEqAbs;
 
-fn modf_result(comptime T: type) type {
+pub fn Modf(comptime T: type) type {
     return struct {
         fpart: T,
         ipart: T,
     };
 }
-pub const modf32_result = modf_result(f32);
-pub const modf64_result = modf_result(f64);
 
 /// Returns the integer and fractional floating-point numbers that sum to x. The sign of each
 /// result is the same as the sign of x.
+/// In comptime, may be used with comptime_float
 ///
 /// Special Cases:
 ///  - modf(+-inf) = +-inf, nan
 ///  - modf(nan)   = nan, nan
-pub fn modf(x: anytype) modf_result(@TypeOf(x)) {
-    const T = @TypeOf(x);
-    return switch (T) {
-        f32 => modf32(x),
-        f64 => modf64(x),
-        else => @compileError("modf not implemented for " ++ @typeName(T)),
+pub fn modf(x: anytype) Modf(@TypeOf(x)) {
+    const ipart = @trunc(x);
+    return .{
+        .ipart = ipart,
+        .fpart = x - ipart,
     };
 }
 
-fn modf32(x: f32) modf32_result {
-    var result: modf32_result = undefined;
+test modf {
+    inline for ([_]type{ f16, f32, f64, f80, f128 }) |T| {
+        const epsilon: comptime_float = @max(1e-6, math.floatEps(T));
 
-    const u: u32 = @bitCast(x);
-    const e = @as(i32, @intCast((u >> 23) & 0xFF)) - 0x7F;
-    const us = u & 0x80000000;
+        var r: Modf(T) = undefined;
 
-    // TODO: Shouldn't need this.
-    if (math.isInf(x)) {
-        result.ipart = x;
-        result.fpart = math.nan(f32);
-        return result;
-    }
+        r = modf(@as(T, 1.0));
+        try expectEqual(1.0, r.ipart);
+        try expectEqual(0.0, r.fpart);
 
-    // no fractional part
-    if (e >= 23) {
-        result.ipart = x;
-        if (e == 0x80 and u << 9 != 0) { // nan
-            result.fpart = x;
-        } else {
-            result.fpart = @as(f32, @bitCast(us));
-        }
-        return result;
-    }
+        r = modf(@as(T, 0.34682));
+        try expectEqual(0.0, r.ipart);
+        try expectApproxEqAbs(@as(T, 0.34682), r.fpart, epsilon);
 
-    // no integral part
-    if (e < 0) {
-        result.ipart = @as(f32, @bitCast(us));
-        result.fpart = x;
-        return result;
-    }
+        r = modf(@as(T, 2.54576));
+        try expectEqual(2.0, r.ipart);
+        try expectApproxEqAbs(0.54576, r.fpart, epsilon);
 
-    const mask = @as(u32, 0x007FFFFF) >> @as(u5, @intCast(e));
-    if (u & mask == 0) {
-        result.ipart = x;
-        result.fpart = @as(f32, @bitCast(us));
-        return result;
+        r = modf(@as(T, 3.9782));
+        try expectEqual(3.0, r.ipart);
+        try expectApproxEqAbs(0.9782, r.fpart, epsilon);
     }
-
-    const uf: f32 = @bitCast(u & ~mask);
-    result.ipart = uf;
-    result.fpart = x - uf;
-    return result;
 }
 
-fn modf64(x: f64) modf64_result {
-    var result: modf64_result = undefined;
-
-    const u: u64 = @bitCast(x);
-    const e = @as(i32, @intCast((u >> 52) & 0x7FF)) - 0x3FF;
-    const us = u & (1 << 63);
-
-    if (math.isInf(x)) {
-        result.ipart = x;
-        result.fpart = math.nan(f64);
-        return result;
-    }
-
-    // no fractional part
-    if (e >= 52) {
-        result.ipart = x;
-        if (e == 0x400 and u << 12 != 0) { // nan
-            result.fpart = x;
-        } else {
-            result.fpart = @as(f64, @bitCast(us));
+/// Generate a namespace of tests for modf on values of the given type
+fn ModfTests(comptime T: type) type {
+    return struct {
+        test "normal" {
+            const epsilon: comptime_float = @max(1e-6, math.floatEps(T));
+            var r: Modf(T) = undefined;
+
+            r = modf(@as(T, 1.0));
+            try expectEqual(1.0, r.ipart);
+            try expectEqual(0.0, r.fpart);
+
+            r = modf(@as(T, 0.34682));
+            try expectEqual(0.0, r.ipart);
+            try expectApproxEqAbs(0.34682, r.fpart, epsilon);
+
+            r = modf(@as(T, 3.97812));
+            try expectEqual(3.0, r.ipart);
+            // account for precision error
+            const expected_a: T = 3.97812 - @as(T, 3);
+            try expectApproxEqAbs(expected_a, r.fpart, epsilon);
+
+            r = modf(@as(T, 43874.3));
+            try expectEqual(43874.0, r.ipart);
+            // account for precision error
+            const expected_b: T = 43874.3 - @as(T, 43874);
+            try expectApproxEqAbs(expected_b, r.fpart, epsilon);
+
+            r = modf(@as(T, 1234.340780));
+            try expectEqual(1234.0, r.ipart);
+            // account for precision error
+            const expected_c: T = 1234.340780 - @as(T, 1234);
+            try expectApproxEqAbs(expected_c, r.fpart, epsilon);
         }
-        return result;
-    }
-
-    // no integral part
-    if (e < 0) {
-        result.ipart = @as(f64, @bitCast(us));
-        result.fpart = x;
-        return result;
-    }
-
-    const mask = @as(u64, maxInt(u64) >> 12) >> @as(u6, @intCast(e));
-    if (u & mask == 0) {
-        result.ipart = x;
-        result.fpart = @as(f64, @bitCast(us));
-        return result;
-    }
-
-    const uf = @as(f64, @bitCast(u & ~mask));
-    result.ipart = uf;
-    result.fpart = x - uf;
-    return result;
-}
+        test "vector" {
+            // Currently, a compiler bug is breaking the usage
+            // of @trunc on @Vector types
 
-test modf {
-    const a = modf(@as(f32, 1.0));
-    const b = modf32(1.0);
-    // NOTE: No struct comparison on generic return type function? non-named, makes sense, but still.
-    try expectEqual(a, b);
-}
-
-test modf32 {
-    const epsilon = 0.000001;
-    var r: modf32_result = undefined;
-
-    r = modf32(1.0);
-    try expect(math.approxEqAbs(f32, r.ipart, 1.0, epsilon));
-    try expect(math.approxEqAbs(f32, r.fpart, 0.0, epsilon));
-
-    r = modf32(2.545);
-    try expect(math.approxEqAbs(f32, r.ipart, 2.0, epsilon));
-    try expect(math.approxEqAbs(f32, r.fpart, 0.545, epsilon));
+            // TODO: Repopulate the below array and
+            // remove the skip statement once this
+            // bug is fixed
 
-    r = modf32(3.978123);
-    try expect(math.approxEqAbs(f32, r.ipart, 3.0, epsilon));
-    try expect(math.approxEqAbs(f32, r.fpart, 0.978123, epsilon));
+            // const widths = [_]comptime_int{ 1, 2, 3, 4, 8, 16 };
+            const widths = [_]comptime_int{};
 
-    r = modf32(43874.3);
-    try expect(math.approxEqAbs(f32, r.ipart, 43874, epsilon));
-    try expect(math.approxEqAbs(f32, r.fpart, 0.300781, epsilon));
+            if (widths.len == 0)
+                return error.SkipZigTest;
 
-    r = modf32(1234.340780);
-    try expect(math.approxEqAbs(f32, r.ipart, 1234, epsilon));
-    try expect(math.approxEqAbs(f32, r.fpart, 0.340820, epsilon));
-}
-
-test modf64 {
-    const epsilon = 0.000001;
-    var r: modf64_result = undefined;
-
-    r = modf64(1.0);
-    try expect(math.approxEqAbs(f64, r.ipart, 1.0, epsilon));
-    try expect(math.approxEqAbs(f64, r.fpart, 0.0, epsilon));
+            inline for (widths) |len| {
+                const V: type = @Vector(len, T);
+                var r: Modf(V) = undefined;
 
-    r = modf64(2.545);
-    try expect(math.approxEqAbs(f64, r.ipart, 2.0, epsilon));
-    try expect(math.approxEqAbs(f64, r.fpart, 0.545, epsilon));
+                r = modf(@as(V, @splat(1.0)));
+                try expectEqual(@as(V, @splat(1.0)), r.ipart);
+                try expectEqual(@as(V, @splat(0.0)), r.fpart);
 
-    r = modf64(3.978123);
-    try expect(math.approxEqAbs(f64, r.ipart, 3.0, epsilon));
-    try expect(math.approxEqAbs(f64, r.fpart, 0.978123, epsilon));
+                r = modf(@as(V, @splat(2.75)));
+                try expectEqual(@as(V, @splat(2.0)), r.ipart);
+                try expectEqual(@as(V, @splat(0.75)), r.fpart);
 
-    r = modf64(43874.3);
-    try expect(math.approxEqAbs(f64, r.ipart, 43874, epsilon));
-    try expect(math.approxEqAbs(f64, r.fpart, 0.3, epsilon));
-
-    r = modf64(1234.340780);
-    try expect(math.approxEqAbs(f64, r.ipart, 1234, epsilon));
-    try expect(math.approxEqAbs(f64, r.fpart, 0.340780, epsilon));
-}
+                r = modf(@as(V, @splat(0.2)));
+                try expectEqual(@as(V, @splat(0.0)), r.ipart);
+                try expectEqual(@as(V, @splat(0.2)), r.fpart);
 
-test "modf32.special" {
-    var r: modf32_result = undefined;
-
-    r = modf32(math.inf(f32));
-    try expect(math.isPositiveInf(r.ipart) and math.isNan(r.fpart));
+                r = modf(std.simd.iota(T, len) + @as(V, @splat(0.5)));
+                try expectEqual(std.simd.iota(T, len), r.ipart);
+                try expectEqual(@as(V, @splat(0.5)), r.fpart);
+            }
+        }
+        test "inf" {
+            var r: Modf(T) = undefined;
 
-    r = modf32(-math.inf(f32));
-    try expect(math.isNegativeInf(r.ipart) and math.isNan(r.fpart));
+            r = modf(math.inf(T));
+            try expect(math.isPositiveInf(r.ipart) and math.isNan(r.fpart));
 
-    r = modf32(math.nan(f32));
-    try expect(math.isNan(r.ipart) and math.isNan(r.fpart));
+            r = modf(-math.inf(T));
+            try expect(math.isNegativeInf(r.ipart) and math.isNan(r.fpart));
+        }
+        test "nan" {
+            const r: Modf(T) = modf(math.nan(T));
+            try expect(math.isNan(r.ipart) and math.isNan(r.fpart));
+        }
+    };
 }
 
-test "modf64.special" {
-    var r: modf64_result = undefined;
-
-    r = modf64(math.inf(f64));
-    try expect(math.isPositiveInf(r.ipart) and math.isNan(r.fpart));
-
-    r = modf64(-math.inf(f64));
-    try expect(math.isNegativeInf(r.ipart) and math.isNan(r.fpart));
-
-    r = modf64(math.nan(f64));
-    try expect(math.isNan(r.ipart) and math.isNan(r.fpart));
+comptime {
+    for ([_]type{ f16, f32, f64, f80, f128 }) |T| {
+        _ = ModfTests(T);
+    }
 }
lib/std/math.zig
@@ -112,6 +112,8 @@ pub const qnan_f80 = @compileError("Deprecated: use `nan(f80)` instead");
 pub const qnan_u128 = @compileError("Deprecated: use `@as(u128, @bitCast(nan(f128)))` instead");
 pub const qnan_f128 = @compileError("Deprecated: use `nan(f128)` instead");
 pub const epsilon = @compileError("Deprecated: use `floatEps` instead");
+pub const modf32_result = @compileError("Deprecated: use `Modf(f32)` instead");
+pub const modf64_result = @compileError("Deprecated: use `Modf(f64)` instead");
 
 /// Performs an approximate comparison of two floating point values `x` and `y`.
 /// Returns true if the absolute difference between them is less or equal than
@@ -255,8 +257,7 @@ pub const isSignalNan = @import("math/isnan.zig").isSignalNan;
 pub const frexp = @import("math/frexp.zig").frexp;
 pub const Frexp = @import("math/frexp.zig").Frexp;
 pub const modf = @import("math/modf.zig").modf;
-pub const modf32_result = @import("math/modf.zig").modf32_result;
-pub const modf64_result = @import("math/modf.zig").modf64_result;
+pub const Modf = @import("math/modf.zig").Modf;
 pub const copysign = @import("math/copysign.zig").copysign;
 pub const isFinite = @import("math/isfinite.zig").isFinite;
 pub const isInf = @import("math/isinf.zig").isInf;
@@ -418,8 +419,7 @@ test {
     _ = frexp;
     _ = Frexp;
     _ = modf;
-    _ = modf32_result;
-    _ = modf64_result;
+    _ = Modf;
     _ = copysign;
     _ = isFinite;
     _ = isInf;