master
  1// Ported from musl, which is licensed under the MIT license:
  2// https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT
  3//
  4// https://git.musl-libc.org/cgit/musl/tree/src/math/tgamma.c
  5
  6const builtin = @import("builtin");
  7const std = @import("../std.zig");
  8
  9/// Returns the gamma function of x,
 10/// gamma(x) = factorial(x - 1) for integer x.
 11///
 12/// Special Cases:
 13///  - gamma(+-nan) = nan
 14///  - gamma(-inf)  = nan
 15///  - gamma(n)     = nan for negative integers
 16///  - gamma(-0.0)  = -inf
 17///  - gamma(+0.0)  = +inf
 18///  - gamma(+inf)  = +inf
 19pub fn gamma(comptime T: type, x: T) T {
 20    if (T != f32 and T != f64) {
 21        @compileError("gamma not implemented for " ++ @typeName(T));
 22    }
 23    // common integer case first
 24    if (x == @trunc(x)) {
 25        // gamma(-inf) = nan
 26        // gamma(n)    = nan for negative integers
 27        if (x < 0) {
 28            return std.math.nan(T);
 29        }
 30        // gamma(-0.0) = -inf
 31        // gamma(+0.0) = +inf
 32        if (x == 0) {
 33            return 1 / x;
 34        }
 35        if (x < integer_result_table.len) {
 36            const i = @as(u8, @intFromFloat(x));
 37            return @floatCast(integer_result_table[i]);
 38        }
 39    }
 40    // below this, result underflows, but has a sign
 41    // negative for (-1,  0)
 42    // positive for (-2, -1)
 43    // negative for (-3, -2)
 44    // ...
 45    const lower_bound = if (T == f64) -184 else -42;
 46    if (x < lower_bound) {
 47        return if (@mod(x, 2) > 1) -0.0 else 0.0;
 48    }
 49    // above this, result overflows
 50    // gamma(+inf) = +inf
 51    const upper_bound = if (T == f64) 172 else 36;
 52    if (x > upper_bound) {
 53        return std.math.inf(T);
 54    }
 55
 56    const abs = @abs(x);
 57    // perfect precision here
 58    if (abs < 0x1p-54) {
 59        return 1 / x;
 60    }
 61
 62    const base = abs + lanczos_minus_half;
 63    const exponent = abs - 0.5;
 64    // error of y for correction, see
 65    // https://github.com/python/cpython/blob/5dc79e3d7f26a6a871a89ce3efc9f1bcee7bb447/Modules/mathmodule.c#L286-L324
 66    const e = if (abs > lanczos_minus_half)
 67        base - abs - lanczos_minus_half
 68    else
 69        base - lanczos_minus_half - abs;
 70    const correction = lanczos * e / base;
 71    const initial = series(T, abs) * @exp(-base);
 72
 73    // use reflection formula for negatives
 74    if (x < 0) {
 75        const reflected = -std.math.pi / (abs * sinpi(T, abs) * initial);
 76        const corrected = reflected - reflected * correction;
 77        const half_pow = std.math.pow(T, base, -0.5 * exponent);
 78        return corrected * half_pow * half_pow;
 79    } else {
 80        const corrected = initial + initial * correction;
 81        const half_pow = std.math.pow(T, base, 0.5 * exponent);
 82        return corrected * half_pow * half_pow;
 83    }
 84}
 85
 86/// Returns the natural logarithm of the absolute value of the gamma function.
 87///
 88/// Special Cases:
 89///  - lgamma(+-nan) = nan
 90///  - lgamma(+-inf) = +inf
 91///  - lgamma(n)     = +inf for negative integers
 92///  - lgamma(+-0.0) = +inf
 93///  - lgamma(1)     = +0.0
 94///  - lgamma(2)     = +0.0
 95pub fn lgamma(comptime T: type, x: T) T {
 96    if (T != f32 and T != f64) {
 97        @compileError("gamma not implemented for " ++ @typeName(T));
 98    }
 99    // common integer case first
100    if (x == @trunc(x)) {
101        // lgamma(-inf)  = +inf
102        // lgamma(n)     = +inf for negative integers
103        // lgamma(+-0.0) = +inf
104        if (x <= 0) {
105            return std.math.inf(T);
106        }
107        // lgamma(1) = +0.0
108        // lgamma(2) = +0.0
109        if (x < integer_result_table.len) {
110            const i = @as(u8, @intFromFloat(x));
111            return @log(@as(T, @floatCast(integer_result_table[i])));
112        }
113        // lgamma(+inf) = +inf
114        if (std.math.isPositiveInf(x)) {
115            return x;
116        }
117    }
118
119    const abs = @abs(x);
120    // perfect precision here
121    if (abs < 0x1p-54) {
122        return -@log(abs);
123    }
124    // obvious approach when overflow is not a problem
125    const upper_bound = if (T == f64) 128 else 26;
126    if (abs < upper_bound) {
127        return @log(@abs(gamma(T, x)));
128    }
129
130    const log_base = @log(abs + lanczos_minus_half) - 1;
131    const exponent = abs - 0.5;
132    const log_series = @log(series(T, abs));
133    const initial = exponent * log_base + log_series - lanczos;
134
135    // use reflection formula for negatives
136    if (x < 0) {
137        const reflected = std.math.pi / (abs * sinpi(T, abs));
138        return @log(@abs(reflected)) - initial;
139    }
140    return initial;
141}
142
143// table of factorials for integer early return
144// stops at 22 because 23 isn't representable with full precision on f64
145const integer_result_table = [_]f64{
146    std.math.inf(f64), // gamma(+0.0)
147    1, // gamma(1)
148    1, // ...
149    2,
150    6,
151    24,
152    120,
153    720,
154    5040,
155    40320,
156    362880,
157    3628800,
158    39916800,
159    479001600,
160    6227020800,
161    87178291200,
162    1307674368000,
163    20922789888000,
164    355687428096000,
165    6402373705728000,
166    121645100408832000,
167    2432902008176640000,
168    51090942171709440000, // gamma(22)
169};
170
171// "g" constant, arbitrary
172const lanczos = 6.024680040776729583740234375;
173const lanczos_minus_half = lanczos - 0.5;
174
175fn series(comptime T: type, abs: T) T {
176    const numerator = [_]T{
177        23531376880.410759688572007674451636754734846804940,
178        42919803642.649098768957899047001988850926355848959,
179        35711959237.355668049440185451547166705960488635843,
180        17921034426.037209699919755754458931112671403265390,
181        6039542586.3520280050642916443072979210699388420708,
182        1439720407.3117216736632230727949123939715485786772,
183        248874557.86205415651146038641322942321632125127801,
184        31426415.585400194380614231628318205362874684987640,
185        2876370.6289353724412254090516208496135991145378768,
186        186056.26539522349504029498971604569928220784236328,
187        8071.6720023658162106380029022722506138218516325024,
188        210.82427775157934587250973392071336271166969580291,
189        2.5066282746310002701649081771338373386264310793408,
190    };
191    const denominator = [_]T{
192        0.0,
193        39916800.0,
194        120543840.0,
195        150917976.0,
196        105258076.0,
197        45995730.0,
198        13339535.0,
199        2637558.0,
200        357423.0,
201        32670.0,
202        1925.0,
203        66.0,
204        1.0,
205    };
206    var num: T = 0;
207    var den: T = 0;
208    // split to avoid overflow
209    if (abs < 8) {
210        // big abs would overflow here
211        for (0..numerator.len) |i| {
212            num = num * abs + numerator[numerator.len - 1 - i];
213            den = den * abs + denominator[numerator.len - 1 - i];
214        }
215    } else {
216        // small abs would overflow here
217        for (0..numerator.len) |i| {
218            num = num / abs + numerator[i];
219            den = den / abs + denominator[i];
220        }
221    }
222    return num / den;
223}
224
225// precise sin(pi * x)
226// but not for integer x or |x| < 2^-54, we handle those already
227fn sinpi(comptime T: type, x: T) T {
228    const xmod2 = @mod(x, 2); // [0, 2]
229    const n = (@as(u8, @intFromFloat(4 * xmod2)) + 1) / 2; // {0, 1, 2, 3, 4}
230    const y = xmod2 - 0.5 * @as(T, @floatFromInt(n)); // [-0.25, 0.25]
231    return switch (n) {
232        0, 4 => @sin(std.math.pi * y),
233        1 => @cos(std.math.pi * y),
234        2 => -@sin(std.math.pi * y),
235        3 => -@cos(std.math.pi * y),
236        else => unreachable,
237    };
238}
239
240const expect = std.testing.expect;
241const expectEqual = std.testing.expectEqual;
242const expectApproxEqRel = std.testing.expectApproxEqRel;
243
244test gamma {
245    inline for (&.{ f32, f64 }) |T| {
246        const eps = @sqrt(std.math.floatEps(T));
247        try expectApproxEqRel(@as(T, 120.0), gamma(T, 6), eps);
248        try expectApproxEqRel(@as(T, 362880.0), gamma(T, 10), eps);
249        try expectApproxEqRel(@as(T, 6402373705728000.0), gamma(T, 19), eps);
250
251        try expectApproxEqRel(@as(T, 332.7590766955334570), gamma(T, 0.003), eps);
252        try expectApproxEqRel(@as(T, 1.377260301981044573), gamma(T, 0.654), eps);
253        try expectApproxEqRel(@as(T, 1.025393882573518478), gamma(T, 0.959), eps);
254
255        try expectApproxEqRel(@as(T, 7.361898021467681690), gamma(T, 4.16), eps);
256        try expectApproxEqRel(@as(T, 198337.2940287730753), gamma(T, 9.73), eps);
257        try expectApproxEqRel(@as(T, 113718145797241.1666), gamma(T, 17.6), eps);
258
259        try expectApproxEqRel(@as(T, -1.13860211111081424930673), gamma(T, -2.80), eps);
260        try expectApproxEqRel(@as(T, 0.00018573407931875070158), gamma(T, -7.74), eps);
261        try expectApproxEqRel(@as(T, -0.00000001647990903942825), gamma(T, -12.1), eps);
262    }
263}
264
265test "gamma.special" {
266    if (builtin.cpu.arch.isArm() and builtin.target.abi.float() == .soft) return error.SkipZigTest; // https://github.com/ziglang/zig/issues/21234
267
268    inline for (&.{ f32, f64 }) |T| {
269        try expect(std.math.isNan(gamma(T, -std.math.nan(T))));
270        try expect(std.math.isNan(gamma(T, std.math.nan(T))));
271        try expect(std.math.isNan(gamma(T, -std.math.inf(T))));
272
273        try expect(std.math.isNan(gamma(T, -4)));
274        try expect(std.math.isNan(gamma(T, -11)));
275        try expect(std.math.isNan(gamma(T, -78)));
276
277        try expectEqual(-std.math.inf(T), gamma(T, -0.0));
278        try expectEqual(std.math.inf(T), gamma(T, 0.0));
279
280        try expect(std.math.isNegativeZero(gamma(T, -200.5)));
281        try expect(std.math.isPositiveZero(gamma(T, -201.5)));
282        try expect(std.math.isNegativeZero(gamma(T, -202.5)));
283
284        try expectEqual(std.math.inf(T), gamma(T, 200));
285        try expectEqual(std.math.inf(T), gamma(T, 201));
286        try expectEqual(std.math.inf(T), gamma(T, 202));
287
288        try expectEqual(std.math.inf(T), gamma(T, std.math.inf(T)));
289    }
290}
291
292test lgamma {
293    inline for (&.{ f32, f64 }) |T| {
294        const eps = @sqrt(std.math.floatEps(T));
295        try expectApproxEqRel(@as(T, @log(24.0)), lgamma(T, 5), eps);
296        try expectApproxEqRel(@as(T, @log(20922789888000.0)), lgamma(T, 17), eps);
297        try expectApproxEqRel(@as(T, @log(2432902008176640000.0)), lgamma(T, 21), eps);
298
299        try expectApproxEqRel(@as(T, 2.201821590438859327), lgamma(T, 0.105), eps);
300        try expectApproxEqRel(@as(T, 1.275416975248413231), lgamma(T, 0.253), eps);
301        try expectApproxEqRel(@as(T, 0.130463884049976732), lgamma(T, 0.823), eps);
302
303        try expectApproxEqRel(@as(T, 43.24395772148497989), lgamma(T, 21.3), eps);
304        try expectApproxEqRel(@as(T, 110.6908958012102623), lgamma(T, 41.1), eps);
305        try expectApproxEqRel(@as(T, 215.2123266224689711), lgamma(T, 67.4), eps);
306
307        try expectApproxEqRel(@as(T, -122.605958469563489), lgamma(T, -43.6), eps);
308        try expectApproxEqRel(@as(T, -278.633885462703133), lgamma(T, -81.4), eps);
309        try expectApproxEqRel(@as(T, -333.247676253238363), lgamma(T, -93.6), eps);
310    }
311}
312
313test "lgamma.special" {
314    inline for (&.{ f32, f64 }) |T| {
315        try expect(std.math.isNan(lgamma(T, -std.math.nan(T))));
316        try expect(std.math.isNan(lgamma(T, std.math.nan(T))));
317
318        try expectEqual(std.math.inf(T), lgamma(T, -std.math.inf(T)));
319        try expectEqual(std.math.inf(T), lgamma(T, std.math.inf(T)));
320
321        try expectEqual(std.math.inf(T), lgamma(T, -5));
322        try expectEqual(std.math.inf(T), lgamma(T, -8));
323        try expectEqual(std.math.inf(T), lgamma(T, -15));
324
325        try expectEqual(std.math.inf(T), lgamma(T, -0.0));
326        try expectEqual(std.math.inf(T), lgamma(T, 0.0));
327
328        try expect(std.math.isPositiveZero(lgamma(T, 1)));
329        try expect(std.math.isPositiveZero(lgamma(T, 2)));
330    }
331}