Commit 98bbd959b0

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-02-06 03:19:24
std.compress.zstandard: improve block size validation
1 parent ece52e0
Changed files (2)
lib
std
compress
zstandard
lib/std/compress/zstandard/decode/block.zig
@@ -594,9 +594,9 @@ pub fn decodeBlock(
     block_header: frame.Zstandard.Block.Header,
     decode_state: *DecodeState,
     consumed_count: *usize,
+    block_size_max: usize,
     written_count: usize,
 ) (error{DestTooSmall} || Error)!usize {
-    const block_size_max = @min(1 << 17, dest[written_count..].len); // 128KiB
     const block_size = block_header.block_size;
     if (block_size_max < block_size) return error.BlockSizeOverMaximum;
     switch (block_header.block_type) {
@@ -805,6 +805,7 @@ pub fn decodeBlockReader(
 
             try decode_state.prepare(block_reader, literals, sequences_header);
 
+            var bytes_written: usize = 0;
             if (sequences_header.sequence_count > 0) {
                 if (sequence_buffer.len < block_reader_limited.bytes_left)
                     return error.SequenceBufferTooSmall;
@@ -825,6 +826,7 @@ pub fn decodeBlockReader(
                         i == sequences_header.sequence_count - 1,
                     ) catch return error.MalformedCompressedBlock;
                     sequence_size_limit -= decompressed_size;
+                    bytes_written += decompressed_size;
                 }
             }
 
@@ -832,8 +834,10 @@ pub fn decodeBlockReader(
                 const len = literals.header.regenerated_size - decode_state.literal_written_count;
                 decode_state.decodeLiteralsRingBuffer(dest, len) catch
                     return error.MalformedCompressedBlock;
+                bytes_written += len;
             }
 
+            if (bytes_written > block_size_max) return error.BlockSizeOverMaximum;
             decode_state.literal_written_count = 0;
             assert(block_reader.readByte() == error.EndOfStream);
         },
lib/std/compress/zstandard/decompress.zig
@@ -236,6 +236,7 @@ pub fn decodeZstandardFrame(
         src[consumed_count..],
         &consumed_count,
         if (frame_context.hasher_opt) |*hasher| hasher else null,
+        frame_context.block_size_max,
     ) catch |err| switch (err) {
         error.DestTooSmall => return error.BadContentSize,
         inline else => |e| return e,
@@ -376,7 +377,6 @@ pub fn decodeZstandardFrameArrayList(
         block_header = try block.decodeBlockHeaderSlice(src[consumed_count..]);
         consumed_count += 3;
     }) {
-        if (block_header.block_size > frame_context.block_size_max) return error.BlockSizeOverMaximum;
         const written_size = try block.decodeBlockRingBuffer(
             &ring_buffer,
             src[consumed_count..],
@@ -420,6 +420,7 @@ fn decodeFrameBlocks(
     src: []const u8,
     consumed_count: *usize,
     hash: ?*std.hash.XxHash64,
+    block_size_max: usize,
 ) (error{ EndOfStream, DestTooSmall } || block.Error)!usize {
     // These tables take 7680 bytes
     var literal_fse_data: [types.compressed_block.table_size_max.literal]Table.Fse = undefined;
@@ -441,6 +442,7 @@ fn decodeFrameBlocks(
             block_header,
             &decode_state,
             &bytes_read,
+            block_size_max,
             written_count,
         );
         if (hash) |hash_state| hash_state.update(dest[written_count .. written_count + written_size]);