Commit fbcdf78cbd

Marc Tiehuis <marc@tiehu.is>
2019-08-26 12:32:39
Simplify wyhash and improve speed
This removes the exposed stateless variant since the standard variant has similar speed now. Using `./benchmark --filter wyhash --count 1024`, the speed change has changed from: wyhash iterative: 4093 MiB/s [6f76b0d5db7db34c] small keys: 3132 MiB/s [28c2f43c70000000] to wyhash iterative: 6515 MiB/s [673e9bb86da93ea4] small keys: 10487 MiB/s [28c2f43c70000000]
1 parent 90e921f
Changed files (3)
std/hash/benchmark.zig
@@ -28,11 +28,6 @@ const hashes = [_]Hash{
         .name = "wyhash",
         .init_u64 = 0,
     },
-    Hash{
-        .ty = hash.WyhashStateless,
-        .name = "wyhash-stateless",
-        .init_u64 = 0,
-    },
     Hash{
         .ty = hash.SipHash64(1, 3),
         .name = "siphash(1,3)",
@@ -91,7 +86,7 @@ const Result = struct {
     throughput: u64,
 };
 
-const block_size: usize = 8192;
+const block_size: usize = 8 * 8192;
 
 pub fn benchmarkHash(comptime H: var, bytes: usize) !Result {
     var h = blk: {
std/hash/wyhash.zig
@@ -10,7 +10,8 @@ const primes = [_]u64{
 };
 
 fn read_bytes(comptime bytes: u8, data: []const u8) u64 {
-    return mem.readVarInt(u64, data[0..bytes], .Little);
+    const T = @IntType(false, 8 * bytes);
+    return mem.readIntSliceLittle(T, data[0..bytes]);
 }
 
 fn read_8bytes_swapped(data: []const u8) u64 {
@@ -31,25 +32,21 @@ fn mix1(a: u64, b: u64, seed: u64) u64 {
     return mum(a ^ seed ^ primes[2], b ^ seed ^ primes[3]);
 }
 
-/// Fast non-cryptographic 64bit hash function.
-/// See https://github.com/wangyi-fudan/wyhash
-pub const Wyhash = struct {
+// Wyhash version which does not store internal state for handling partial buffers.
+// This is needed so that we can maximize the speed for the short key case, which will
+// use the non-iterative api which the public Wyhash exposes.
+const WyhashStateless = struct {
     seed: u64,
-
-    buf: [32]u8,
-    buf_len: usize,
     msg_len: usize,
 
-    pub fn init(seed: u64) Wyhash {
-        return Wyhash{
+    pub fn init(seed: u64) WyhashStateless {
+        return WyhashStateless{
             .seed = seed,
-            .buf = undefined,
-            .buf_len = 0,
             .msg_len = 0,
         };
     }
 
-    fn round(self: *Wyhash, b: []const u8) void {
+    fn round(self: *WyhashStateless, b: []const u8) void {
         std.debug.assert(b.len == 32);
 
         self.seed = mix0(
@@ -63,32 +60,23 @@ pub const Wyhash = struct {
         );
     }
 
-    pub fn update(self: *Wyhash, b: []const u8) void {
-        var off: usize = 0;
-
-        // Partial from previous.
-        if (self.buf_len != 0 and self.buf_len + b.len > 32) {
-            off += 32 - self.buf_len;
-            mem.copy(u8, self.buf[self.buf_len..], b[0..off]);
-            self.round(self.buf[0..]);
-            self.buf_len = 0;
-        }
+    pub fn update(self: *WyhashStateless, b: []const u8) void {
+        std.debug.assert(b.len % 32 == 0);
 
-        // Full middle blocks.
-        while (off + 32 <= b.len) : (off += 32) {
+        var off: usize = 0;
+        while (off < b.len) : (off += 32) {
             @inlineCall(self.round, b[off .. off + 32]);
         }
 
-        // Remainder for next pass.
-        mem.copy(u8, self.buf[self.buf_len..], b[off..]);
-        self.buf_len += @intCast(u8, b[off..].len);
         self.msg_len += b.len;
     }
 
-    pub fn final(self: *Wyhash) u64 {
+    pub fn final(self: *WyhashStateless, b: []const u8) u64 {
+        std.debug.assert(b.len < 32);
+
         const seed = self.seed;
-        const rem_len = @intCast(u5, self.buf_len);
-        const rem_key = self.buf[0..self.buf_len];
+        const rem_len = @intCast(u5, b.len);
+        const rem_key = b[0..rem_len];
 
         self.seed = switch (rem_len) {
             0 => seed,
@@ -125,109 +113,63 @@ pub const Wyhash = struct {
             31 => mix0(read_8bytes_swapped(rem_key), read_8bytes_swapped(rem_key[8..]), seed) ^ mix1(read_8bytes_swapped(rem_key[16..]), (read_bytes(4, rem_key[24..]) << 24) | (read_bytes(2, rem_key[28..]) << 8) | read_bytes(1, rem_key[30..]), seed),
         };
 
+        self.msg_len += b.len;
         return mum(self.seed ^ self.msg_len, primes[4]);
     }
 
     pub fn hash(seed: u64, input: []const u8) u64 {
-        var c = Wyhash.init(seed);
-        @inlineCall(c.update, input);
-        return @inlineCall(c.final);
-    }
-};
-
-/// Wyhash version where state is not preserved between successive `update`
-/// calls, ie. it will have different results between hashing the data in
-/// one or several steps.
-/// This allows it to be faster.
-pub const WyhashStateless = struct {
-    seed: u64,
-    msg_len: usize,
+        const aligned_len = input.len - (input.len % 32);
 
-    const Self = @This();
-
-    pub fn init(seed: u64) Self {
-        return Self{
-            .seed = seed,
-            .msg_len = 0,
-        };
+        var c = WyhashStateless.init(seed);
+        @inlineCall(c.update, input[0..aligned_len]);
+        return @inlineCall(c.final, input[aligned_len..]);
     }
+};
 
-    fn round(self: *Self, b: []const u8) void {
-        std.debug.assert(b.len == 32);
-
-        self.seed = mix0(
-            read_bytes(8, b[0..]),
-            read_bytes(8, b[8..]),
-            self.seed,
-        ) ^ mix1(
-            read_bytes(8, b[16..]),
-            read_bytes(8, b[24..]),
-            self.seed,
-        );
-    }
+/// Fast non-cryptographic 64bit hash function.
+/// See https://github.com/wangyi-fudan/wyhash
+pub const Wyhash = struct {
+    state: WyhashStateless,
 
-    fn partial(self: *Self, b: []const u8) void {
-        const rem_key = b;
-        const rem_len = b.len;
+    buf: [32]u8,
+    buf_len: usize,
 
-        var seed = self.seed;
-        seed = switch (@intCast(u5, rem_len)) {
-            0 => seed,
-            1 => mix0(read_bytes(1, rem_key), primes[4], seed),
-            2 => mix0(read_bytes(2, rem_key), primes[4], seed),
-            3 => mix0((read_bytes(2, rem_key) << 8) | read_bytes(1, rem_key[2..]), primes[4], seed),
-            4 => mix0(read_bytes(4, rem_key), primes[4], seed),
-            5 => mix0((read_bytes(4, rem_key) << 8) | read_bytes(1, rem_key[4..]), primes[4], seed),
-            6 => mix0((read_bytes(4, rem_key) << 16) | read_bytes(2, rem_key[4..]), primes[4], seed),
-            7 => mix0((read_bytes(4, rem_key) << 24) | (read_bytes(2, rem_key[4..]) << 8) | read_bytes(1, rem_key[6..]), primes[4], seed),
-            8 => mix0(read_8bytes_swapped(rem_key), primes[4], seed),
-            9 => mix0(read_8bytes_swapped(rem_key), read_bytes(1, rem_key[8..]), seed),
-            10 => mix0(read_8bytes_swapped(rem_key), read_bytes(2, rem_key[8..]), seed),
-            11 => mix0(read_8bytes_swapped(rem_key), (read_bytes(2, rem_key[8..]) << 8) | read_bytes(1, rem_key[10..]), seed),
-            12 => mix0(read_8bytes_swapped(rem_key), read_bytes(4, rem_key[8..]), seed),
-            13 => mix0(read_8bytes_swapped(rem_key), (read_bytes(4, rem_key[8..]) << 8) | read_bytes(1, rem_key[12..]), seed),
-            14 => mix0(read_8bytes_swapped(rem_key), (read_bytes(4, rem_key[8..]) << 16) | read_bytes(2, rem_key[12..]), seed),
-            15 => mix0(read_8bytes_swapped(rem_key), (read_bytes(4, rem_key[8..]) << 24) | (read_bytes(2, rem_key[12..]) << 8) | read_bytes(1, rem_key[14..]), seed),
-            16 => mix0(read_8bytes_swapped(rem_key), read_8bytes_swapped(rem_key[8..]), seed),
-            17 => mix0(read_8bytes_swapped(rem_key), read_8bytes_swapped(rem_key[8..]), seed) ^ mix1(read_bytes(1, rem_key[16..]), primes[4], seed),
-            18 => mix0(read_8bytes_swapped(rem_key), read_8bytes_swapped(rem_key[8..]), seed) ^ mix1(read_bytes(2, rem_key[16..]), primes[4], seed),
-            19 => mix0(read_8bytes_swapped(rem_key), read_8bytes_swapped(rem_key[8..]), seed) ^ mix1((read_bytes(2, rem_key[16..]) << 8) | read_bytes(1, rem_key[18..]), primes[4], seed),
-            20 => mix0(read_8bytes_swapped(rem_key), read_8bytes_swapped(rem_key[8..]), seed) ^ mix1(read_bytes(4, rem_key[16..]), primes[4], seed),
-            21 => mix0(read_8bytes_swapped(rem_key), read_8bytes_swapped(rem_key[8..]), seed) ^ mix1((read_bytes(4, rem_key[16..]) << 8) | read_bytes(1, rem_key[20..]), primes[4], seed),
-            22 => mix0(read_8bytes_swapped(rem_key), read_8bytes_swapped(rem_key[8..]), seed) ^ mix1((read_bytes(4, rem_key[16..]) << 16) | read_bytes(2, rem_key[20..]), primes[4], seed),
-            23 => mix0(read_8bytes_swapped(rem_key), read_8bytes_swapped(rem_key[8..]), seed) ^ mix1((read_bytes(4, rem_key[16..]) << 24) | (read_bytes(2, rem_key[20..]) << 8) | read_bytes(1, rem_key[22..]), primes[4], seed),
-            24 => mix0(read_8bytes_swapped(rem_key), read_8bytes_swapped(rem_key[8..]), seed) ^ mix1(read_8bytes_swapped(rem_key[16..]), primes[4], seed),
-            25 => mix0(read_8bytes_swapped(rem_key), read_8bytes_swapped(rem_key[8..]), seed) ^ mix1(read_8bytes_swapped(rem_key[16..]), read_bytes(1, rem_key[24..]), seed),
-            26 => mix0(read_8bytes_swapped(rem_key), read_8bytes_swapped(rem_key[8..]), seed) ^ mix1(read_8bytes_swapped(rem_key[16..]), read_bytes(2, rem_key[24..]), seed),
-            27 => mix0(read_8bytes_swapped(rem_key), read_8bytes_swapped(rem_key[8..]), seed) ^ mix1(read_8bytes_swapped(rem_key[16..]), (read_bytes(2, rem_key[24..]) << 8) | read_bytes(1, rem_key[26..]), seed),
-            28 => mix0(read_8bytes_swapped(rem_key), read_8bytes_swapped(rem_key[8..]), seed) ^ mix1(read_8bytes_swapped(rem_key[16..]), read_bytes(4, rem_key[24..]), seed),
-            29 => mix0(read_8bytes_swapped(rem_key), read_8bytes_swapped(rem_key[8..]), seed) ^ mix1(read_8bytes_swapped(rem_key[16..]), (read_bytes(4, rem_key[24..]) << 8) | read_bytes(1, rem_key[28..]), seed),
-            30 => mix0(read_8bytes_swapped(rem_key), read_8bytes_swapped(rem_key[8..]), seed) ^ mix1(read_8bytes_swapped(rem_key[16..]), (read_bytes(4, rem_key[24..]) << 16) | read_bytes(2, rem_key[28..]), seed),
-            31 => mix0(read_8bytes_swapped(rem_key), read_8bytes_swapped(rem_key[8..]), seed) ^ mix1(read_8bytes_swapped(rem_key[16..]), (read_bytes(4, rem_key[24..]) << 24) | (read_bytes(2, rem_key[28..]) << 8) | read_bytes(1, rem_key[30..]), seed),
+    pub fn init(seed: u64) Wyhash {
+        return Wyhash{
+            .state = WyhashStateless.init(seed),
+            .buf = undefined,
+            .buf_len = 0,
         };
-        self.seed = seed;
     }
 
-    pub fn update(self: *Self, b: []const u8) void {
+    pub fn update(self: *Wyhash, b: []const u8) void {
         var off: usize = 0;
 
-        // Full middle blocks.
-        while (off + 32 <= b.len) : (off += 32) {
-            @inlineCall(self.round, b[off .. off + 32]);
+        if (self.buf_len != 0 and self.buf_len + b.len >= 32) {
+            off += 32 - self.buf_len;
+            mem.copy(u8, self.buf[self.buf_len..], b[0..off]);
+            self.state.update(self.buf[0..]);
+            self.buf_len = 0;
         }
 
-        self.partial(b[off..]);
-        self.msg_len += b.len;
+        const remain_len = b.len - off;
+        const aligned_len = remain_len - (remain_len % 32);
+        self.state.update(b[off .. off + aligned_len]);
+
+        mem.copy(u8, self.buf[self.buf_len..], b[off + aligned_len ..]);
+        self.buf_len += @intCast(u8, b[off + aligned_len ..].len);
     }
 
-    pub fn final(self: *Self) u64 {
-        return mum(self.seed ^ self.msg_len, primes[4]);
+    pub fn final(self: *Wyhash) u64 {
+        const seed = self.state.seed;
+        const rem_len = @intCast(u5, self.buf_len);
+        const rem_key = self.buf[0..self.buf_len];
+
+        return self.state.final(rem_key);
     }
 
     pub fn hash(seed: u64, input: []const u8) u64 {
-        var c = Self.init(seed);
-        @inlineCall(c.update, input);
-        return @inlineCall(c.final);
+        return WyhashStateless.hash(seed, input);
     }
 };
 
@@ -265,17 +207,25 @@ test "test vectors streaming" {
     expectEqual(wh.final(), result);
 }
 
-test "test vectors stateless" {
-    const hash = WyhashStateless.hash;
+test "iterative non-divisible update" {
+    var buf: [8192]u8 = undefined;
+    for (buf) |*e, i| {
+        e.* = @truncate(u8, i);
+    }
 
-    expectEqual(hash(0, ""), 0x0);
-    expectEqual(hash(1, "a"), 0xbed235177f41d328);
-    expectEqual(hash(2, "abc"), 0xbe348debe59b27c3);
-    expectEqual(hash(3, "message digest"), 0x37320f657213a290);
-    expectEqual(hash(4, "abcdefghijklmnopqrstuvwxyz"), 0xd0b270e1d8a7019c);
-    expectEqual(hash(5, "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"), 0x602a1894d3bbfe7f);
-    expectEqual(hash(6, "12345678901234567890123456789012345678901234567890123456789012345678901234567890"), 0x829e9c148b75970e);
+    const seed = 0x128dad08f;
+
+    var end: usize = 32;
+    while (end < buf.len) : (end += 32) {
+        const non_iterative_hash = Wyhash.hash(seed, buf[0..end]);
 
-    // We don't check for the streaming API having the same results, as it is
-    // not required to.
+        var wy = Wyhash.init(seed);
+        var i: usize = 0;
+        while (i < end) : (i += 33) {
+            wy.update(buf[i..std.math.min(i + 33, end)]);
+        }
+        const iterative_hash = wy.final();
+
+        std.testing.expectEqual(iterative_hash, non_iterative_hash);
+    }
 }
std/hash.zig
@@ -29,7 +29,6 @@ pub const CityHash64 = cityhash.CityHash64;
 
 const wyhash = @import("hash/wyhash.zig");
 pub const Wyhash = wyhash.Wyhash;
-pub const WyhashStateless = wyhash.WyhashStateless;
 
 test "hash" {
     _ = @import("hash/adler.zig");