Commit 05380a11a4

Filippo Casarin <casarin.filippo17@gmail.com>
2021-03-24 20:59:15
std.math.sqrt_int: fixed odd size integers types
1 parent 4a3ac16
Changed files (1)
lib
std
lib/std/math/sqrt.zig
@@ -39,55 +39,50 @@ pub fn sqrt(x: anytype) Sqrt(@TypeOf(x)) {
 }
 
 fn sqrt_int(comptime T: type, value: T) Sqrt(T) {
-    switch (T) {
-        u0 => return 0,
-        u1 => return value,
-        else => {},
-    }
-
-    var op = value;
-    var res: T = 0;
-    var one: T = 1 << (@typeInfo(T).Int.bits - 2);
+    if (@typeInfo(T).Int.bits <= 2) {
+        return if (value == 0) 0 else 1; // shortcut for small number of bits to simplify general case
+    } else {
+        var op = value;
+        var res: T = 0;
+        var one: T = 1 << ((@typeInfo(T).Int.bits - 1) & -2); // highest power of four that fits into T
 
-    // "one" starts at the highest power of four <= than the argument.
-    while (one > op) {
-        one >>= 2;
-    }
+        // "one" starts at the highest power of four <= than the argument.
+        while (one > op) {
+            one >>= 2;
+        }
 
-    while (one != 0) {
-        if (op >= res + one) {
-            op -= res + one;
-            res += 2 * one;
+        while (one != 0) {
+            var c = op >= res + one;
+            if (c) op -= res + one;
+            res >>= 1;
+            if (c) res += one;
+            one >>= 2;
         }
-        res >>= 1;
-        one >>= 2;
-    }
 
-    const ResultType = Sqrt(T);
-    return @intCast(ResultType, res);
+        return @intCast(Sqrt(T), res);
+    }
 }
 
 test "math.sqrt_int" {
-    try expect(sqrt_int(u0, 0) == 0);
-    try expect(sqrt_int(u1, 1) == 1);
     try expect(sqrt_int(u32, 3) == 1);
     try expect(sqrt_int(u32, 4) == 2);
     try expect(sqrt_int(u32, 5) == 2);
     try expect(sqrt_int(u32, 8) == 2);
     try expect(sqrt_int(u32, 9) == 3);
     try expect(sqrt_int(u32, 10) == 3);
+
+    try expect(sqrt_int(u0, 0) == 0);
+    try expect(sqrt_int(u1, 1) == 1);
+    try expect(sqrt_int(u2, 3) == 1);
+    try expect(sqrt_int(u3, 4) == 2);
+    try expect(sqrt_int(u4, 8) == 2);
+    try expect(sqrt_int(u4, 9) == 3);
 }
 
 /// Returns the return type `sqrt` will return given an operand of type `T`.
 pub fn Sqrt(comptime T: type) type {
     return switch (@typeInfo(T)) {
-        .Int => |int| {
-            return switch (int.bits) {
-                0 => u0,
-                1 => u1,
-                else => std.meta.Int(.unsigned, int.bits / 2),
-            };
-        },
+        .Int => |int| std.meta.Int(.unsigned, (int.bits + 1) / 2),
         else => T,
     };
 }