Commit e92575d3d4

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-01-28 12:02:08
std.compress.zstandard: verify checksum in decodeFrameAlloc()
1 parent 3bfba36
Changed files (1)
lib
std
compress
zstandard
lib/std/compress/zstandard/decompress.zig
@@ -573,6 +573,11 @@ const literal_table_size_max = 1 << types.compressed_block.table_accuracy_log_ma
 const match_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.match;
 const offset_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.match;
 
+pub fn computeChecksum(hasher: *std.hash.XxHash64) u32 {
+    const hash = hasher.final();
+    return @intCast(u32, hash & 0xFFFFFFFF);
+}
+
 const FrameError = error{
     DictionaryIdFlagUnsupported,
     ChecksumFailure,
@@ -601,24 +606,20 @@ pub fn decodeZStandardFrame(
     if (dest.len < content_size) return error.ContentTooLarge;
 
     const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum;
-    var hash_state = if (should_compute_checksum) std.hash.XxHash64.init(0) else undefined;
+    var hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null;
 
     const written_count = try decodeFrameBlocks(
         dest,
         src[consumed_count..],
         &consumed_count,
-        if (should_compute_checksum) &hash_state else null,
+        if (hasher_opt) |*hasher| hasher else null,
     );
 
     if (frame_header.descriptor.content_checksum_flag) {
         const checksum = readIntSlice(u32, src[consumed_count .. consumed_count + 4]);
         consumed_count += 4;
-        if (verify_checksum) {
-            const hash = hash_state.final();
-            const hash_low_bytes = hash & 0xFFFFFFFF;
-            if (checksum != hash_low_bytes) {
-                return error.ChecksumFailure;
-            }
+        if (hasher_opt) |*hasher| {
+            if (checksum != computeChecksum(hasher)) return error.ChecksumFailure;
         }
     }
     return ReadWriteCount{ .read_count = consumed_count, .write_count = written_count };
@@ -649,7 +650,7 @@ pub fn decodeZStandardFrameAlloc(
         @intCast(usize, window_size_raw);
 
     const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum;
-    var hash = if (should_compute_checksum) std.hash.XxHash64.init(0) else null;
+    var hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null;
 
     const block_size_maximum = @min(1 << 17, window_size);
 
@@ -707,12 +708,20 @@ pub fn decodeZStandardFrameAlloc(
         const written_slice = ring_buffer.sliceLast(written_size);
         try result.appendSlice(written_slice.first);
         try result.appendSlice(written_slice.second);
-        if (hash) |*hash_state| {
-            hash_state.update(written_slice.first);
-            hash_state.update(written_slice.second);
+        if (hasher_opt) |*hasher| {
+            hasher.update(written_slice.first);
+            hasher.update(written_slice.second);
         }
         if (block_header.last_block) break;
     }
+
+    if (frame_header.descriptor.content_checksum_flag) {
+        const checksum = readIntSlice(u32, src[consumed_count .. consumed_count + 4]);
+        consumed_count += 4;
+        if (hasher_opt) |*hasher| {
+            if (checksum != computeChecksum(hasher)) return error.ChecksumFailure;
+        }
+    }
     return result.toOwnedSlice();
 }