Commit 6b85373875

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-01-24 03:07:58
std.compress.zstandard: validate sequence lengths
1 parent 082acd7
Changed files (1)
lib
std
compress
zstandard
lib/std/compress/zstandard/decompress.zig
@@ -271,10 +271,9 @@ pub const DecodeState = struct {
         literals: LiteralsSection,
         sequence: Sequence,
     ) !void {
-        try self.decodeLiteralsSlice(dest[write_pos..], literals, sequence.literal_length);
+        if (sequence.offset > write_pos + sequence.literal_length) return error.MalformedSequence;
 
-        // TODO: should we validate offset against max_window_size?
-        assert(sequence.offset <= write_pos + sequence.literal_length);
+        try self.decodeLiteralsSlice(dest[write_pos..], literals, sequence.literal_length);
         const copy_start = write_pos + sequence.literal_length - sequence.offset;
         const copy_end = copy_start + sequence.match_length;
         // NOTE: we ignore the usage message for std.mem.copy and copy with dest.ptr >= src.ptr
@@ -288,8 +287,9 @@ pub const DecodeState = struct {
         literals: LiteralsSection,
         sequence: Sequence,
     ) !void {
+        if (sequence.offset > dest.data.len) return error.MalformedSequence;
+
         try self.decodeLiteralsRingBuffer(dest, literals, sequence.literal_length);
-        // TODO: check that ring buffer window is full enough for match copies
         const copy_slice = dest.sliceAt(dest.write_index + dest.data.len - sequence.offset, sequence.match_length);
         // TODO: would std.mem.copy and figuring out dest slice be better/faster?
         for (copy_slice.first) |b| dest.writeAssumeCapacity(b);
@@ -302,9 +302,13 @@ pub const DecodeState = struct {
         write_pos: usize,
         literals: LiteralsSection,
         bit_reader: anytype,
+        sequence_size_limit: usize,
         last_sequence: bool,
     ) !usize {
         const sequence = try self.nextSequence(bit_reader);
+        const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length;
+        if (sequence_length > sequence_size_limit) return error.MalformedSequence;
+
         try self.executeSequenceSlice(dest, write_pos, literals, sequence);
         log.debug("sequence decompressed into '{x}'", .{
             std.fmt.fmtSliceHexUpper(dest[write_pos .. write_pos + sequence.literal_length + sequence.match_length]),
@@ -314,7 +318,7 @@ pub const DecodeState = struct {
             try self.updateState(.match, bit_reader);
             try self.updateState(.offset, bit_reader);
         }
-        return sequence.match_length + sequence.literal_length;
+        return sequence_length;
     }
 
     pub fn decodeSequenceRingBuffer(
@@ -322,12 +326,15 @@ pub const DecodeState = struct {
         dest: *RingBuffer,
         literals: LiteralsSection,
         bit_reader: anytype,
+        sequence_size_limit: usize,
         last_sequence: bool,
     ) !usize {
         const sequence = try self.nextSequence(bit_reader);
+        const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length;
+        if (sequence_length > sequence_size_limit) return error.MalformedSequence;
+
         try self.executeSequenceRingBuffer(dest, literals, sequence);
         if (std.options.log_level == .debug) {
-            const sequence_length = sequence.literal_length + sequence.match_length;
             const written_slice = dest.sliceLast(sequence_length);
             log.debug("sequence decompressed into '{x}{x}'", .{
                 std.fmt.fmtSliceHexUpper(written_slice.first),
@@ -339,7 +346,7 @@ pub const DecodeState = struct {
             try self.updateState(.match, bit_reader);
             try self.updateState(.offset, bit_reader);
         }
-        return sequence.match_length + sequence.literal_length;
+        return sequence_length;
     }
 
     fn nextLiteralMultiStream(self: *DecodeState, literals: LiteralsSection) !void {
@@ -717,9 +724,9 @@ pub fn decodeBlock(
     consumed_count: *usize,
     written_count: usize,
 ) !usize {
-    const block_maximum_size = 1 << 17; // 128KiB
+    const block_size_max = @min(1 << 17, dest[written_count..].len); // 128KiB
     const block_size = block_header.block_size;
-    if (block_maximum_size < block_size) return error.BlockSizeOverMaximum;
+    if (block_size_max < block_size) return error.BlockSizeOverMaximum;
     // TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks)
     switch (block_header.block_type) {
         .raw => return decodeRawBlock(dest[written_count..], src, block_size, consumed_count),
@@ -739,17 +746,21 @@ pub fn decodeBlock(
 
                 try decode_state.readInitialFseState(&bit_stream);
 
+                var sequence_size_limit = block_size_max;
                 var i: usize = 0;
                 while (i < sequences_header.sequence_count) : (i += 1) {
                     log.debug("decoding sequence {d}", .{i});
+                    const write_pos = written_count + bytes_written;
                     const decompressed_size = try decode_state.decodeSequenceSlice(
                         dest,
-                        written_count + bytes_written,
+                        write_pos,
                         literals,
                         &bit_stream,
+                        sequence_size_limit,
                         i == sequences_header.sequence_count - 1,
                     );
                     bytes_written += decompressed_size;
+                    sequence_size_limit -= decompressed_size;
                 }
 
                 bytes_read += bit_stream_bytes.len;
@@ -781,10 +792,10 @@ pub fn decodeBlockRingBuffer(
     block_header: frame.ZStandard.Block.Header,
     decode_state: *DecodeState,
     consumed_count: *usize,
-    block_size_maximum: usize,
+    block_size_max: usize,
 ) !usize {
     const block_size = block_header.block_size;
-    if (block_size_maximum < block_size) return error.BlockSizeOverMaximum;
+    if (block_size_max < block_size) return error.BlockSizeOverMaximum;
     // TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks)
     switch (block_header.block_type) {
         .raw => return decodeRawBlockRingBuffer(dest, src, block_size, consumed_count),
@@ -804,6 +815,7 @@ pub fn decodeBlockRingBuffer(
 
                 try decode_state.readInitialFseState(&bit_stream);
 
+                var sequence_size_limit = block_size_max;
                 var i: usize = 0;
                 while (i < sequences_header.sequence_count) : (i += 1) {
                     log.debug("decoding sequence {d}", .{i});
@@ -811,9 +823,11 @@ pub fn decodeBlockRingBuffer(
                         dest,
                         literals,
                         &bit_stream,
+                        sequence_size_limit,
                         i == sequences_header.sequence_count - 1,
                     );
                     bytes_written += decompressed_size;
+                    sequence_size_limit -= decompressed_size;
                 }
 
                 bytes_read += bit_stream_bytes.len;