master
1// Ported from musl, which is licensed under the MIT license:
2// https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT
3//
4// https://git.musl-libc.org/cgit/musl/tree/src/complex/csqrtf.c
5// https://git.musl-libc.org/cgit/musl/tree/src/complex/csqrt.c
6
7const std = @import("../../std.zig");
8const testing = std.testing;
9const math = std.math;
10const cmath = math.complex;
11const Complex = cmath.Complex;
12
13/// Returns the square root of z. The real and imaginary parts of the result have the same sign
14/// as the imaginary part of z.
15pub fn sqrt(z: anytype) Complex(@TypeOf(z.re, z.im)) {
16 const T = @TypeOf(z.re, z.im);
17
18 return switch (T) {
19 f32 => sqrt32(z),
20 f64 => sqrt64(z),
21 else => @compileError("sqrt not implemented for " ++ @typeName(T)),
22 };
23}
24
25fn sqrt32(z: Complex(f32)) Complex(f32) {
26 const x = z.re;
27 const y = z.im;
28
29 if (x == 0 and y == 0) {
30 return Complex(f32).init(0, y);
31 }
32 if (math.isInf(y)) {
33 return Complex(f32).init(math.inf(f32), y);
34 }
35 if (math.isNan(x)) {
36 // raise invalid if y is not nan
37 const t = (y - y) / (y - y);
38 return Complex(f32).init(x, t);
39 }
40 if (math.isInf(x)) {
41 // sqrt(inf + i nan) = inf + nan i
42 // sqrt(inf + iy) = inf + i0
43 // sqrt(-inf + i nan) = nan +- inf i
44 // sqrt(-inf + iy) = 0 + inf i
45 if (math.signbit(x)) {
46 return Complex(f32).init(@abs(y - y), math.copysign(x, y));
47 } else {
48 return Complex(f32).init(x, math.copysign(y - y, y));
49 }
50 }
51
52 // y = nan special case is handled fine below
53
54 // double-precision avoids overflow with correct rounding.
55 const dx = @as(f64, x);
56 const dy = @as(f64, y);
57
58 if (dx >= 0) {
59 const t = @sqrt((dx + math.hypot(dx, dy)) * 0.5);
60 return Complex(f32).init(
61 @as(f32, @floatCast(t)),
62 @as(f32, @floatCast(dy / (2.0 * t))),
63 );
64 } else {
65 const t = @sqrt((-dx + math.hypot(dx, dy)) * 0.5);
66 return Complex(f32).init(
67 @as(f32, @floatCast(@abs(y) / (2.0 * t))),
68 @as(f32, @floatCast(math.copysign(t, y))),
69 );
70 }
71}
72
73fn sqrt64(z: Complex(f64)) Complex(f64) {
74 // may encounter overflow for im,re >= DBL_MAX / (1 + sqrt(2))
75 const threshold = 0x1.a827999fcef32p+1022;
76
77 var x = z.re;
78 var y = z.im;
79
80 if (x == 0 and y == 0) {
81 return Complex(f64).init(0, y);
82 }
83 if (math.isInf(y)) {
84 return Complex(f64).init(math.inf(f64), y);
85 }
86 if (math.isNan(x)) {
87 // raise invalid if y is not nan
88 const t = (y - y) / (y - y);
89 return Complex(f64).init(x, t);
90 }
91 if (math.isInf(x)) {
92 // sqrt(inf + i nan) = inf + nan i
93 // sqrt(inf + iy) = inf + i0
94 // sqrt(-inf + i nan) = nan +- inf i
95 // sqrt(-inf + iy) = 0 + inf i
96 if (math.signbit(x)) {
97 return Complex(f64).init(@abs(y - y), math.copysign(x, y));
98 } else {
99 return Complex(f64).init(x, math.copysign(y - y, y));
100 }
101 }
102
103 // y = nan special case is handled fine below
104
105 // scale to avoid overflow
106 var scale = false;
107 if (@abs(x) >= threshold or @abs(y) >= threshold) {
108 x *= 0.25;
109 y *= 0.25;
110 scale = true;
111 }
112
113 var result: Complex(f64) = undefined;
114 if (x >= 0) {
115 const t = @sqrt((x + math.hypot(x, y)) * 0.5);
116 result = Complex(f64).init(t, y / (2.0 * t));
117 } else {
118 const t = @sqrt((-x + math.hypot(x, y)) * 0.5);
119 result = Complex(f64).init(@abs(y) / (2.0 * t), math.copysign(t, y));
120 }
121
122 if (scale) {
123 result.re *= 2;
124 result.im *= 2;
125 }
126
127 return result;
128}
129
130test sqrt32 {
131 const epsilon = math.floatEps(f32);
132 const a = Complex(f32).init(5, 3);
133 const c = sqrt(a);
134
135 try testing.expectApproxEqAbs(2.3271174, c.re, epsilon);
136 try testing.expectApproxEqAbs(0.6445742, c.im, epsilon);
137}
138
139test sqrt64 {
140 const epsilon = math.floatEps(f64);
141 const a = Complex(f64).init(5, 3);
142 const c = sqrt(a);
143
144 try testing.expectApproxEqAbs(2.3271175190399496, c.re, epsilon);
145 try testing.expectApproxEqAbs(0.6445742373246469, c.im, epsilon);
146}