Commit a53cf299a6

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-02-12 12:04:07
std.compress.zstandard: add error condition to ring buffer decoding
Previously `executeSequenceRingBuffer()` would not verify the offset against the number of bytes already decoded, so it would happily copy garbage bytes rather than return an error before the window was filled. To fix this a new `written_count` is added to the decode state that tracks the total number of bytes decoded.
1 parent 5a31fc2
Changed files (1)
lib
std
compress
zstandard
decode
lib/std/compress/zstandard/decode/block.zig
@@ -45,6 +45,7 @@ pub const DecodeState = struct {
     huffman_tree: ?LiteralsSection.HuffmanTree,
 
     literal_written_count: usize,
+    written_count: usize = 0,
 
     fn StateData(comptime max_accuracy_log: comptime_int) type {
         return struct {
@@ -84,6 +85,8 @@ pub const DecodeState = struct {
             .literal_stream_reader = undefined,
             .literal_stream_index = undefined,
             .huffman_tree = null,
+
+            .written_count = 0,
         };
     }
 
@@ -296,6 +299,7 @@ pub const DecodeState = struct {
         // NOTE: we ignore the usage message for std.mem.copy and copy with dest.ptr >= src.ptr
         //       to allow repeats
         std.mem.copy(u8, dest[write_pos + sequence.literal_length ..], dest[copy_start..copy_end]);
+        self.written_count += sequence.match_length;
     }
 
     fn executeSequenceRingBuffer(
@@ -303,7 +307,8 @@ pub const DecodeState = struct {
         dest: *RingBuffer,
         sequence: Sequence,
     ) (error{MalformedSequence} || DecodeLiteralsError)!void {
-        if (sequence.offset > dest.data.len) return error.MalformedSequence;
+        if (sequence.offset > @min(dest.data.len, self.written_count + sequence.literal_length))
+            return error.MalformedSequence;
 
         try self.decodeLiteralsRingBuffer(dest, sequence.literal_length);
         const copy_start = dest.write_index + dest.data.len - sequence.offset;
@@ -311,6 +316,7 @@ pub const DecodeState = struct {
         // TODO: would std.mem.copy and figuring out dest slice be better/faster?
         for (copy_slice.first) |b| dest.writeAssumeCapacity(b);
         for (copy_slice.second) |b| dest.writeAssumeCapacity(b);
+        self.written_count += sequence.match_length;
     }
 
     const DecodeSequenceError = error{
@@ -444,6 +450,7 @@ pub const DecodeState = struct {
                 const literal_data = self.literal_streams.one[self.literal_written_count..literals_end];
                 std.mem.copy(u8, dest, literal_data);
                 self.literal_written_count += len;
+                self.written_count += len;
             },
             .rle => {
                 var i: usize = 0;
@@ -451,6 +458,7 @@ pub const DecodeState = struct {
                     dest[i] = self.literal_streams.one[0];
                 }
                 self.literal_written_count += len;
+                self.written_count += len;
             },
             .compressed, .treeless => {
                 // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4;
@@ -497,6 +505,7 @@ pub const DecodeState = struct {
                     }
                 }
                 self.literal_written_count += len;
+                self.written_count += len;
             },
         }
     }
@@ -516,6 +525,7 @@ pub const DecodeState = struct {
                 const literal_data = self.literal_streams.one[self.literal_written_count..literals_end];
                 dest.writeSliceAssumeCapacity(literal_data);
                 self.literal_written_count += len;
+                self.written_count += len;
             },
             .rle => {
                 var i: usize = 0;
@@ -523,6 +533,7 @@ pub const DecodeState = struct {
                     dest.writeAssumeCapacity(self.literal_streams.one[0]);
                 }
                 self.literal_written_count += len;
+                self.written_count += len;
             },
             .compressed, .treeless => {
                 // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4;
@@ -565,6 +576,7 @@ pub const DecodeState = struct {
                     }
                 }
                 self.literal_written_count += len;
+                self.written_count += len;
             },
         }
     }
@@ -612,6 +624,7 @@ pub fn decodeBlock(
             const data = src[0..block_size];
             std.mem.copy(u8, dest[written_count..], data);
             consumed_count.* += block_size;
+            decode_state.written_count += block_size;
             return block_size;
         },
         .rle => {
@@ -622,6 +635,7 @@ pub fn decodeBlock(
                 dest[write_pos] = src[0];
             }
             consumed_count.* += 1;
+            decode_state.written_count += block_size;
             return block_size;
         },
         .compressed => {
@@ -712,6 +726,7 @@ pub fn decodeBlockRingBuffer(
             const data = src[0..block_size];
             dest.writeSliceAssumeCapacity(data);
             consumed_count.* += block_size;
+            decode_state.written_count += block_size;
             return block_size;
         },
         .rle => {
@@ -721,6 +736,7 @@ pub fn decodeBlockRingBuffer(
                 dest.writeAssumeCapacity(src[0]);
             }
             consumed_count.* += 1;
+            decode_state.written_count += block_size;
             return block_size;
         },
         .compressed => {
@@ -814,6 +830,7 @@ pub fn decodeBlockReader(
             try source.readNoEof(slice.first);
             try source.readNoEof(slice.second);
             dest.write_index = dest.mask2(dest.write_index + block_size);
+            decode_state.written_count += block_size;
         },
         .rle => {
             const byte = try source.readByte();
@@ -821,6 +838,7 @@ pub fn decodeBlockReader(
             while (i < block_size) : (i += 1) {
                 dest.writeAssumeCapacity(byte);
             }
+            decode_state.written_count += block_size;
         },
         .compressed => {
             const literals = try decodeLiteralsSection(block_reader, literals_buffer);