Commit 844e05f619

Andrew Kelley <superjoe30@gmail.com>
2017-10-01 03:09:58
improve bit shift API in std.math
* `shl` moved to `shlExact` * added `shl` and `shr` which are truncating like `<<` and `>>`. closes #403
1 parent c6295fe
Changed files (1)
std
std/math/index.zig
@@ -219,11 +219,59 @@ pub fn negate(x: var) -> %@typeOf(x) {
 }
 
 error Overflow;
-pub fn shl(comptime T: type, a: T, shift_amt: Log2Int(T)) -> %T {
+pub fn shlExact(comptime T: type, a: T, shift_amt: Log2Int(T)) -> %T {
     var answer: T = undefined;
     if (@shlWithOverflow(T, a, shift_amt, &answer)) error.Overflow else answer
 }
 
+/// Shifts left. Overflowed bits are truncated.
+/// A negative shift amount results in a right shift.
+pub fn shl(comptime T: type, a: T, shift_amt: var) -> T {
+    const abs_shift_amt = absCast(shift_amt);
+    const casted_shift_amt = if (abs_shift_amt >= T.bit_count) return 0 else Log2Int(T)(abs_shift_amt);
+
+    if (@typeOf(shift_amt).is_signed) {
+        if (shift_amt >= 0) {
+            return a << casted_shift_amt;
+        } else {
+            return a >> casted_shift_amt;
+        }
+    }
+
+    return a << casted_shift_amt;
+}
+
+test "math.shl" {
+    assert(shl(u8, 0b11111111, usize(3)) == 0b11111000);
+    assert(shl(u8, 0b11111111, usize(8)) == 0);
+    assert(shl(u8, 0b11111111, usize(9)) == 0);
+    assert(shl(u8, 0b11111111, isize(-2)) == 0b00111111);
+}
+
+/// Shifts right. Overflowed bits are truncated.
+/// A negative shift amount results in a lefft shift.
+pub fn shr(comptime T: type, a: T, shift_amt: var) -> T {
+    const abs_shift_amt = absCast(shift_amt);
+    const casted_shift_amt = if (abs_shift_amt >= T.bit_count) return 0 else Log2Int(T)(abs_shift_amt);
+
+    if (@typeOf(shift_amt).is_signed) {
+        if (shift_amt >= 0) {
+            return a >> casted_shift_amt;
+        } else {
+            return a << casted_shift_amt;
+        }
+    }
+
+    return a >> casted_shift_amt;
+}
+
+test "math.shr" {
+    assert(shr(u8, 0b11111111, usize(3)) == 0b00011111);
+    assert(shr(u8, 0b11111111, usize(8)) == 0);
+    assert(shr(u8, 0b11111111, usize(9)) == 0);
+    assert(shr(u8, 0b11111111, isize(-2)) == 0b11111100);
+}
+
 pub fn Log2Int(comptime T: type) -> type {
     @IntType(false, log2(T.bit_count))
 }
@@ -237,7 +285,7 @@ fn testOverflow() {
     assert(%%mul(i32, 3, 4) == 12);
     assert(%%add(i32, 3, 4) == 7);
     assert(%%sub(i32, 3, 4) == -1);
-    assert(%%shl(i32, 0b11, 4) == 0b110000);
+    assert(%%shlExact(i32, 0b11, 4) == 0b110000);
 }