master
1const std = @import("../std.zig");
2const math = std.math;
3const expect = std.testing.expect;
4const TypeId = std.builtin.TypeId;
5const maxInt = std.math.maxInt;
6
7/// Returns the square root of x.
8///
9/// Special Cases:
10/// - sqrt(+inf) = +inf
11/// - sqrt(+-0) = +-0
12/// - sqrt(x) = nan if x < 0
13/// - sqrt(nan) = nan
14/// TODO Decide if all this logic should be implemented directly in the @sqrt builtin function.
15pub fn sqrt(x: anytype) Sqrt(@TypeOf(x)) {
16 const T = @TypeOf(x);
17 switch (@typeInfo(T)) {
18 .float, .comptime_float => return @sqrt(x),
19 .comptime_int => comptime {
20 if (x > maxInt(u128)) {
21 @compileError("sqrt not implemented for comptime_int greater than 128 bits");
22 }
23 if (x < 0) {
24 @compileError("sqrt on negative number");
25 }
26 return @as(T, sqrt_int(u128, x));
27 },
28 .int => |IntType| switch (IntType.signedness) {
29 .signed => @compileError("sqrt not implemented for signed integers"),
30 .unsigned => return sqrt_int(T, x),
31 },
32 else => @compileError("sqrt not implemented for " ++ @typeName(T)),
33 }
34}
35
36fn sqrt_int(comptime T: type, value: T) Sqrt(T) {
37 if (@typeInfo(T).int.bits <= 2) {
38 return if (value == 0) 0 else 1; // shortcut for small number of bits to simplify general case
39 } else {
40 const bits = @typeInfo(T).int.bits;
41 const max = math.maxInt(T);
42 const minustwo = (@as(T, 2) ^ max) + 1; // unsigned int cannot represent -2
43 var op = value;
44 var res: T = 0;
45 var one: T = 1 << ((bits - 1) & minustwo); // highest power of four that fits into T
46
47 // "one" starts at the highest power of four <= than the argument.
48 while (one > op) {
49 one >>= 2;
50 }
51
52 while (one != 0) {
53 const c = op >= res + one;
54 if (c) op -= res + one;
55 res >>= 1;
56 if (c) res += one;
57 one >>= 2;
58 }
59
60 return @as(Sqrt(T), @intCast(res));
61 }
62}
63
64test sqrt_int {
65 try expect(sqrt_int(u32, 3) == 1);
66 try expect(sqrt_int(u32, 4) == 2);
67 try expect(sqrt_int(u32, 5) == 2);
68 try expect(sqrt_int(u32, 8) == 2);
69 try expect(sqrt_int(u32, 9) == 3);
70 try expect(sqrt_int(u32, 10) == 3);
71
72 try expect(sqrt_int(u0, 0) == 0);
73 try expect(sqrt_int(u1, 1) == 1);
74 try expect(sqrt_int(u2, 3) == 1);
75 try expect(sqrt_int(u3, 4) == 2);
76 try expect(sqrt_int(u4, 8) == 2);
77 try expect(sqrt_int(u4, 9) == 3);
78}
79
80/// Returns the return type `sqrt` will return given an operand of type `T`.
81pub fn Sqrt(comptime T: type) type {
82 return switch (@typeInfo(T)) {
83 .int => |int| @Int(.unsigned, (int.bits + 1) / 2),
84 else => T,
85 };
86}