Commit 774e2f5a5c

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-01-24 04:30:32
std.compress.zstandard: add input length safety checks
1 parent 31d1cae
Changed files (1)
lib
std
compress
zstandard
lib/std/compress/zstandard/decompress.zig
@@ -680,7 +680,8 @@ pub fn decodeFrameBlocks(dest: []u8, src: []const u8, consumed_count: *usize, ha
     return written_count;
 }
 
-fn decodeRawBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: *usize) usize {
+fn decodeRawBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: *usize) !usize {
+    if (src.len < block_size) return error.MalformedBlockSize;
     log.debug("writing raw block - size {d}", .{block_size});
     const data = src[0..block_size];
     std.mem.copy(u8, dest, data);
@@ -688,7 +689,8 @@ fn decodeRawBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count:
     return block_size;
 }
 
-fn decodeRawBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, consumed_count: *usize) usize {
+fn decodeRawBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, consumed_count: *usize) !usize {
+    if (src.len < block_size) return error.MalformedBlockSize;
     log.debug("writing raw block - size {d}", .{block_size});
     const data = src[0..block_size];
     dest.writeSliceAssumeCapacity(data);
@@ -696,7 +698,8 @@ fn decodeRawBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21,
     return block_size;
 }
 
-fn decodeRleBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: *usize) usize {
+fn decodeRleBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: *usize) !usize {
+    if (src.len < 1) return error.MalformedRleBlock;
     log.debug("writing rle block - '{x}'x{d}", .{ src[0], block_size });
     var write_pos: usize = 0;
     while (write_pos < block_size) : (write_pos += 1) {
@@ -706,7 +709,8 @@ fn decodeRleBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count:
     return block_size;
 }
 
-fn decodeRleBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, consumed_count: *usize) usize {
+fn decodeRleBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, consumed_count: *usize) !usize {
+    if (src.len < 1) return error.MalformedRleBlock;
     log.debug("writing rle block - '{x}'x{d}", .{ src[0], block_size });
     var write_pos: usize = 0;
     while (write_pos < block_size) : (write_pos += 1) {
@@ -727,11 +731,11 @@ pub fn decodeBlock(
     const block_size_max = @min(1 << 17, dest[written_count..].len); // 128KiB
     const block_size = block_header.block_size;
     if (block_size_max < block_size) return error.BlockSizeOverMaximum;
-    // TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks)
     switch (block_header.block_type) {
         .raw => return decodeRawBlock(dest[written_count..], src, block_size, consumed_count),
         .rle => return decodeRleBlock(dest[written_count..], src, block_size, consumed_count),
         .compressed => {
+            if (src.len < block_size) return error.MalformedBlockSize;
             var bytes_read: usize = 0;
             const literals = try decodeLiteralsSection(src, &bytes_read);
             const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read);
@@ -796,11 +800,11 @@ pub fn decodeBlockRingBuffer(
 ) !usize {
     const block_size = block_header.block_size;
     if (block_size_max < block_size) return error.BlockSizeOverMaximum;
-    // TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks)
     switch (block_header.block_type) {
         .raw => return decodeRawBlockRingBuffer(dest, src, block_size, consumed_count),
         .rle => return decodeRleBlockRingBuffer(dest, src, block_size, consumed_count),
         .compressed => {
+            if (src.len < block_size) return error.MalformedBlockSize;
             var bytes_read: usize = 0;
             const literals = try decodeLiteralsSection(src, &bytes_read);
             const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read);
@@ -957,11 +961,11 @@ pub fn decodeBlockHeader(src: *const [3]u8) frame.ZStandard.Block.Header {
 }
 
 pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !LiteralsSection {
-    // TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks)
     var bytes_read: usize = 0;
-    const header = decodeLiteralsHeader(src, &bytes_read);
+    const header = try decodeLiteralsHeader(src, &bytes_read);
     switch (header.block_type) {
         .raw => {
+            if (src.len < bytes_read + header.regenerated_size) return error.MalformedLiteralsSection;
             const stream = src[bytes_read .. bytes_read + header.regenerated_size];
             consumed_count.* += header.regenerated_size + bytes_read;
             return LiteralsSection{
@@ -971,6 +975,7 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !LiteralsS
             };
         },
         .rle => {
+            if (src.len < bytes_read + 1) return error.MalformedLiteralsSection;
             const stream = src[bytes_read .. bytes_read + 1];
             consumed_count.* += 1 + bytes_read;
             return LiteralsSection{
@@ -990,18 +995,19 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !LiteralsS
             log.debug("huffman tree size = {}, total streams size = {}", .{ huffman_tree_size, total_streams_size });
             if (huffman_tree) |tree| dumpHuffmanTree(tree);
 
+            if (src.len < bytes_read + total_streams_size) return error.MalformedLiteralsSection;
+            const stream_data = src[bytes_read .. bytes_read + total_streams_size];
+
             if (header.size_format == 0) {
-                const stream = src[bytes_read .. bytes_read + total_streams_size];
-                bytes_read += total_streams_size;
-                consumed_count.* += bytes_read;
+                consumed_count.* += total_streams_size + bytes_read;
                 return LiteralsSection{
                     .header = header,
                     .huffman_tree = huffman_tree,
-                    .streams = .{ .one = stream },
+                    .streams = .{ .one = stream_data },
                 };
             }
 
-            const stream_data = src[bytes_read .. bytes_read + total_streams_size];
+            if (stream_data.len < 6) return error.MalformedLiteralsSection;
 
             log.debug("jump table: {}", .{std.fmt.fmtSliceHexUpper(stream_data[0..6])});
             const stream_1_length = @as(usize, readInt(u16, stream_data[0..2]));
@@ -1014,6 +1020,7 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !LiteralsS
             const stream_3_start = stream_2_start + stream_2_length;
             const stream_4_start = stream_3_start + stream_3_length;
 
+            if (stream_data.len < stream_4_start + stream_4_length) return error.MalformedLiteralsSection;
             consumed_count.* += total_streams_size + bytes_read;
 
             return LiteralsSection{
@@ -1033,13 +1040,15 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !LiteralsS
 fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !LiteralsSection.HuffmanTree {
     var bytes_read: usize = 0;
     bytes_read += 1;
+    if (src.len == 0) return error.MalformedHuffmanTree;
     const header = src[0];
     var symbol_count: usize = undefined;
     var weights: [256]u4 = undefined;
     var max_number_of_bits: u4 = undefined;
     if (header < 128) {
-        // FSE compressed weigths
+        // FSE compressed weights
         const compressed_size = header;
+        if (src.len < 1 + compressed_size) return error.MalformedHuffmanTree;
         var stream = std.io.fixedBufferStream(src[1 .. compressed_size + 1]);
         var counting_reader = std.io.countingReader(stream.reader());
         var bit_reader = bitReader(counting_reader.reader());
@@ -1185,8 +1194,8 @@ fn lessThanByWeight(
     return weights[lhs.symbol] < weights[rhs.symbol];
 }
 
-pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) LiteralsSection.Header {
-    // TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks)
+pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) !LiteralsSection.Header {
+    if (src.len == 0) return error.MalformedLiteralsSection;
     const start = consumed_count.*;
     const byte0 = src[0];
     const block_type = @intToEnum(LiteralsSection.BlockType, byte0 & 0b11);
@@ -1201,14 +1210,16 @@ pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) LiteralsSec
                     consumed_count.* += 1;
                 },
                 1 => {
+                    if (src.len < 2) return error.MalformedLiteralsHeader;
                     regenerated_size = (byte0 >> 4) +
-                        (@as(u20, src[consumed_count.* + 1]) << 4);
+                        (@as(u20, src[1]) << 4);
                     consumed_count.* += 2;
                 },
                 3 => {
+                    if (src.len < 3) return error.MalformedLiteralsHeader;
                     regenerated_size = (byte0 >> 4) +
-                        (@as(u20, src[consumed_count.* + 1]) << 4) +
-                        (@as(u20, src[consumed_count.* + 2]) << 12);
+                        (@as(u20, src[1]) << 4) +
+                        (@as(u20, src[2]) << 12);
                     consumed_count.* += 3;
                 },
             }
@@ -1218,17 +1229,20 @@ pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) LiteralsSec
             const byte2 = src[2];
             switch (size_format) {
                 0, 1 => {
+                    if (src.len < 3) return error.MalformedLiteralsHeader;
                     regenerated_size = (byte0 >> 4) + ((@as(u20, byte1) & 0b00111111) << 4);
                     compressed_size = ((byte1 & 0b11000000) >> 6) + (@as(u18, byte2) << 2);
                     consumed_count.* += 3;
                 },
                 2 => {
+                    if (src.len < 4) return error.MalformedLiteralsHeader;
                     const byte3 = src[3];
                     regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00000011) << 12);
                     compressed_size = ((byte2 & 0b11111100) >> 2) + (@as(u18, byte3) << 6);
                     consumed_count.* += 4;
                 },
                 3 => {
+                    if (src.len < 5) return error.MalformedLiteralsHeader;
                     const byte3 = src[3];
                     const byte4 = src[4];
                     regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00111111) << 12);
@@ -1257,6 +1271,7 @@ pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) LiteralsSec
 }
 
 pub fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !SequencesSection.Header {
+    if (src.len == 0) return error.MalformedSequencesSection;
     var sequence_count: u24 = undefined;
 
     var bytes_read: usize = 0;
@@ -1275,13 +1290,16 @@ pub fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !Sequences
         sequence_count = byte0;
         bytes_read += 1;
     } else if (byte0 < 255) {
+        if (src.len < 2) return error.MalformedSequencesSection;
         sequence_count = (@as(u24, (byte0 - 128)) << 8) + src[1];
         bytes_read += 2;
     } else {
+        if (src.len < 3) return error.MalformedSequencesSection;
         sequence_count = src[1] + (@as(u24, src[2]) << 8) + 0x7F00;
         bytes_read += 3;
     }
 
+    if (src.len < bytes_read + 1) return error.MalformedSequencesSection;
     const compression_modes = src[bytes_read];
     bytes_read += 1;