Commit 12aa478db0

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-02-13 07:19:33
std.compress.zstandard: also check block size when sequence count is 0
1 parent a53cf29
Changed files (2)
lib
std
compress
zstandard
lib/std/compress/zstandard/decode/block.zig
@@ -654,29 +654,32 @@ pub fn decodeBlock(
             bytes_read += fbs.pos;
 
             var bytes_written: usize = 0;
-            if (sequences_header.sequence_count > 0) {
+            {
                 const bit_stream_bytes = src[bytes_read..block_size];
                 var bit_stream: readers.ReverseBitReader = undefined;
                 bit_stream.init(bit_stream_bytes) catch return error.MalformedCompressedBlock;
 
-                decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock;
-
-                var sequence_size_limit = block_size_max;
-                var i: usize = 0;
-                while (i < sequences_header.sequence_count) : (i += 1) {
-                    const write_pos = written_count + bytes_written;
-                    const decompressed_size = decode_state.decodeSequenceSlice(
-                        dest,
-                        write_pos,
-                        &bit_stream,
-                        sequence_size_limit,
-                        i == sequences_header.sequence_count - 1,
-                    ) catch |err| switch (err) {
-                        error.DestTooSmall => return error.DestTooSmall,
-                        else => return error.MalformedCompressedBlock,
-                    };
-                    bytes_written += decompressed_size;
-                    sequence_size_limit -= decompressed_size;
+                if (sequences_header.sequence_count > 0) {
+                    decode_state.readInitialFseState(&bit_stream) catch
+                        return error.MalformedCompressedBlock;
+
+                    var sequence_size_limit = block_size_max;
+                    var i: usize = 0;
+                    while (i < sequences_header.sequence_count) : (i += 1) {
+                        const write_pos = written_count + bytes_written;
+                        const decompressed_size = decode_state.decodeSequenceSlice(
+                            dest,
+                            write_pos,
+                            &bit_stream,
+                            sequence_size_limit,
+                            i == sequences_header.sequence_count - 1,
+                        ) catch |err| switch (err) {
+                            error.DestTooSmall => return error.DestTooSmall,
+                            else => return error.MalformedCompressedBlock,
+                        };
+                        bytes_written += decompressed_size;
+                        sequence_size_limit -= decompressed_size;
+                    }
                 }
 
                 if (!bit_stream.isEmpty()) {
@@ -755,24 +758,27 @@ pub fn decodeBlockRingBuffer(
             bytes_read += fbs.pos;
 
             var bytes_written: usize = 0;
-            if (sequences_header.sequence_count > 0) {
+            {
                 const bit_stream_bytes = src[bytes_read..block_size];
                 var bit_stream: readers.ReverseBitReader = undefined;
                 bit_stream.init(bit_stream_bytes) catch return error.MalformedCompressedBlock;
 
-                decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock;
-
-                var sequence_size_limit = block_size_max;
-                var i: usize = 0;
-                while (i < sequences_header.sequence_count) : (i += 1) {
-                    const decompressed_size = decode_state.decodeSequenceRingBuffer(
-                        dest,
-                        &bit_stream,
-                        sequence_size_limit,
-                        i == sequences_header.sequence_count - 1,
-                    ) catch return error.MalformedCompressedBlock;
-                    bytes_written += decompressed_size;
-                    sequence_size_limit -= decompressed_size;
+                if (sequences_header.sequence_count > 0) {
+                    decode_state.readInitialFseState(&bit_stream) catch
+                        return error.MalformedCompressedBlock;
+
+                    var sequence_size_limit = block_size_max;
+                    var i: usize = 0;
+                    while (i < sequences_header.sequence_count) : (i += 1) {
+                        const decompressed_size = decode_state.decodeSequenceRingBuffer(
+                            dest,
+                            &bit_stream,
+                            sequence_size_limit,
+                            i == sequences_header.sequence_count - 1,
+                        ) catch return error.MalformedCompressedBlock;
+                        bytes_written += decompressed_size;
+                        sequence_size_limit -= decompressed_size;
+                    }
                 }
 
                 if (!bit_stream.isEmpty()) {
@@ -847,28 +853,32 @@ pub fn decodeBlockReader(
             try decode_state.prepare(block_reader, literals, sequences_header);
 
             var bytes_written: usize = 0;
-            if (sequences_header.sequence_count > 0) {
-                if (sequence_buffer.len < block_reader_limited.bytes_left)
-                    return error.SequenceBufferTooSmall;
-
+            {
                 const size = try block_reader.readAll(sequence_buffer);
                 var bit_stream: readers.ReverseBitReader = undefined;
                 try bit_stream.init(sequence_buffer[0..size]);
 
-                decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock;
-
-                var sequence_size_limit = block_size_max;
-                var i: usize = 0;
-                while (i < sequences_header.sequence_count) : (i += 1) {
-                    const decompressed_size = decode_state.decodeSequenceRingBuffer(
-                        dest,
-                        &bit_stream,
-                        sequence_size_limit,
-                        i == sequences_header.sequence_count - 1,
-                    ) catch return error.MalformedCompressedBlock;
-                    sequence_size_limit -= decompressed_size;
-                    bytes_written += decompressed_size;
+                if (sequences_header.sequence_count > 0) {
+                    if (sequence_buffer.len < block_reader_limited.bytes_left)
+                        return error.SequenceBufferTooSmall;
+
+                    decode_state.readInitialFseState(&bit_stream) catch
+                        return error.MalformedCompressedBlock;
+
+                    var sequence_size_limit = block_size_max;
+                    var i: usize = 0;
+                    while (i < sequences_header.sequence_count) : (i += 1) {
+                        const decompressed_size = decode_state.decodeSequenceRingBuffer(
+                            dest,
+                            &bit_stream,
+                            sequence_size_limit,
+                            i == sequences_header.sequence_count - 1,
+                        ) catch return error.MalformedCompressedBlock;
+                        sequence_size_limit -= decompressed_size;
+                        bytes_written += decompressed_size;
+                    }
                 }
+
                 if (!bit_stream.isEmpty()) {
                     return error.MalformedCompressedBlock;
                 }
lib/std/compress/zstandard/readers.zig
@@ -36,8 +36,9 @@ pub const ReverseBitReader = struct {
     pub fn init(self: *ReverseBitReader, bytes: []const u8) error{BitStreamHasNoStartBit}!void {
         self.byte_reader = ReversedByteReader.init(bytes);
         self.bit_reader = std.io.bitReader(.Big, self.byte_reader.reader());
+        if (bytes.len == 0) return;
         var i: usize = 0;
-        while (i < 8 and 0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) : (i += 1) {}
+        while (i < 8 and 0 == self.readBitsNoEof(u1, 1) catch unreachable) : (i += 1) {}
         if (i == 8) return error.BitStreamHasNoStartBit;
     }