Commit 596a97fb55

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-02-03 02:56:51
std.compress.zstandard: fix crashes
1 parent a651704
Changed files (3)
lib
std
compress
lib/std/compress/zstandard/decode/block.zig
@@ -148,8 +148,8 @@ pub const DecodeState = struct {
     }
 
     fn updateRepeatOffset(self: *DecodeState, offset: u32) void {
-        std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[1]);
-        std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[2]);
+        self.repeat_offsets[2] = self.repeat_offsets[1];
+        self.repeat_offsets[1] = self.repeat_offsets[0];
         self.repeat_offsets[0] = offset;
     }
 
@@ -238,18 +238,22 @@ pub const DecodeState = struct {
     fn nextSequence(
         self: *DecodeState,
         bit_reader: *readers.ReverseBitReader,
-    ) error{ OffsetCodeTooLarge, EndOfStream }!Sequence {
+    ) error{ InvalidBitStream, EndOfStream }!Sequence {
         const raw_code = self.getCode(.offset);
         const offset_code = std.math.cast(u5, raw_code) orelse {
-            return error.OffsetCodeTooLarge;
+            return error.InvalidBitStream;
         };
         const offset_value = (@as(u32, 1) << offset_code) + try bit_reader.readBitsNoEof(u32, offset_code);
 
         const match_code = self.getCode(.match);
+        if (match_code >= types.compressed_block.match_length_code_table.len)
+            return error.InvalidBitStream;
         const match = types.compressed_block.match_length_code_table[match_code];
         const match_length = match[0] + try bit_reader.readBitsNoEof(u32, match[1]);
 
         const literal_code = self.getCode(.literal);
+        if (literal_code >= types.compressed_block.literals_length_code_table.len)
+            return error.InvalidBitStream;
         const literal = types.compressed_block.literals_length_code_table[literal_code];
         const literal_length = literal[0] + try bit_reader.readBitsNoEof(u32, literal[1]);
 
@@ -269,6 +273,8 @@ pub const DecodeState = struct {
             break :offset self.useRepeatOffset(offset_value - 1);
         };
 
+        if (offset == 0) return error.InvalidBitStream;
+
         return .{
             .literal_length = literal_length,
             .match_length = match_length,
@@ -308,7 +314,7 @@ pub const DecodeState = struct {
     }
 
     const DecodeSequenceError = error{
-        OffsetCodeTooLarge,
+        InvalidBitStream,
         EndOfStream,
         MalformedSequence,
         MalformedFseBits,
@@ -326,7 +332,7 @@ pub const DecodeState = struct {
     ///   - `error.UnexpectedEndOfLiteralStream` if the decoder state's literal
     ///     streams do not contain enough literals for the sequence (this may
     ///     mean the literal stream or the sequence is malformed).
-    ///   - `error.OffsetCodeTooLarge` if an invalid offset code is found
+    ///   - `error.InvalidBitStream` if the FSE sequence bitstream is malformed
     ///   - `error.EndOfStream` if `bit_reader` does not contain enough bits
     pub fn decodeSequenceSlice(
         self: *DecodeState,
@@ -608,9 +614,9 @@ pub fn decodeBlock(
         .compressed => {
             if (src.len < block_size) return error.MalformedBlockSize;
             var bytes_read: usize = 0;
-            const literals = decodeLiteralsSectionSlice(src, &bytes_read) catch
+            const literals = decodeLiteralsSectionSlice(src[0..block_size], &bytes_read) catch
                 return error.MalformedCompressedBlock;
-            var fbs = std.io.fixedBufferStream(src[bytes_read..]);
+            var fbs = std.io.fixedBufferStream(src[bytes_read..block_size]);
             const fbs_reader = fbs.reader();
             const sequences_header = decodeSequencesHeader(fbs_reader) catch
                 return error.MalformedCompressedBlock;
@@ -695,9 +701,9 @@ pub fn decodeBlockRingBuffer(
         .compressed => {
             if (src.len < block_size) return error.MalformedBlockSize;
             var bytes_read: usize = 0;
-            const literals = decodeLiteralsSectionSlice(src, &bytes_read) catch
+            const literals = decodeLiteralsSectionSlice(src[0..block_size], &bytes_read) catch
                 return error.MalformedCompressedBlock;
-            var fbs = std.io.fixedBufferStream(src[bytes_read..]);
+            var fbs = std.io.fixedBufferStream(src[bytes_read..block_size]);
             const fbs_reader = fbs.reader();
             const sequences_header = decodeSequencesHeader(fbs_reader) catch
                 return error.MalformedCompressedBlock;
@@ -894,7 +900,8 @@ pub fn decodeLiteralsSectionSlice(
             else
                 null;
             const huffman_tree_size = bytes_read - huffman_tree_start;
-            const total_streams_size = @as(usize, header.compressed_size.?) - huffman_tree_size;
+            const total_streams_size = std.math.sub(usize, header.compressed_size.?, huffman_tree_size) catch
+                return error.MalformedLiteralsSection;
 
             if (src.len < bytes_read + total_streams_size) return error.MalformedLiteralsSection;
             const stream_data = src[bytes_read .. bytes_read + total_streams_size];
@@ -940,8 +947,9 @@ pub fn decodeLiteralsSection(
                 try huffman.decodeHuffmanTree(counting_reader.reader(), buffer)
             else
                 null;
-            const huffman_tree_size = counting_reader.bytes_read;
-            const total_streams_size = @as(usize, header.compressed_size.?) - @intCast(usize, huffman_tree_size);
+            const huffman_tree_size = @intCast(usize, counting_reader.bytes_read);
+            const total_streams_size = std.math.sub(usize, header.compressed_size.?, huffman_tree_size) catch
+                return error.MalformedLiteralsSection;
 
             if (total_streams_size > buffer.len) return error.LiteralsBufferTooSmall;
             try source.readNoEof(buffer[0..total_streams_size]);
lib/std/compress/zstandard/decode/huffman.zig
@@ -146,13 +146,14 @@ fn assignSymbols(weight_sorted_prefixed_symbols: []LiteralsSection.HuffmanTree.P
     return prefixed_symbol_count;
 }
 
-fn buildHuffmanTree(weights: *[256]u4, symbol_count: usize) LiteralsSection.HuffmanTree {
+fn buildHuffmanTree(weights: *[256]u4, symbol_count: usize) error{MalformedHuffmanTree}!LiteralsSection.HuffmanTree {
     var weight_power_sum: u16 = 0;
     for (weights[0 .. symbol_count - 1]) |value| {
         if (value > 0) {
             weight_power_sum += @as(u16, 1) << (value - 1);
         }
     }
+    if (weight_power_sum >= 1 << 11) return error.MalformedHuffmanTree;
 
     // advance to next power of two (even if weight_power_sum is a power of 2)
     const max_number_of_bits = std.math.log2_int(u16, weight_power_sum) + 1;
lib/std/compress/zstandard/decompress.zig
@@ -195,6 +195,7 @@ pub fn decodeZstandardFrame(
     );
 
     if (frame_header.descriptor.content_checksum_flag) {
+        if (src.len < consumed_count + 4) return error.EndOfStream;
         const checksum = readIntSlice(u32, src[consumed_count .. consumed_count + 4]);
         consumed_count += 4;
         if (hasher_opt) |*hasher| {
@@ -302,17 +303,20 @@ pub fn decodeZstandardFrameAlloc(
             &consumed_count,
             frame_context.block_size_max,
         );
-        const written_slice = ring_buffer.sliceLast(written_size);
-        try result.appendSlice(written_slice.first);
-        try result.appendSlice(written_slice.second);
-        if (frame_context.hasher_opt) |*hasher| {
-            hasher.update(written_slice.first);
-            hasher.update(written_slice.second);
+        if (written_size > 0) {
+            const written_slice = ring_buffer.sliceLast(written_size);
+            try result.appendSlice(written_slice.first);
+            try result.appendSlice(written_slice.second);
+            if (frame_context.hasher_opt) |*hasher| {
+                hasher.update(written_slice.first);
+                hasher.update(written_slice.second);
+            }
         }
         if (block_header.last_block) break;
     }
 
     if (frame_context.has_checksum) {
+        if (src.len < consumed_count + 4) return error.EndOfStream;
         const checksum = readIntSlice(u32, src[consumed_count .. consumed_count + 4]);
         consumed_count += 4;
         if (frame_context.hasher_opt) |*hasher| {