master
 1const std = @import("../std.zig");
 2const math = std.math;
 3const expect = std.testing.expect;
 4const TypeId = std.builtin.TypeId;
 5const maxInt = std.math.maxInt;
 6
 7/// Returns the square root of x.
 8///
 9/// Special Cases:
10///  - sqrt(+inf)  = +inf
11///  - sqrt(+-0)   = +-0
12///  - sqrt(x)     = nan if x < 0
13///  - sqrt(nan)   = nan
14/// TODO Decide if all this logic should be implemented directly in the @sqrt builtin function.
15pub fn sqrt(x: anytype) Sqrt(@TypeOf(x)) {
16    const T = @TypeOf(x);
17    switch (@typeInfo(T)) {
18        .float, .comptime_float => return @sqrt(x),
19        .comptime_int => comptime {
20            if (x > maxInt(u128)) {
21                @compileError("sqrt not implemented for comptime_int greater than 128 bits");
22            }
23            if (x < 0) {
24                @compileError("sqrt on negative number");
25            }
26            return @as(T, sqrt_int(u128, x));
27        },
28        .int => |IntType| switch (IntType.signedness) {
29            .signed => @compileError("sqrt not implemented for signed integers"),
30            .unsigned => return sqrt_int(T, x),
31        },
32        else => @compileError("sqrt not implemented for " ++ @typeName(T)),
33    }
34}
35
36fn sqrt_int(comptime T: type, value: T) Sqrt(T) {
37    if (@typeInfo(T).int.bits <= 2) {
38        return if (value == 0) 0 else 1; // shortcut for small number of bits to simplify general case
39    } else {
40        const bits = @typeInfo(T).int.bits;
41        const max = math.maxInt(T);
42        const minustwo = (@as(T, 2) ^ max) + 1; // unsigned int cannot represent -2
43        var op = value;
44        var res: T = 0;
45        var one: T = 1 << ((bits - 1) & minustwo); // highest power of four that fits into T
46
47        // "one" starts at the highest power of four <= than the argument.
48        while (one > op) {
49            one >>= 2;
50        }
51
52        while (one != 0) {
53            const c = op >= res + one;
54            if (c) op -= res + one;
55            res >>= 1;
56            if (c) res += one;
57            one >>= 2;
58        }
59
60        return @as(Sqrt(T), @intCast(res));
61    }
62}
63
64test sqrt_int {
65    try expect(sqrt_int(u32, 3) == 1);
66    try expect(sqrt_int(u32, 4) == 2);
67    try expect(sqrt_int(u32, 5) == 2);
68    try expect(sqrt_int(u32, 8) == 2);
69    try expect(sqrt_int(u32, 9) == 3);
70    try expect(sqrt_int(u32, 10) == 3);
71
72    try expect(sqrt_int(u0, 0) == 0);
73    try expect(sqrt_int(u1, 1) == 1);
74    try expect(sqrt_int(u2, 3) == 1);
75    try expect(sqrt_int(u3, 4) == 2);
76    try expect(sqrt_int(u4, 8) == 2);
77    try expect(sqrt_int(u4, 9) == 3);
78}
79
80/// Returns the return type `sqrt` will return given an operand of type `T`.
81pub fn Sqrt(comptime T: type) type {
82    return switch (@typeInfo(T)) {
83        .int => |int| @Int(.unsigned, (int.bits + 1) / 2),
84        else => T,
85    };
86}