Commit ece52e0771

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-02-05 12:27:00
std.compress.zstandard: verify content size and fix crash
1 parent a9c8376
Changed files (2)
lib
std
compress
zstandard
lib/std/compress/zstandard/decode/block.zig
@@ -334,6 +334,8 @@ pub const DecodeState = struct {
     ///     mean the literal stream or the sequence is malformed).
     ///   - `error.InvalidBitStream` if the FSE sequence bitstream is malformed
     ///   - `error.EndOfStream` if `bit_reader` does not contain enough bits
+    ///   - `error.DestTooSmall` if `dest` is not large enough to holde the
+    ///     decompressed sequence
     pub fn decodeSequenceSlice(
         self: *DecodeState,
         dest: []u8,
@@ -341,10 +343,11 @@ pub const DecodeState = struct {
         bit_reader: *readers.ReverseBitReader,
         sequence_size_limit: usize,
         last_sequence: bool,
-    ) DecodeSequenceError!usize {
+    ) (error{DestTooSmall} || DecodeSequenceError)!usize {
         const sequence = try self.nextSequence(bit_reader);
         const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length;
         if (sequence_length > sequence_size_limit) return error.MalformedSequence;
+        if (sequence_length > dest[write_pos..].len) return error.DestTooSmall;
 
         try self.executeSequenceSlice(dest, write_pos, sequence);
         if (!last_sequence) {
@@ -583,6 +586,8 @@ pub const DecodeState = struct {
 ///   - `error.MalformedRleBlock` if the block is an RLE block and `src.len < 1`
 ///   - `error.MalformedCompressedBlock` if there are errors decoding a
 ///     compressed block
+///   - `error.DestTooSmall` is `dest` is not large enough to hold the
+///     decompressed block
 pub fn decodeBlock(
     dest: []u8,
     src: []const u8,
@@ -590,13 +595,14 @@ pub fn decodeBlock(
     decode_state: *DecodeState,
     consumed_count: *usize,
     written_count: usize,
-) Error!usize {
+) (error{DestTooSmall} || Error)!usize {
     const block_size_max = @min(1 << 17, dest[written_count..].len); // 128KiB
     const block_size = block_header.block_size;
     if (block_size_max < block_size) return error.BlockSizeOverMaximum;
     switch (block_header.block_type) {
         .raw => {
             if (src.len < block_size) return error.MalformedBlockSize;
+            if (dest[written_count..].len < block_size) return error.DestTooSmall;
             const data = src[0..block_size];
             std.mem.copy(u8, dest[written_count..], data);
             consumed_count.* += block_size;
@@ -604,6 +610,7 @@ pub fn decodeBlock(
         },
         .rle => {
             if (src.len < 1) return error.MalformedRleBlock;
+            if (dest[written_count..].len < block_size) return error.DestTooSmall;
             var write_pos: usize = written_count;
             while (write_pos < block_size + written_count) : (write_pos += 1) {
                 dest[write_pos] = src[0];
@@ -644,7 +651,10 @@ pub fn decodeBlock(
                         &bit_stream,
                         sequence_size_limit,
                         i == sequences_header.sequence_count - 1,
-                    ) catch return error.MalformedCompressedBlock;
+                    ) catch |err| switch (err) {
+                        error.DestTooSmall => return error.DestTooSmall,
+                        else => return error.MalformedCompressedBlock,
+                    };
                     bytes_written += decompressed_size;
                     sequence_size_limit -= decompressed_size;
                 }
@@ -655,6 +665,7 @@ pub fn decodeBlock(
 
             if (decode_state.literal_written_count < literals.header.regenerated_size) {
                 const len = literals.header.regenerated_size - decode_state.literal_written_count;
+                if (len > dest[written_count + bytes_written ..].len) return error.DestTooSmall;
                 decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], len) catch
                     return error.MalformedCompressedBlock;
                 bytes_written += len;
lib/std/compress/zstandard/decompress.zig
@@ -96,6 +96,7 @@ pub fn decodeAlloc(
 ///   - `error.UnknownContentSizeUnsupported` if the frame does not declare the
 ///     uncompressed content size
 ///   - `error.ContentTooLarge` if `dest` is smaller than the uncompressed data
+///     size declared by the frame header
 ///   - `error.BadMagic` if the first 4 bytes of `src` is not a valid magic
 ///     number for a Zstandard or Skippable frame
 ///   - `error.DictionaryIdFlagUnsupported` if the frame uses a dictionary
@@ -180,6 +181,7 @@ pub fn computeChecksum(hasher: *std.hash.XxHash64) u32 {
 const FrameError = error{
     DictionaryIdFlagUnsupported,
     ChecksumFailure,
+    BadContentSize,
     EndOfStream,
 } || InvalidBit || block.Error;
 
@@ -191,7 +193,7 @@ const FrameError = error{
 ///   - `error.UnknownContentSizeUnsupported` if the frame does not declare the
 ///     uncompressed content size
 ///   - `error.ContentTooLarge` if `dest` is smaller than the uncompressed data
-///     number for a Zstandard or Skippable frame
+///     size declared by the frame header
 ///   - `error.DictionaryIdFlagUnsupported` if the frame uses a dictionary
 ///   - `error.ChecksumFailure` if `verify_checksum` is true and the frame
 ///     contains a checksum that does not match the checksum of the decompressed
@@ -200,39 +202,51 @@ const FrameError = error{
 ///   - `error.UnusedBitSet` if the unused bit of the frame header is set
 ///   - `error.EndOfStream` if `src` does not contain a complete frame
 ///   - an error in `block.Error` if there are errors decoding a block
+///   - `error.BadContentSize` if the content size declared by the frame does
+///     not equal the actual size of decompressed data
 pub fn decodeZstandardFrame(
     dest: []u8,
     src: []const u8,
     verify_checksum: bool,
-) (error{ UnknownContentSizeUnsupported, ContentTooLarge } || FrameError)!ReadWriteCount {
+) (error{
+    UnknownContentSizeUnsupported,
+    ContentTooLarge,
+    ContentSizeTooLarge,
+    WindowSizeUnknown,
+} || FrameError)!ReadWriteCount {
     assert(readInt(u32, src[0..4]) == frame.Zstandard.magic_number);
     var consumed_count: usize = 4;
 
-    var fbs = std.io.fixedBufferStream(src[consumed_count..]);
-    var source = fbs.reader();
-    const frame_header = try decodeZstandardHeader(source);
-    consumed_count += fbs.pos;
-
-    if (frame_header.descriptor.dictionary_id_flag != 0) return error.DictionaryIdFlagUnsupported;
+    var frame_context = context: {
+        var fbs = std.io.fixedBufferStream(src[consumed_count..]);
+        var source = fbs.reader();
+        const frame_header = try decodeZstandardHeader(source);
+        consumed_count += fbs.pos;
+        break :context FrameContext.init(frame_header, std.math.maxInt(usize), verify_checksum) catch |err| switch (err) {
+            error.WindowTooLarge => unreachable,
+            inline else => |e| return e,
+        };
+    };
 
-    const content_size = frame_header.content_size orelse return error.UnknownContentSizeUnsupported;
+    const content_size = frame_context.content_size orelse return error.UnknownContentSizeUnsupported;
     if (dest.len < content_size) return error.ContentTooLarge;
 
-    const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum;
-    var hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null;
-
-    const written_count = try decodeFrameBlocks(
-        dest,
+    const written_count = decodeFrameBlocks(
+        dest[0..content_size],
         src[consumed_count..],
         &consumed_count,
-        if (hasher_opt) |*hasher| hasher else null,
-    );
+        if (frame_context.hasher_opt) |*hasher| hasher else null,
+    ) catch |err| switch (err) {
+        error.DestTooSmall => return error.BadContentSize,
+        inline else => |e| return e,
+    };
 
-    if (frame_header.descriptor.content_checksum_flag) {
+    if (written_count != content_size) return error.BadContentSize;
+    if (frame_context.has_checksum) {
         if (src.len < consumed_count + 4) return error.EndOfStream;
         const checksum = readIntSlice(u32, src[consumed_count .. consumed_count + 4]);
         consumed_count += 4;
-        if (hasher_opt) |*hasher| {
+        if (frame_context.hasher_opt) |*hasher| {
             if (checksum != computeChecksum(hasher)) return error.ChecksumFailure;
         }
     }
@@ -244,8 +258,14 @@ pub const FrameContext = struct {
     window_size: usize,
     has_checksum: bool,
     block_size_max: usize,
+    content_size: ?usize,
 
-    const Error = error{ DictionaryIdFlagUnsupported, WindowSizeUnknown, WindowTooLarge };
+    const Error = error{
+        DictionaryIdFlagUnsupported,
+        WindowSizeUnknown,
+        WindowTooLarge,
+        ContentSizeTooLarge,
+    };
     /// Validates `frame_header` and returns the associated `FrameContext`.
     ///
     /// Errors returned:
@@ -266,11 +286,18 @@ pub const FrameContext = struct {
             @intCast(usize, window_size_raw);
 
         const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum;
+
+        const content_size = if (frame_header.content_size) |size|
+            std.math.cast(usize, size) orelse return error.ContentSizeTooLarge
+        else
+            null;
+
         return .{
             .hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null,
             .window_size = window_size,
             .has_checksum = frame_header.descriptor.content_checksum_flag,
             .block_size_max = @min(1 << 17, window_size),
+            .content_size = content_size,
         };
     }
 };
@@ -294,6 +321,8 @@ pub const FrameContext = struct {
 ///   - `error.EndOfStream` if `src` does not contain a complete frame
 ///   - `error.OutOfMemory` if `allocator` cannot allocate enough memory
 ///   - an error in `block.Error` if there are errors decoding a block
+///   - `error.BadContentSize` if the content size declared by the frame does
+///     not equal the size of decompressed data
 pub fn decodeZstandardFrameAlloc(
     allocator: Allocator,
     src: []const u8,
@@ -321,6 +350,7 @@ pub fn decodeZstandardFrameArrayList(
     window_size_max: usize,
 ) (error{OutOfMemory} || FrameContext.Error || FrameError)!usize {
     assert(readInt(u32, src[0..4]) == frame.Zstandard.magic_number);
+    const initial_len = dest.items.len;
     var consumed_count: usize = 4;
 
     var frame_context = context: {
@@ -364,6 +394,12 @@ pub fn decodeZstandardFrameArrayList(
                 hasher.update(written_slice.second);
             }
         }
+        const added_len = dest.items.len - initial_len;
+        if (frame_context.content_size) |size| {
+            if (added_len != size) {
+                return error.BadContentSize;
+            }
+        }
         if (block_header.last_block) break;
     }
 
@@ -384,7 +420,7 @@ fn decodeFrameBlocks(
     src: []const u8,
     consumed_count: *usize,
     hash: ?*std.hash.XxHash64,
-) (error{EndOfStream} || block.Error)!usize {
+) (error{ EndOfStream, DestTooSmall } || block.Error)!usize {
     // These tables take 7680 bytes
     var literal_fse_data: [types.compressed_block.table_size_max.literal]Table.Fse = undefined;
     var match_fse_data: [types.compressed_block.table_size_max.match]Table.Fse = undefined;