Commit 373d8ef26e
Changed files (3)
lib
std
compress
zstandard
lib/std/compress/zstandard/decode/block.zig
@@ -391,15 +391,21 @@ pub const DecodeState = struct {
try self.literal_stream_reader.init(bytes);
}
+ fn isLiteralStreamEmpty(self: *DecodeState) bool {
+ switch (self.literal_streams) {
+ .one => return self.literal_stream_reader.isEmpty(),
+ .four => return self.literal_stream_index == 3 and self.literal_stream_reader.isEmpty(),
+ }
+ }
+
const LiteralBitsError = error{
BitStreamHasNoStartBit,
UnexpectedEndOfLiteralStream,
};
fn readLiteralsBits(
self: *DecodeState,
- comptime T: type,
bit_count_to_read: usize,
- ) LiteralBitsError!T {
+ ) LiteralBitsError!u16 {
return self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch bits: {
if (self.literal_streams == .four and self.literal_stream_index < 3) {
try self.nextLiteralMultiStream();
@@ -461,7 +467,7 @@ pub const DecodeState = struct {
while (i < len) : (i += 1) {
var prefix: u16 = 0;
while (true) {
- const new_bits = self.readLiteralsBits(u16, bit_count_to_read) catch |err| {
+ const new_bits = self.readLiteralsBits(bit_count_to_read) catch |err| {
return err;
};
prefix <<= bit_count_to_read;
@@ -533,7 +539,7 @@ pub const DecodeState = struct {
while (i < len) : (i += 1) {
var prefix: u16 = 0;
while (true) {
- const new_bits = try self.readLiteralsBits(u16, bit_count_to_read);
+ const new_bits = try self.readLiteralsBits(bit_count_to_read);
prefix <<= bit_count_to_read;
prefix |= new_bits;
bits_read += bit_count_to_read;
@@ -659,13 +665,10 @@ pub fn decodeBlock(
sequence_size_limit -= decompressed_size;
}
- if (bit_stream.bit_reader.bit_count != 0) {
+ if (!bit_stream.isEmpty()) {
return error.MalformedCompressedBlock;
}
-
- bytes_read += bit_stream_bytes.len;
}
- if (bytes_read != block_size) return error.MalformedCompressedBlock;
if (decode_state.literal_written_count < literals.header.regenerated_size) {
const len = literals.header.regenerated_size - decode_state.literal_written_count;
@@ -675,7 +678,9 @@ pub fn decodeBlock(
bytes_written += len;
}
- consumed_count.* += bytes_read;
+ if (!decode_state.isLiteralStreamEmpty()) return error.MalformedCompressedBlock;
+
+ consumed_count.* += block_size;
return bytes_written;
},
.reserved => return error.ReservedBlock,
@@ -749,13 +754,10 @@ pub fn decodeBlockRingBuffer(
sequence_size_limit -= decompressed_size;
}
- if (bit_stream.bit_reader.bit_count != 0) {
+ if (!bit_stream.isEmpty()) {
return error.MalformedCompressedBlock;
}
-
- bytes_read += bit_stream_bytes.len;
}
- if (bytes_read != block_size) return error.MalformedCompressedBlock;
if (decode_state.literal_written_count < literals.header.regenerated_size) {
const len = literals.header.regenerated_size - decode_state.literal_written_count;
@@ -764,7 +766,9 @@ pub fn decodeBlockRingBuffer(
bytes_written += len;
}
- consumed_count.* += bytes_read;
+ if (!decode_state.isLiteralStreamEmpty()) return error.MalformedCompressedBlock;
+
+ consumed_count.* += block_size;
if (bytes_written > block_size_max) return error.BlockSizeOverMaximum;
return bytes_written;
},
@@ -837,7 +841,7 @@ pub fn decodeBlockReader(
sequence_size_limit -= decompressed_size;
bytes_written += decompressed_size;
}
- if (bit_stream.bit_reader.bit_count != 0) {
+ if (!bit_stream.isEmpty()) {
return error.MalformedCompressedBlock;
}
}
@@ -849,6 +853,8 @@ pub fn decodeBlockReader(
bytes_written += len;
}
+ if (!decode_state.isLiteralStreamEmpty()) return error.MalformedCompressedBlock;
+
if (bytes_written > block_size_max) return error.BlockSizeOverMaximum;
if (block_reader_limited.bytes_left != 0) return error.MalformedCompressedBlock;
decode_state.literal_written_count = 0;
lib/std/compress/zstandard/decode/huffman.zig
@@ -86,6 +86,10 @@ fn assignWeights(huff_bits: *readers.ReverseBitReader, accuracy_log: usize, entr
odd_state = odd_data.baseline + odd_bits;
} else return error.MalformedHuffmanTree;
+ if (!huff_bits.isEmpty()) {
+ return error.MalformedHuffmanTree;
+ }
+
return i + 1; // stream contains all but the last symbol
}
lib/std/compress/zstandard/readers.zig
@@ -36,7 +36,9 @@ pub const ReverseBitReader = struct {
pub fn init(self: *ReverseBitReader, bytes: []const u8) error{BitStreamHasNoStartBit}!void {
self.byte_reader = ReversedByteReader.init(bytes);
self.bit_reader = std.io.bitReader(.Big, self.byte_reader.reader());
- while (0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) {}
+ var i: usize = 0;
+ while (i < 8 and 0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) : (i += 1) {}
+ if (i == 8) return error.BitStreamHasNoStartBit;
}
pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) error{EndOfStream}!U {
@@ -50,6 +52,10 @@ pub const ReverseBitReader = struct {
pub fn alignToByte(self: *@This()) void {
self.bit_reader.alignToByte();
}
+
+ pub fn isEmpty(self: ReverseBitReader) bool {
+ return self.byte_reader.remaining_bytes == 0 and self.bit_reader.bit_count == 0;
+ }
};
pub fn BitReader(comptime Reader: type) type {