master
  1// Ported from musl, which is licensed under the MIT license:
  2// https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT
  3//
  4// https://git.musl-libc.org/cgit/musl/tree/src/complex/csqrtf.c
  5// https://git.musl-libc.org/cgit/musl/tree/src/complex/csqrt.c
  6
  7const std = @import("../../std.zig");
  8const testing = std.testing;
  9const math = std.math;
 10const cmath = math.complex;
 11const Complex = cmath.Complex;
 12
 13/// Returns the square root of z. The real and imaginary parts of the result have the same sign
 14/// as the imaginary part of z.
 15pub fn sqrt(z: anytype) Complex(@TypeOf(z.re, z.im)) {
 16    const T = @TypeOf(z.re, z.im);
 17
 18    return switch (T) {
 19        f32 => sqrt32(z),
 20        f64 => sqrt64(z),
 21        else => @compileError("sqrt not implemented for " ++ @typeName(T)),
 22    };
 23}
 24
 25fn sqrt32(z: Complex(f32)) Complex(f32) {
 26    const x = z.re;
 27    const y = z.im;
 28
 29    if (x == 0 and y == 0) {
 30        return Complex(f32).init(0, y);
 31    }
 32    if (math.isInf(y)) {
 33        return Complex(f32).init(math.inf(f32), y);
 34    }
 35    if (math.isNan(x)) {
 36        // raise invalid if y is not nan
 37        const t = (y - y) / (y - y);
 38        return Complex(f32).init(x, t);
 39    }
 40    if (math.isInf(x)) {
 41        // sqrt(inf + i nan)    = inf + nan i
 42        // sqrt(inf + iy)       = inf + i0
 43        // sqrt(-inf + i nan)   = nan +- inf i
 44        // sqrt(-inf + iy)      = 0 + inf i
 45        if (math.signbit(x)) {
 46            return Complex(f32).init(@abs(y - y), math.copysign(x, y));
 47        } else {
 48            return Complex(f32).init(x, math.copysign(y - y, y));
 49        }
 50    }
 51
 52    // y = nan special case is handled fine below
 53
 54    // double-precision avoids overflow with correct rounding.
 55    const dx = @as(f64, x);
 56    const dy = @as(f64, y);
 57
 58    if (dx >= 0) {
 59        const t = @sqrt((dx + math.hypot(dx, dy)) * 0.5);
 60        return Complex(f32).init(
 61            @as(f32, @floatCast(t)),
 62            @as(f32, @floatCast(dy / (2.0 * t))),
 63        );
 64    } else {
 65        const t = @sqrt((-dx + math.hypot(dx, dy)) * 0.5);
 66        return Complex(f32).init(
 67            @as(f32, @floatCast(@abs(y) / (2.0 * t))),
 68            @as(f32, @floatCast(math.copysign(t, y))),
 69        );
 70    }
 71}
 72
 73fn sqrt64(z: Complex(f64)) Complex(f64) {
 74    // may encounter overflow for im,re >= DBL_MAX / (1 + sqrt(2))
 75    const threshold = 0x1.a827999fcef32p+1022;
 76
 77    var x = z.re;
 78    var y = z.im;
 79
 80    if (x == 0 and y == 0) {
 81        return Complex(f64).init(0, y);
 82    }
 83    if (math.isInf(y)) {
 84        return Complex(f64).init(math.inf(f64), y);
 85    }
 86    if (math.isNan(x)) {
 87        // raise invalid if y is not nan
 88        const t = (y - y) / (y - y);
 89        return Complex(f64).init(x, t);
 90    }
 91    if (math.isInf(x)) {
 92        // sqrt(inf + i nan)    = inf + nan i
 93        // sqrt(inf + iy)       = inf + i0
 94        // sqrt(-inf + i nan)   = nan +- inf i
 95        // sqrt(-inf + iy)      = 0 + inf i
 96        if (math.signbit(x)) {
 97            return Complex(f64).init(@abs(y - y), math.copysign(x, y));
 98        } else {
 99            return Complex(f64).init(x, math.copysign(y - y, y));
100        }
101    }
102
103    // y = nan special case is handled fine below
104
105    // scale to avoid overflow
106    var scale = false;
107    if (@abs(x) >= threshold or @abs(y) >= threshold) {
108        x *= 0.25;
109        y *= 0.25;
110        scale = true;
111    }
112
113    var result: Complex(f64) = undefined;
114    if (x >= 0) {
115        const t = @sqrt((x + math.hypot(x, y)) * 0.5);
116        result = Complex(f64).init(t, y / (2.0 * t));
117    } else {
118        const t = @sqrt((-x + math.hypot(x, y)) * 0.5);
119        result = Complex(f64).init(@abs(y) / (2.0 * t), math.copysign(t, y));
120    }
121
122    if (scale) {
123        result.re *= 2;
124        result.im *= 2;
125    }
126
127    return result;
128}
129
130test sqrt32 {
131    const epsilon = math.floatEps(f32);
132    const a = Complex(f32).init(5, 3);
133    const c = sqrt(a);
134
135    try testing.expectApproxEqAbs(2.3271174, c.re, epsilon);
136    try testing.expectApproxEqAbs(0.6445742, c.im, epsilon);
137}
138
139test sqrt64 {
140    const epsilon = math.floatEps(f64);
141    const a = Complex(f64).init(5, 3);
142    const c = sqrt(a);
143
144    try testing.expectApproxEqAbs(2.3271175190399496, c.re, epsilon);
145    try testing.expectApproxEqAbs(0.6445742373246469, c.im, epsilon);
146}