Commit 21213127ec

Martin Wickham <spexguy070@gmail.com>
2020-12-09 08:17:55
Modify cityhash to work at comptime
1 parent 73b1747
Changed files (1)
lib
std
lib/std/hash/cityhash.zig
@@ -6,6 +6,19 @@
 const std = @import("std");
 const builtin = @import("builtin");
 
+inline fn offsetPtr(ptr: [*]const u8, offset: usize) [*]const u8 {
+    // ptr + offset doesn't work at comptime so we need this instead.
+    return @ptrCast([*]const u8, &ptr[offset]);
+}
+
+fn fetch32(ptr: [*]const u8, offset: usize) u32 {
+    return std.mem.readIntLittle(u32, offsetPtr(ptr, offset)[0..4]);
+}
+
+fn fetch64(ptr: [*]const u8, offset: usize) u64 {
+    return std.mem.readIntLittle(u64, offsetPtr(ptr, offset)[0..8]);
+}
+
 pub const CityHash32 = struct {
     const Self = @This();
 
@@ -13,14 +26,6 @@ pub const CityHash32 = struct {
     const c1: u32 = 0xcc9e2d51;
     const c2: u32 = 0x1b873593;
 
-    fn fetch32(ptr: [*]const u8) u32 {
-        var v: u32 = undefined;
-        @memcpy(@ptrCast([*]u8, &v), ptr, 4);
-        if (builtin.endian == .Big)
-            return @byteSwap(u32, v);
-        return v;
-    }
-
     // A 32-bit to 32-bit integer hash copied from Murmur3.
     fn fmix(h: u32) u32 {
         var h1: u32 = h;
@@ -66,21 +71,21 @@ pub const CityHash32 = struct {
         var c: u32 = 9;
         const d: u32 = b;
 
-        a +%= fetch32(str.ptr);
-        b +%= fetch32(str.ptr + str.len - 4);
-        c +%= fetch32(str.ptr + ((str.len >> 1) & 4));
+        a +%= fetch32(str.ptr, 0);
+        b +%= fetch32(str.ptr, str.len - 4);
+        c +%= fetch32(str.ptr, (str.len >> 1) & 4);
 
         return fmix(mur(c, mur(b, mur(a, d))));
     }
 
     fn hash32Len13To24(str: []const u8) u32 {
         const len: u32 = @truncate(u32, str.len);
-        const a: u32 = fetch32(str.ptr + (str.len >> 1) - 4);
-        const b: u32 = fetch32(str.ptr + 4);
-        const c: u32 = fetch32(str.ptr + str.len - 8);
-        const d: u32 = fetch32(str.ptr + (str.len >> 1));
-        const e: u32 = fetch32(str.ptr);
-        const f: u32 = fetch32(str.ptr + str.len - 4);
+        const a: u32 = fetch32(str.ptr, (str.len >> 1) - 4);
+        const b: u32 = fetch32(str.ptr, 4);
+        const c: u32 = fetch32(str.ptr, str.len - 8);
+        const d: u32 = fetch32(str.ptr, str.len >> 1);
+        const e: u32 = fetch32(str.ptr, 0);
+        const f: u32 = fetch32(str.ptr, str.len - 4);
 
         return fmix(mur(f, mur(e, mur(d, mur(c, mur(b, mur(a, len)))))));
     }
@@ -101,11 +106,11 @@ pub const CityHash32 = struct {
         var g: u32 = c1 *% len;
         var f: u32 = g;
 
-        const a0: u32 = rotr32(fetch32(str.ptr + str.len - 4) *% c1, 17) *% c2;
-        const a1: u32 = rotr32(fetch32(str.ptr + str.len - 8) *% c1, 17) *% c2;
-        const a2: u32 = rotr32(fetch32(str.ptr + str.len - 16) *% c1, 17) *% c2;
-        const a3: u32 = rotr32(fetch32(str.ptr + str.len - 12) *% c1, 17) *% c2;
-        const a4: u32 = rotr32(fetch32(str.ptr + str.len - 20) *% c1, 17) *% c2;
+        const a0: u32 = rotr32(fetch32(str.ptr, str.len - 4) *% c1, 17) *% c2;
+        const a1: u32 = rotr32(fetch32(str.ptr, str.len - 8) *% c1, 17) *% c2;
+        const a2: u32 = rotr32(fetch32(str.ptr, str.len - 16) *% c1, 17) *% c2;
+        const a3: u32 = rotr32(fetch32(str.ptr, str.len - 12) *% c1, 17) *% c2;
+        const a4: u32 = rotr32(fetch32(str.ptr, str.len - 20) *% c1, 17) *% c2;
 
         h ^= a0;
         h = rotr32(h, 19);
@@ -125,11 +130,11 @@ pub const CityHash32 = struct {
         var iters = (str.len - 1) / 20;
         var ptr = str.ptr;
         while (iters != 0) : (iters -= 1) {
-            const b0: u32 = rotr32(fetch32(ptr) *% c1, 17) *% c2;
-            const b1: u32 = fetch32(ptr + 4);
-            const b2: u32 = rotr32(fetch32(ptr + 8) *% c1, 17) *% c2;
-            const b3: u32 = rotr32(fetch32(ptr + 12) *% c1, 17) *% c2;
-            const b4: u32 = fetch32(ptr + 16);
+            const b0: u32 = rotr32(fetch32(ptr, 0) *% c1, 17) *% c2;
+            const b1: u32 = fetch32(ptr, 4);
+            const b2: u32 = rotr32(fetch32(ptr, 8) *% c1, 17) *% c2;
+            const b3: u32 = rotr32(fetch32(ptr, 12) *% c1, 17) *% c2;
+            const b4: u32 = fetch32(ptr, 16);
 
             h ^= b0;
             h = rotr32(h, 18);
@@ -152,7 +157,7 @@ pub const CityHash32 = struct {
             h = f;
             f = g;
             g = t;
-            ptr += 20;
+            ptr = offsetPtr(ptr, 20);
         }
         g = rotr32(g, 11) *% c1;
         g = rotr32(g, 17) *% c1;
@@ -176,22 +181,6 @@ pub const CityHash64 = struct {
     const k1: u64 = 0xb492b66fbe98f273;
     const k2: u64 = 0x9ae16a3b2f90404f;
 
-    fn fetch32(ptr: [*]const u8) u32 {
-        var v: u32 = undefined;
-        @memcpy(@ptrCast([*]u8, &v), ptr, 4);
-        if (builtin.endian == .Big)
-            return @byteSwap(u32, v);
-        return v;
-    }
-
-    fn fetch64(ptr: [*]const u8) u64 {
-        var v: u64 = undefined;
-        @memcpy(@ptrCast([*]u8, &v), ptr, 8);
-        if (builtin.endian == .Big)
-            return @byteSwap(u64, v);
-        return v;
-    }
-
     // Rotate right helper
     fn rotr64(x: u64, comptime r: u64) u64 {
         return (x >> r) | (x << (64 - r));
@@ -222,16 +211,16 @@ pub const CityHash64 = struct {
         const len: u64 = @as(u64, str.len);
         if (len >= 8) {
             const mul: u64 = k2 +% len *% 2;
-            const a: u64 = fetch64(str.ptr) +% k2;
-            const b: u64 = fetch64(str.ptr + str.len - 8);
+            const a: u64 = fetch64(str.ptr, 0) +% k2;
+            const b: u64 = fetch64(str.ptr, str.len - 8);
             const c: u64 = rotr64(b, 37) *% mul +% a;
             const d: u64 = (rotr64(a, 25) +% b) *% mul;
             return hashLen16Mul(c, d, mul);
         }
         if (len >= 4) {
             const mul: u64 = k2 +% len *% 2;
-            const a: u64 = fetch32(str.ptr);
-            return hashLen16Mul(len +% (a << 3), fetch32(str.ptr + str.len - 4), mul);
+            const a: u64 = fetch32(str.ptr, 0);
+            return hashLen16Mul(len +% (a << 3), fetch32(str.ptr, str.len - 4), mul);
         }
         if (len > 0) {
             const a: u8 = str[0];
@@ -247,10 +236,10 @@ pub const CityHash64 = struct {
     fn hashLen17To32(str: []const u8) u64 {
         const len: u64 = @as(u64, str.len);
         const mul: u64 = k2 +% len *% 2;
-        const a: u64 = fetch64(str.ptr) *% k1;
-        const b: u64 = fetch64(str.ptr + 8);
-        const c: u64 = fetch64(str.ptr + str.len - 8) *% mul;
-        const d: u64 = fetch64(str.ptr + str.len - 16) *% k2;
+        const a: u64 = fetch64(str.ptr, 0) *% k1;
+        const b: u64 = fetch64(str.ptr, 8);
+        const c: u64 = fetch64(str.ptr, str.len - 8) *% mul;
+        const d: u64 = fetch64(str.ptr, str.len - 16) *% k2;
 
         return hashLen16Mul(rotr64(a +% b, 43) +% rotr64(c, 30) +% d, a +% rotr64(b +% k2, 18) +% c, mul);
     }
@@ -258,14 +247,14 @@ pub const CityHash64 = struct {
     fn hashLen33To64(str: []const u8) u64 {
         const len: u64 = @as(u64, str.len);
         const mul: u64 = k2 +% len *% 2;
-        const a: u64 = fetch64(str.ptr) *% k2;
-        const b: u64 = fetch64(str.ptr + 8);
-        const c: u64 = fetch64(str.ptr + str.len - 24);
-        const d: u64 = fetch64(str.ptr + str.len - 32);
-        const e: u64 = fetch64(str.ptr + 16) *% k2;
-        const f: u64 = fetch64(str.ptr + 24) *% 9;
-        const g: u64 = fetch64(str.ptr + str.len - 8);
-        const h: u64 = fetch64(str.ptr + str.len - 16) *% mul;
+        const a: u64 = fetch64(str.ptr, 0) *% k2;
+        const b: u64 = fetch64(str.ptr, 8);
+        const c: u64 = fetch64(str.ptr, str.len - 24);
+        const d: u64 = fetch64(str.ptr, str.len - 32);
+        const e: u64 = fetch64(str.ptr, 16) *% k2;
+        const f: u64 = fetch64(str.ptr, 24) *% 9;
+        const g: u64 = fetch64(str.ptr, str.len - 8);
+        const h: u64 = fetch64(str.ptr, str.len - 16) *% mul;
 
         const u: u64 = rotr64(a +% g, 43) +% (rotr64(b, 30) +% c) *% 9;
         const v: u64 = ((a +% g) ^ d) +% f +% 1;
@@ -297,10 +286,10 @@ pub const CityHash64 = struct {
 
     fn weakHashLen32WithSeeds(ptr: [*]const u8, a: u64, b: u64) WeakPair {
         return @call(.{ .modifier = .always_inline }, weakHashLen32WithSeedsHelper, .{
-            fetch64(ptr),
-            fetch64(ptr + 8),
-            fetch64(ptr + 16),
-            fetch64(ptr + 24),
+            fetch64(ptr, 0),
+            fetch64(ptr, 8),
+            fetch64(ptr, 16),
+            fetch64(ptr, 24),
             a,
             b,
         });
@@ -319,29 +308,29 @@ pub const CityHash64 = struct {
 
         var len: u64 = @as(u64, str.len);
 
-        var x: u64 = fetch64(str.ptr + str.len - 40);
-        var y: u64 = fetch64(str.ptr + str.len - 16) +% fetch64(str.ptr + str.len - 56);
-        var z: u64 = hashLen16(fetch64(str.ptr + str.len - 48) +% len, fetch64(str.ptr + str.len - 24));
-        var v: WeakPair = weakHashLen32WithSeeds(str.ptr + str.len - 64, len, z);
-        var w: WeakPair = weakHashLen32WithSeeds(str.ptr + str.len - 32, y +% k1, x);
+        var x: u64 = fetch64(str.ptr, str.len - 40);
+        var y: u64 = fetch64(str.ptr, str.len - 16) +% fetch64(str.ptr, str.len - 56);
+        var z: u64 = hashLen16(fetch64(str.ptr, str.len - 48) +% len, fetch64(str.ptr, str.len - 24));
+        var v: WeakPair = weakHashLen32WithSeeds(offsetPtr(str.ptr, str.len - 64), len, z);
+        var w: WeakPair = weakHashLen32WithSeeds(offsetPtr(str.ptr, str.len - 32), y +% k1, x);
 
-        x = x *% k1 +% fetch64(str.ptr);
+        x = x *% k1 +% fetch64(str.ptr, 0);
         len = (len - 1) & ~@intCast(u64, 63);
 
         var ptr: [*]const u8 = str.ptr;
         while (true) {
-            x = rotr64(x +% y +% v.first +% fetch64(ptr + 8), 37) *% k1;
-            y = rotr64(y +% v.second +% fetch64(ptr + 48), 42) *% k1;
+            x = rotr64(x +% y +% v.first +% fetch64(ptr, 8), 37) *% k1;
+            y = rotr64(y +% v.second +% fetch64(ptr, 48), 42) *% k1;
             x ^= w.second;
-            y +%= v.first +% fetch64(ptr + 40);
+            y +%= v.first +% fetch64(ptr, 40);
             z = rotr64(z +% w.first, 33) *% k1;
             v = weakHashLen32WithSeeds(ptr, v.second *% k1, x +% w.first);
-            w = weakHashLen32WithSeeds(ptr + 32, z +% w.second, y +% fetch64(ptr + 16));
+            w = weakHashLen32WithSeeds(offsetPtr(ptr, 32), z +% w.second, y +% fetch64(ptr, 16));
             const t: u64 = z;
             z = x;
             x = t;
 
-            ptr += 64;
+            ptr = offsetPtr(ptr, 64);
             len -= 64;
             if (len == 0)
                 break;
@@ -359,27 +348,31 @@ pub const CityHash64 = struct {
     }
 };
 
-fn SMHasherTest(comptime hash_fn: anytype, comptime hashbits: u32) u32 {
-    const hashbytes = hashbits / 8;
+fn SMHasherTest(comptime hash_fn: anytype) u32 {
+    const HashResult = @typeInfo(@TypeOf(hash_fn)).Fn.return_type.?;
+
     var key: [256]u8 = undefined;
-    var hashes: [hashbytes * 256]u8 = undefined;
-    var final: [hashbytes]u8 = undefined;
+    var hashes_bytes: [256 * @sizeOf(HashResult)]u8 = undefined;
+    var final: HashResult = 0;
 
-    @memset(@ptrCast([*]u8, &key[0]), 0, @sizeOf(@TypeOf(key)));
-    @memset(@ptrCast([*]u8, &hashes[0]), 0, @sizeOf(@TypeOf(hashes)));
-    @memset(@ptrCast([*]u8, &final[0]), 0, @sizeOf(@TypeOf(final)));
+    std.mem.set(u8, &key, 0);
+    std.mem.set(u8, &hashes_bytes, 0);
 
     var i: u32 = 0;
     while (i < 256) : (i += 1) {
         key[i] = @intCast(u8, i);
 
-        var h = hash_fn(key[0..i], 256 - i);
-        if (builtin.endian == .Big)
-            h = @byteSwap(@TypeOf(h), h);
-        @memcpy(@ptrCast([*]u8, &hashes[i * hashbytes]), @ptrCast([*]u8, &h), hashbytes);
+        var h: HashResult = hash_fn(key[0..i], 256 - i);
+
+        // comptime can't really do reinterpret casting yet,
+        // so we need to write the bytes manually.
+        for (hashes_bytes[i*@sizeOf(HashResult)..][0..@sizeOf(HashResult)]) |*byte| {
+            byte.* = @truncate(u8, h);
+            h = h >> 8;
+        }
     }
 
-    return @truncate(u32, hash_fn(&hashes, 0));
+    return @truncate(u32, hash_fn(&hashes_bytes, 0));
 }
 
 fn CityHash32hashIgnoreSeed(str: []const u8, seed: u32) u32 {
@@ -387,13 +380,28 @@ fn CityHash32hashIgnoreSeed(str: []const u8, seed: u32) u32 {
 }
 
 test "cityhash32" {
-    // Note: SMHasher doesn't provide a 32bit version of the algorithm.
-    // Note: The implementation was verified against the Google Abseil version.
-    std.testing.expectEqual(SMHasherTest(CityHash32hashIgnoreSeed, 32), 0x68254F81);
+    const Test = struct {
+        fn doTest() void {
+            // Note: SMHasher doesn't provide a 32bit version of the algorithm.
+            // Note: The implementation was verified against the Google Abseil version.
+            std.testing.expectEqual(SMHasherTest(CityHash32hashIgnoreSeed), 0x68254F81);
+            std.testing.expectEqual(SMHasherTest(CityHash32hashIgnoreSeed), 0x68254F81);
+        }
+    };
+    Test.doTest();
+    @setEvalBranchQuota(50000);
+    comptime Test.doTest();
 }
 
 test "cityhash64" {
-    // Note: This is not compliant with the SMHasher implementation of CityHash64!
-    // Note: The implementation was verified against the Google Abseil version.
-    std.testing.expectEqual(SMHasherTest(CityHash64.hashWithSeed, 64), 0x5FABC5C5);
+    const Test = struct {
+        fn doTest() void {
+            // Note: This is not compliant with the SMHasher implementation of CityHash64!
+            // Note: The implementation was verified against the Google Abseil version.
+            std.testing.expectEqual(SMHasherTest(CityHash64.hashWithSeed), 0x5FABC5C5);
+        }
+    };
+    Test.doTest();
+    @setEvalBranchQuota(50000);
+    comptime Test.doTest();
 }