Commit 326b7b794b

Marc Tiehuis <marc@tiehu.is>
2019-08-27 10:13:57
Improve siphash performance for small keys by up to 30% (#3124)
This removes the partial buffer handling from the full slice api. `./benchmark --filter siphash --count 1024` old siphash(1,3) iterative: 3388 MiB/s [67532e53a0d210bf] small keys: 1258 MiB/s [948c91176a000000] siphash(2,4) iterative: 2061 MiB/s [f792d39bff42f819] small keys: 902 MiB/s [e1ecba6614000000] new siphash(1,3) iterative: 3410 MiB/s [67532e53a0d210bf] small keys: 1639 MiB/s [948c91176a000000] siphash(2,4) iterative: 2053 MiB/s [f792d39bff42f819] small keys: 1074 MiB/s [e1ecba6614000000]
1 parent 1df75da
Changed files (1)
std
std/hash/siphash.zig
@@ -21,7 +21,7 @@ pub fn SipHash128(comptime c_rounds: usize, comptime d_rounds: usize) type {
     return SipHash(u128, c_rounds, d_rounds);
 }
 
-fn SipHash(comptime T: type, comptime c_rounds: usize, comptime d_rounds: usize) type {
+fn SipHashStateless(comptime T: type, comptime c_rounds: usize, comptime d_rounds: usize) type {
     assert(T == u64 or T == u128);
     assert(c_rounds > 0 and d_rounds > 0);
 
@@ -34,10 +34,6 @@ fn SipHash(comptime T: type, comptime c_rounds: usize, comptime d_rounds: usize)
         v1: u64,
         v2: u64,
         v3: u64,
-
-        // streaming cache
-        buf: [8]u8,
-        buf_len: usize,
         msg_len: u8,
 
         pub fn init(key: []const u8) Self {
@@ -51,9 +47,6 @@ fn SipHash(comptime T: type, comptime c_rounds: usize, comptime d_rounds: usize)
                 .v1 = k1 ^ 0x646f72616e646f6d,
                 .v2 = k0 ^ 0x6c7967656e657261,
                 .v3 = k1 ^ 0x7465646279746573,
-
-                .buf = undefined,
-                .buf_len = 0,
                 .msg_len = 0,
             };
 
@@ -64,73 +57,66 @@ fn SipHash(comptime T: type, comptime c_rounds: usize, comptime d_rounds: usize)
             return d;
         }
 
-        pub fn update(d: *Self, b: []const u8) void {
-            var off: usize = 0;
-
-            // Partial from previous.
-            if (d.buf_len != 0 and d.buf_len + b.len > 8) {
-                off += 8 - d.buf_len;
-                mem.copy(u8, d.buf[d.buf_len..], b[0..off]);
-                d.round(d.buf[0..]);
-                d.buf_len = 0;
-            }
+        pub fn update(self: *Self, b: []const u8) void {
+            std.debug.assert(b.len % 8 == 0);
 
-            // Full middle blocks.
-            while (off + 8 <= b.len) : (off += 8) {
-                d.round(b[off .. off + 8]);
+            var off: usize = 0;
+            while (off < b.len) : (off += 8) {
+                @inlineCall(self.round, b[off .. off + 8]);
             }
 
-            // Remainder for next pass.
-            mem.copy(u8, d.buf[d.buf_len..], b[off..]);
-            d.buf_len += @intCast(u8, b[off..].len);
-            d.msg_len +%= @truncate(u8, b.len);
+            self.msg_len +%= @truncate(u8, b.len);
         }
 
-        pub fn final(d: *Self) T {
-            // Padding
-            mem.set(u8, d.buf[d.buf_len..], 0);
-            d.buf[7] = d.msg_len;
-            d.round(d.buf[0..]);
+        pub fn final(self: *Self, b: []const u8) T {
+            std.debug.assert(b.len < 8);
+
+            self.msg_len +%= @truncate(u8, b.len);
+
+            var buf = [_]u8{0} ** 8;
+            mem.copy(u8, buf[0..], b[0..]);
+            buf[7] = self.msg_len;
+            self.round(buf[0..]);
 
             if (T == u128) {
-                d.v2 ^= 0xee;
+                self.v2 ^= 0xee;
             } else {
-                d.v2 ^= 0xff;
+                self.v2 ^= 0xff;
             }
 
             comptime var i: usize = 0;
             inline while (i < d_rounds) : (i += 1) {
-                @inlineCall(sipRound, d);
+                @inlineCall(sipRound, self);
             }
 
-            const b1 = d.v0 ^ d.v1 ^ d.v2 ^ d.v3;
+            const b1 = self.v0 ^ self.v1 ^ self.v2 ^ self.v3;
             if (T == u64) {
                 return b1;
             }
 
-            d.v1 ^= 0xdd;
+            self.v1 ^= 0xdd;
 
             comptime var j: usize = 0;
             inline while (j < d_rounds) : (j += 1) {
-                @inlineCall(sipRound, d);
+                @inlineCall(sipRound, self);
             }
 
-            const b2 = d.v0 ^ d.v1 ^ d.v2 ^ d.v3;
+            const b2 = self.v0 ^ self.v1 ^ self.v2 ^ self.v3;
             return (u128(b2) << 64) | b1;
         }
 
-        fn round(d: *Self, b: []const u8) void {
+        fn round(self: *Self, b: []const u8) void {
             assert(b.len == 8);
 
             const m = mem.readIntSliceLittle(u64, b[0..]);
-            d.v3 ^= m;
+            self.v3 ^= m;
 
             comptime var i: usize = 0;
             inline while (i < c_rounds) : (i += 1) {
-                @inlineCall(sipRound, d);
+                @inlineCall(sipRound, self);
             }
 
-            d.v0 ^= m;
+            self.v0 ^= m;
         }
 
         fn sipRound(d: *Self) void {
@@ -151,9 +137,61 @@ fn SipHash(comptime T: type, comptime c_rounds: usize, comptime d_rounds: usize)
         }
 
         pub fn hash(key: []const u8, input: []const u8) T {
+            const aligned_len = input.len - (input.len % 8);
+
             var c = Self.init(key);
-            @inlineCall(c.update, input);
-            return @inlineCall(c.final);
+            @inlineCall(c.update, input[0..aligned_len]);
+            return @inlineCall(c.final, input[aligned_len..]);
+        }
+    };
+}
+
+pub fn SipHash(comptime T: type, comptime c_rounds: usize, comptime d_rounds: usize) type {
+    assert(T == u64 or T == u128);
+    assert(c_rounds > 0 and d_rounds > 0);
+
+    return struct {
+        const State = SipHashStateless(T, c_rounds, d_rounds);
+        const Self = @This();
+        const digest_size = 64;
+        const block_size = 64;
+
+        state: State,
+        buf: [8]u8,
+        buf_len: usize,
+
+        pub fn init(key: []const u8) Self {
+            return Self{
+                .state = State.init(key),
+                .buf = undefined,
+                .buf_len = 0,
+            };
+        }
+
+        pub fn update(self: *Self, b: []const u8) void {
+            var off: usize = 0;
+
+            if (self.buf_len != 0 and self.buf_len + b.len >= 8) {
+                off += 8 - self.buf_len;
+                mem.copy(u8, self.buf[self.buf_len..], b[0..off]);
+                self.state.update(self.buf[0..]);
+                self.buf_len = 0;
+            }
+
+            const remain_len = b.len - off;
+            const aligned_len = remain_len - (remain_len % 8);
+            self.state.update(b[off .. off + aligned_len]);
+
+            mem.copy(u8, self.buf[self.buf_len..], b[off + aligned_len ..]);
+            self.buf_len += @intCast(u8, b[off + aligned_len ..].len);
+        }
+
+        pub fn final(self: *Self) T {
+            return self.state.final(self.buf[0..self.buf_len]);
+        }
+
+        pub fn hash(key: []const u8, input: []const u8) T {
+            return State.hash(key, input);
         }
     };
 }
@@ -237,7 +275,7 @@ test "siphash64-2-4 sanity" {
         buffer[i] = @intCast(u8, i);
 
         const expected = mem.readIntLittle(u64, &vector);
-        testing.expect(siphash.hash(test_key, buffer[0..i]) == expected);
+        testing.expectEqual(siphash.hash(test_key, buffer[0..i]), expected);
     }
 }
 
@@ -316,6 +354,30 @@ test "siphash128-2-4 sanity" {
         buffer[i] = @intCast(u8, i);
 
         const expected = mem.readIntLittle(u128, &vector);
-        testing.expect(siphash.hash(test_key, buffer[0..i]) == expected);
+        testing.expectEqual(siphash.hash(test_key, buffer[0..i]), expected);
+    }
+}
+
+test "iterative non-divisible update" {
+    var buf: [1024]u8 = undefined;
+    for (buf) |*e, i| {
+        e.* = @truncate(u8, i);
+    }
+
+    const key = "0x128dad08f12307";
+    const Siphash = SipHash64(2, 4);
+
+    var end: usize = 9;
+    while (end < buf.len) : (end += 9) {
+        const non_iterative_hash = Siphash.hash(key, buf[0..end]);
+
+        var wy = Siphash.init(key);
+        var i: usize = 0;
+        while (i < end) : (i += 7) {
+            wy.update(buf[i..std.math.min(i + 7, end)]);
+        }
+        const iterative_hash = wy.final();
+
+        std.testing.expectEqual(iterative_hash, non_iterative_hash);
     }
 }