Commit 373d8ef26e

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-02-11 18:33:20
std.compress.zstandard: check FSE bitstreams are fully consumed
1 parent 1530e73
Changed files (3)
lib
std
compress
lib/std/compress/zstandard/decode/block.zig
@@ -391,15 +391,21 @@ pub const DecodeState = struct {
         try self.literal_stream_reader.init(bytes);
     }
 
+    fn isLiteralStreamEmpty(self: *DecodeState) bool {
+        switch (self.literal_streams) {
+            .one => return self.literal_stream_reader.isEmpty(),
+            .four => return self.literal_stream_index == 3 and self.literal_stream_reader.isEmpty(),
+        }
+    }
+
     const LiteralBitsError = error{
         BitStreamHasNoStartBit,
         UnexpectedEndOfLiteralStream,
     };
     fn readLiteralsBits(
         self: *DecodeState,
-        comptime T: type,
         bit_count_to_read: usize,
-    ) LiteralBitsError!T {
+    ) LiteralBitsError!u16 {
         return self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch bits: {
             if (self.literal_streams == .four and self.literal_stream_index < 3) {
                 try self.nextLiteralMultiStream();
@@ -461,7 +467,7 @@ pub const DecodeState = struct {
                 while (i < len) : (i += 1) {
                     var prefix: u16 = 0;
                     while (true) {
-                        const new_bits = self.readLiteralsBits(u16, bit_count_to_read) catch |err| {
+                        const new_bits = self.readLiteralsBits(bit_count_to_read) catch |err| {
                             return err;
                         };
                         prefix <<= bit_count_to_read;
@@ -533,7 +539,7 @@ pub const DecodeState = struct {
                 while (i < len) : (i += 1) {
                     var prefix: u16 = 0;
                     while (true) {
-                        const new_bits = try self.readLiteralsBits(u16, bit_count_to_read);
+                        const new_bits = try self.readLiteralsBits(bit_count_to_read);
                         prefix <<= bit_count_to_read;
                         prefix |= new_bits;
                         bits_read += bit_count_to_read;
@@ -659,13 +665,10 @@ pub fn decodeBlock(
                     sequence_size_limit -= decompressed_size;
                 }
 
-                if (bit_stream.bit_reader.bit_count != 0) {
+                if (!bit_stream.isEmpty()) {
                     return error.MalformedCompressedBlock;
                 }
-
-                bytes_read += bit_stream_bytes.len;
             }
-            if (bytes_read != block_size) return error.MalformedCompressedBlock;
 
             if (decode_state.literal_written_count < literals.header.regenerated_size) {
                 const len = literals.header.regenerated_size - decode_state.literal_written_count;
@@ -675,7 +678,9 @@ pub fn decodeBlock(
                 bytes_written += len;
             }
 
-            consumed_count.* += bytes_read;
+            if (!decode_state.isLiteralStreamEmpty()) return error.MalformedCompressedBlock;
+
+            consumed_count.* += block_size;
             return bytes_written;
         },
         .reserved => return error.ReservedBlock,
@@ -749,13 +754,10 @@ pub fn decodeBlockRingBuffer(
                     sequence_size_limit -= decompressed_size;
                 }
 
-                if (bit_stream.bit_reader.bit_count != 0) {
+                if (!bit_stream.isEmpty()) {
                     return error.MalformedCompressedBlock;
                 }
-
-                bytes_read += bit_stream_bytes.len;
             }
-            if (bytes_read != block_size) return error.MalformedCompressedBlock;
 
             if (decode_state.literal_written_count < literals.header.regenerated_size) {
                 const len = literals.header.regenerated_size - decode_state.literal_written_count;
@@ -764,7 +766,9 @@ pub fn decodeBlockRingBuffer(
                 bytes_written += len;
             }
 
-            consumed_count.* += bytes_read;
+            if (!decode_state.isLiteralStreamEmpty()) return error.MalformedCompressedBlock;
+
+            consumed_count.* += block_size;
             if (bytes_written > block_size_max) return error.BlockSizeOverMaximum;
             return bytes_written;
         },
@@ -837,7 +841,7 @@ pub fn decodeBlockReader(
                     sequence_size_limit -= decompressed_size;
                     bytes_written += decompressed_size;
                 }
-                if (bit_stream.bit_reader.bit_count != 0) {
+                if (!bit_stream.isEmpty()) {
                     return error.MalformedCompressedBlock;
                 }
             }
@@ -849,6 +853,8 @@ pub fn decodeBlockReader(
                 bytes_written += len;
             }
 
+            if (!decode_state.isLiteralStreamEmpty()) return error.MalformedCompressedBlock;
+
             if (bytes_written > block_size_max) return error.BlockSizeOverMaximum;
             if (block_reader_limited.bytes_left != 0) return error.MalformedCompressedBlock;
             decode_state.literal_written_count = 0;
lib/std/compress/zstandard/decode/huffman.zig
@@ -86,6 +86,10 @@ fn assignWeights(huff_bits: *readers.ReverseBitReader, accuracy_log: usize, entr
         odd_state = odd_data.baseline + odd_bits;
     } else return error.MalformedHuffmanTree;
 
+    if (!huff_bits.isEmpty()) {
+        return error.MalformedHuffmanTree;
+    }
+
     return i + 1; // stream contains all but the last symbol
 }
 
lib/std/compress/zstandard/readers.zig
@@ -36,7 +36,9 @@ pub const ReverseBitReader = struct {
     pub fn init(self: *ReverseBitReader, bytes: []const u8) error{BitStreamHasNoStartBit}!void {
         self.byte_reader = ReversedByteReader.init(bytes);
         self.bit_reader = std.io.bitReader(.Big, self.byte_reader.reader());
-        while (0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) {}
+        var i: usize = 0;
+        while (i < 8 and 0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) : (i += 1) {}
+        if (i == 8) return error.BitStreamHasNoStartBit;
     }
 
     pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) error{EndOfStream}!U {
@@ -50,6 +52,10 @@ pub const ReverseBitReader = struct {
     pub fn alignToByte(self: *@This()) void {
         self.bit_reader.alignToByte();
     }
+
+    pub fn isEmpty(self: ReverseBitReader) bool {
+        return self.byte_reader.remaining_bytes == 0 and self.bit_reader.bit_count == 0;
+    }
 };
 
 pub fn BitReader(comptime Reader: type) type {