Commit 05380a11a4
2021-03-24 20:59:15
1 parent
4a3ac16Changed files (1)
lib
std
math
lib/std/math/sqrt.zig
@@ -39,55 +39,50 @@ pub fn sqrt(x: anytype) Sqrt(@TypeOf(x)) {
}
fn sqrt_int(comptime T: type, value: T) Sqrt(T) {
- switch (T) {
- u0 => return 0,
- u1 => return value,
- else => {},
- }
-
- var op = value;
- var res: T = 0;
- var one: T = 1 << (@typeInfo(T).Int.bits - 2);
+ if (@typeInfo(T).Int.bits <= 2) {
+ return if (value == 0) 0 else 1; // shortcut for small number of bits to simplify general case
+ } else {
+ var op = value;
+ var res: T = 0;
+ var one: T = 1 << ((@typeInfo(T).Int.bits - 1) & -2); // highest power of four that fits into T
- // "one" starts at the highest power of four <= than the argument.
- while (one > op) {
- one >>= 2;
- }
+ // "one" starts at the highest power of four <= than the argument.
+ while (one > op) {
+ one >>= 2;
+ }
- while (one != 0) {
- if (op >= res + one) {
- op -= res + one;
- res += 2 * one;
+ while (one != 0) {
+ var c = op >= res + one;
+ if (c) op -= res + one;
+ res >>= 1;
+ if (c) res += one;
+ one >>= 2;
}
- res >>= 1;
- one >>= 2;
- }
- const ResultType = Sqrt(T);
- return @intCast(ResultType, res);
+ return @intCast(Sqrt(T), res);
+ }
}
test "math.sqrt_int" {
- try expect(sqrt_int(u0, 0) == 0);
- try expect(sqrt_int(u1, 1) == 1);
try expect(sqrt_int(u32, 3) == 1);
try expect(sqrt_int(u32, 4) == 2);
try expect(sqrt_int(u32, 5) == 2);
try expect(sqrt_int(u32, 8) == 2);
try expect(sqrt_int(u32, 9) == 3);
try expect(sqrt_int(u32, 10) == 3);
+
+ try expect(sqrt_int(u0, 0) == 0);
+ try expect(sqrt_int(u1, 1) == 1);
+ try expect(sqrt_int(u2, 3) == 1);
+ try expect(sqrt_int(u3, 4) == 2);
+ try expect(sqrt_int(u4, 8) == 2);
+ try expect(sqrt_int(u4, 9) == 3);
}
/// Returns the return type `sqrt` will return given an operand of type `T`.
pub fn Sqrt(comptime T: type) type {
return switch (@typeInfo(T)) {
- .Int => |int| {
- return switch (int.bits) {
- 0 => u0,
- 1 => u1,
- else => std.meta.Int(.unsigned, int.bits / 2),
- };
- },
+ .Int => |int| std.meta.Int(.unsigned, (int.bits + 1) / 2),
else => T,
};
}