Commit 0aeeff0d94

expikr <77922942+expikr@users.noreply.github.com>
2024-05-30 11:58:05
math.hypot: fix incorrect over/underflow behavior (#19472)
1 parent ce9d2ed
Changed files (3)
lib/std/math/float.zig
@@ -94,6 +94,19 @@ pub inline fn floatEps(comptime T: type) T {
     return reconstructFloat(T, -floatFractionalBits(T), mantissaOne(T));
 }
 
+/// Returns the local epsilon of floating point type T.
+pub inline fn floatEpsAt(comptime T: type, x: T) T {
+    switch (@typeInfo(T)) {
+        .Float => |F| {
+            const U: type = @Type(.{ .Int = .{ .signedness = .unsigned, .bits = F.bits } });
+            const u: U = @bitCast(x);
+            const y: T = @bitCast(u ^ 1);
+            return @abs(x - y);
+        },
+        else => @compileError("floatEpsAt only supports floats"),
+    }
+}
+
 /// Returns the value inf for floating point type T.
 pub inline fn inf(comptime T: type) T {
     return reconstructFloat(T, floatExponentMax(T) + 1, mantissaOne(T));
lib/std/math/hypot.zig
@@ -1,13 +1,14 @@
-// Ported from musl, which is licensed under the MIT license:
-// https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT
-//
-// https://git.musl-libc.org/cgit/musl/tree/src/math/hypotf.c
-// https://git.musl-libc.org/cgit/musl/tree/src/math/hypot.c
-
 const std = @import("../std.zig");
 const math = std.math;
 const expect = std.testing.expect;
-const maxInt = std.math.maxInt;
+const isNan = math.isNan;
+const isInf = math.isInf;
+const inf = math.inf;
+const nan = math.nan;
+const floatEpsAt = math.floatEpsAt;
+const floatEps = math.floatEps;
+const floatMin = math.floatMin;
+const floatMax = math.floatMax;
 
 /// Returns sqrt(x * x + y * y), avoiding unnecessary overflow and underflow.
 ///
@@ -15,162 +16,116 @@ const maxInt = std.math.maxInt;
 ///
 /// |   x   |   y   | hypot |
 /// |-------|-------|-------|
-/// | +inf  |  num  | +inf  |
-/// |  num  | +-inf | +inf  |
-/// |  nan  |  any  |  nan  |
-/// |  any  |  nan  |  nan  |
+/// | +-inf |  any  | +inf  |
+/// |  any  | +-inf | +inf  |
+/// |  nan  |  fin  |  nan  |
+/// |  fin  |  nan  |  nan  |
 pub fn hypot(x: anytype, y: anytype) @TypeOf(x, y) {
     const T = @TypeOf(x, y);
-    return switch (T) {
-        f32 => hypot32(x, y),
-        f64 => hypot64(x, y),
+    switch (@typeInfo(T)) {
+        .Float => {},
+        .ComptimeFloat => return @sqrt(x * x + y * y),
         else => @compileError("hypot not implemented for " ++ @typeName(T)),
-    };
-}
-
-fn hypot32(x: f32, y: f32) f32 {
-    var ux = @as(u32, @bitCast(x));
-    var uy = @as(u32, @bitCast(y));
-
-    ux &= maxInt(u32) >> 1;
-    uy &= maxInt(u32) >> 1;
-    if (ux < uy) {
-        const tmp = ux;
-        ux = uy;
-        uy = tmp;
     }
-
-    var xx = @as(f32, @bitCast(ux));
-    var yy = @as(f32, @bitCast(uy));
-    if (uy == 0xFF << 23) {
-        return yy;
-    }
-    if (ux >= 0xFF << 23 or uy == 0 or ux - uy >= (25 << 23)) {
-        return xx + yy;
-    }
-
-    var z: f32 = 1.0;
-    if (ux >= (0x7F + 60) << 23) {
-        z = 0x1.0p90;
-        xx *= 0x1.0p-90;
-        yy *= 0x1.0p-90;
-    } else if (uy < (0x7F - 60) << 23) {
-        z = 0x1.0p-90;
-        xx *= 0x1.0p-90;
-        yy *= 0x1.0p-90;
+    const lower = @sqrt(floatMin(T));
+    const upper = @sqrt(floatMax(T) / 2);
+    const incre = @sqrt(floatEps(T) / 2);
+    const scale = floatEpsAt(T, incre);
+    const hypfn = if (emulateFma(T)) hypotUnfused else hypotFused;
+    var major: T = x;
+    var minor: T = y;
+    if (isInf(major) or isInf(minor)) return inf(T);
+    if (isNan(major) or isNan(minor)) return nan(T);
+    if (T == f16) return @floatCast(@sqrt(@mulAdd(f32, x, x, @as(f32, y) * y)));
+    if (T == f32) return @floatCast(@sqrt(@mulAdd(f64, x, x, @as(f64, y) * y)));
+    major = @abs(major);
+    minor = @abs(minor);
+    if (minor > major) {
+        const tempo = major;
+        major = minor;
+        minor = tempo;
     }
-
-    return z * @sqrt(@as(f32, @floatCast(@as(f64, x) * x + @as(f64, y) * y)));
+    if (major * incre >= minor) return major;
+    if (major > upper) return hypfn(T, major * scale, minor * scale) / scale;
+    if (minor < lower) return hypfn(T, major / scale, minor / scale) * scale;
+    return hypfn(T, major, minor);
 }
 
-fn sq(hi: *f64, lo: *f64, x: f64) void {
-    const split: f64 = 0x1.0p27 + 1.0;
-    const xc = x * split;
-    const xh = x - xc + xc;
-    const xl = x - xh;
-    hi.* = x * x;
-    lo.* = xh * xh - hi.* + 2 * xh * xl + xl * xl;
+inline fn emulateFma(comptime T: type) bool {
+    // If @mulAdd lowers to the software implementation,
+    // hypotUnfused should be used in place of hypotFused.
+    // This takes an educated guess, but ideally we should
+    // properly detect at comptime when that fallback will
+    // occur.
+    return (T == f128 or T == f80);
 }
 
-fn hypot64(x: f64, y: f64) f64 {
-    var ux = @as(u64, @bitCast(x));
-    var uy = @as(u64, @bitCast(y));
-
-    ux &= maxInt(u64) >> 1;
-    uy &= maxInt(u64) >> 1;
-    if (ux < uy) {
-        const tmp = ux;
-        ux = uy;
-        uy = tmp;
-    }
-
-    const ex = ux >> 52;
-    const ey = uy >> 52;
-    var xx = @as(f64, @bitCast(ux));
-    var yy = @as(f64, @bitCast(uy));
-
-    // hypot(inf, nan) == inf
-    if (ey == 0x7FF) {
-        return yy;
-    }
-    if (ex == 0x7FF or uy == 0) {
-        return xx;
-    }
-
-    // hypot(x, y) ~= x + y * y / x / 2 with inexact for small y/x
-    if (ex - ey > 64) {
-        return xx + yy;
-    }
+inline fn hypotFused(comptime F: type, x: F, y: F) F {
+    const r = @sqrt(@mulAdd(F, x, x, y * y));
+    const rr = r * r;
+    const xx = x * x;
+    const z = @mulAdd(F, -y, y, rr - xx) + @mulAdd(F, r, r, -rr) - @mulAdd(F, x, x, -xx);
+    return r - z / (2 * r);
+}
 
-    var z: f64 = 1;
-    if (ex > 0x3FF + 510) {
-        z = 0x1.0p700;
-        xx *= 0x1.0p-700;
-        yy *= 0x1.0p-700;
-    } else if (ey < 0x3FF - 450) {
-        z = 0x1.0p-700;
-        xx *= 0x1.0p700;
-        yy *= 0x1.0p700;
+inline fn hypotUnfused(comptime F: type, x: F, y: F) F {
+    const r = @sqrt(x * x + y * y);
+    if (r <= 2 * y) { // 30deg or steeper
+        const dx = r - y;
+        const z = x * (2 * dx - x) + (dx - 2 * (x - y)) * dx;
+        return r - z / (2 * r);
+    } else { // shallower than 30 deg
+        const dy = r - x;
+        const z = 2 * dy * (x - 2 * y) + (4 * dy - y) * y + dy * dy;
+        return r - z / (2 * r);
     }
-
-    var hx: f64 = undefined;
-    var lx: f64 = undefined;
-    var hy: f64 = undefined;
-    var ly: f64 = undefined;
-
-    sq(&hx, &lx, x);
-    sq(&hy, &ly, y);
-
-    return z * @sqrt(ly + lx + hy + hx);
 }
 
+const hypot_test_cases = .{
+    .{ 0.0, -1.2, 1.2 },
+    .{ 0.2, -0.34, 0.3944616584663203993612799816649560759946493601889826495362 },
+    .{ 0.8923, 2.636890, 2.7837722899152509525110650481670176852603253522923737962880 },
+    .{ 1.5, 5.25, 5.4600824169603887033229768686452745953332522619323580787836 },
+    .{ 37.45, 159.835, 164.16372840856167640478217141034363907565754072954443805164 },
+    .{ 89.123, 382.028905, 392.28687638576315875933966414927490685367196874260165618371 },
+    .{ 123123.234375, 529428.707813, 543556.88524707706887251269205923830745438413088753096759371 },
+};
+
 test hypot {
-    const x32: f32 = 0.0;
-    const y32: f32 = -1.2;
-    const x64: f64 = 0.0;
-    const y64: f64 = -1.2;
-    try expect(hypot(x32, y32) == hypot32(0.0, -1.2));
-    try expect(hypot(x64, y64) == hypot64(0.0, -1.2));
+    try expect(hypot(0.3, 0.4) == 0.5);
 }
 
-test hypot32 {
-    const epsilon = 0.000001;
-
-    try expect(math.approxEqAbs(f32, hypot32(0.0, -1.2), 1.2, epsilon));
-    try expect(math.approxEqAbs(f32, hypot32(0.2, -0.34), 0.394462, epsilon));
-    try expect(math.approxEqAbs(f32, hypot32(0.8923, 2.636890), 2.783772, epsilon));
-    try expect(math.approxEqAbs(f32, hypot32(1.5, 5.25), 5.460083, epsilon));
-    try expect(math.approxEqAbs(f32, hypot32(37.45, 159.835), 164.163742, epsilon));
-    try expect(math.approxEqAbs(f32, hypot32(89.123, 382.028905), 392.286865, epsilon));
-    try expect(math.approxEqAbs(f32, hypot32(123123.234375, 529428.707813), 543556.875, epsilon));
+test "hypot.correct" {
+    inline for (.{ f16, f32, f64, f128 }) |T| {
+        inline for (hypot_test_cases) |v| {
+            const a: T, const b: T, const c: T = v;
+            try expect(math.approxEqRel(T, hypot(a, b), c, @sqrt(floatEps(T))));
+        }
+    }
 }
 
-test hypot64 {
-    const epsilon = 0.000001;
-
-    try expect(math.approxEqAbs(f64, hypot64(0.0, -1.2), 1.2, epsilon));
-    try expect(math.approxEqAbs(f64, hypot64(0.2, -0.34), 0.394462, epsilon));
-    try expect(math.approxEqAbs(f64, hypot64(0.8923, 2.636890), 2.783772, epsilon));
-    try expect(math.approxEqAbs(f64, hypot64(1.5, 5.25), 5.460082, epsilon));
-    try expect(math.approxEqAbs(f64, hypot64(37.45, 159.835), 164.163728, epsilon));
-    try expect(math.approxEqAbs(f64, hypot64(89.123, 382.028905), 392.286876, epsilon));
-    try expect(math.approxEqAbs(f64, hypot64(123123.234375, 529428.707813), 543556.885247, epsilon));
+test "hypot.precise" {
+    inline for (.{ f16, f32, f64 }) |T| { // f128 seems to be 5 ulp
+        inline for (hypot_test_cases) |v| {
+            const a: T, const b: T, const c: T = v;
+            try expect(math.approxEqRel(T, hypot(a, b), c, floatEps(T)));
+        }
+    }
 }
 
-test "hypot32.special" {
-    try expect(math.isPositiveInf(hypot32(math.inf(f32), 0.0)));
-    try expect(math.isPositiveInf(hypot32(-math.inf(f32), 0.0)));
-    try expect(math.isPositiveInf(hypot32(0.0, math.inf(f32))));
-    try expect(math.isPositiveInf(hypot32(0.0, -math.inf(f32))));
-    try expect(math.isNan(hypot32(math.nan(f32), 0.0)));
-    try expect(math.isNan(hypot32(0.0, math.nan(f32))));
-}
+test "hypot.special" {
+    inline for (.{ f16, f32, f64, f128 }) |T| {
+        try expect(math.isNan(hypot(nan(T), 0.0)));
+        try expect(math.isNan(hypot(0.0, nan(T))));
+
+        try expect(math.isPositiveInf(hypot(inf(T), 0.0)));
+        try expect(math.isPositiveInf(hypot(0.0, inf(T))));
+        try expect(math.isPositiveInf(hypot(inf(T), nan(T))));
+        try expect(math.isPositiveInf(hypot(nan(T), inf(T))));
 
-test "hypot64.special" {
-    try expect(math.isPositiveInf(hypot64(math.inf(f64), 0.0)));
-    try expect(math.isPositiveInf(hypot64(-math.inf(f64), 0.0)));
-    try expect(math.isPositiveInf(hypot64(0.0, math.inf(f64))));
-    try expect(math.isPositiveInf(hypot64(0.0, -math.inf(f64))));
-    try expect(math.isNan(hypot64(math.nan(f64), 0.0)));
-    try expect(math.isNan(hypot64(0.0, math.nan(f64))));
+        try expect(math.isPositiveInf(hypot(-inf(T), 0.0)));
+        try expect(math.isPositiveInf(hypot(0.0, -inf(T))));
+        try expect(math.isPositiveInf(hypot(-inf(T), nan(T))));
+        try expect(math.isPositiveInf(hypot(nan(T), -inf(T))));
+    }
 }
lib/std/math.zig
@@ -52,6 +52,7 @@ pub const floatTrueMin = @import("math/float.zig").floatTrueMin;
 pub const floatMin = @import("math/float.zig").floatMin;
 pub const floatMax = @import("math/float.zig").floatMax;
 pub const floatEps = @import("math/float.zig").floatEps;
+pub const floatEpsAt = @import("math/float.zig").floatEpsAt;
 pub const inf = @import("math/float.zig").inf;
 pub const nan = @import("math/float.zig").nan;
 pub const snan = @import("math/float.zig").snan;