Commit 3bfba36548
Changed files (2)
lib
std
compress
zstandard
lib/std/compress/zstandard/decompress.zig
@@ -22,7 +22,7 @@ fn isSkippableMagic(magic: u32) bool {
/// if the the frame is skippable, `null` for Zstanndard frames that do not
/// declare their content size. Returns `UnusedBitSet` and `ReservedBitSet`
/// errors if the respective bits of the the frame descriptor are set.
-pub fn getFrameDecompressedSize(src: []const u8) !?u64 {
+pub fn getFrameDecompressedSize(src: []const u8) (InvalidBit || error{BadMagic})!?u64 {
switch (try frameType(src)) {
.zstandard => {
const header = try decodeZStandardHeader(src[4..], null);
@@ -52,7 +52,11 @@ const ReadWriteCount = struct {
/// Decodes the frame at the start of `src` into `dest`. Returns the number of
/// bytes read from `src` and written to `dest`.
-pub fn decodeFrame(dest: []u8, src: []const u8, verify_checksum: bool) !ReadWriteCount {
+pub fn decodeFrame(
+ dest: []u8,
+ src: []const u8,
+ verify_checksum: bool,
+) (error{ UnknownContentSizeUnsupported, ContentTooLarge, BadMagic } || FrameError)!ReadWriteCount {
return switch (try frameType(src)) {
.zstandard => decodeZStandardFrame(dest, src, verify_checksum),
.skippable => ReadWriteCount{
@@ -100,7 +104,7 @@ pub const DecodeState = struct {
src: []const u8,
literals: LiteralsSection,
sequences_header: SequencesSection.Header,
- ) !usize {
+ ) (error{ BitStreamHasNoStartBit, TreelessLiteralsFirst } || FseTableError)!usize {
if (literals.huffman_tree) |tree| {
self.huffman_tree = tree;
} else if (literals.header.block_type == .treeless and self.huffman_tree == null) {
@@ -145,7 +149,7 @@ pub const DecodeState = struct {
/// Read initial FSE states for sequence decoding. Returns `error.EndOfStream`
/// if `bit_reader` does not contain enough bits.
- pub fn readInitialFseState(self: *DecodeState, bit_reader: anytype) !void {
+ pub fn readInitialFseState(self: *DecodeState, bit_reader: *ReverseBitReader) error{EndOfStream}!void {
self.literal.state = try bit_reader.readBitsNoEof(u9, self.literal.accuracy_log);
self.offset.state = try bit_reader.readBitsNoEof(u8, self.offset.accuracy_log);
self.match.state = try bit_reader.readBitsNoEof(u9, self.match.accuracy_log);
@@ -169,7 +173,11 @@ pub const DecodeState = struct {
const DataType = enum { offset, match, literal };
- fn updateState(self: *DecodeState, comptime choice: DataType, bit_reader: anytype) !void {
+ fn updateState(
+ self: *DecodeState,
+ comptime choice: DataType,
+ bit_reader: *ReverseBitReader,
+ ) error{ MalformedFseBits, EndOfStream }!void {
switch (@field(self, @tagName(choice)).table) {
.rle => {},
.fse => |table| {
@@ -185,17 +193,27 @@ pub const DecodeState = struct {
}
}
+ const FseTableError = error{
+ MalformedFseTable,
+ MalformedAccuracyLog,
+ RepeatModeFirst,
+ EndOfStream,
+ };
+
fn updateFseTable(
self: *DecodeState,
src: []const u8,
comptime choice: DataType,
mode: SequencesSection.Header.Mode,
- ) !usize {
+ ) FseTableError!usize {
const field_name = @tagName(choice);
switch (mode) {
.predefined => {
- @field(self, field_name).accuracy_log = @field(types.compressed_block.default_accuracy_log, field_name);
- @field(self, field_name).table = @field(types.compressed_block, "predefined_" ++ field_name ++ "_fse_table");
+ @field(self, field_name).accuracy_log =
+ @field(types.compressed_block.default_accuracy_log, field_name);
+
+ @field(self, field_name).table =
+ @field(types.compressed_block, "predefined_" ++ field_name ++ "_fse_table");
return 0;
},
.rle => {
@@ -214,9 +232,11 @@ pub const DecodeState = struct {
@field(types.compressed_block.table_accuracy_log_max, field_name),
@field(self, field_name ++ "_fse_buffer"),
);
- @field(self, field_name).table = .{ .fse = @field(self, field_name ++ "_fse_buffer")[0..table_size] };
+ @field(self, field_name).table = .{
+ .fse = @field(self, field_name ++ "_fse_buffer")[0..table_size],
+ };
@field(self, field_name).accuracy_log = std.math.log2_int_ceil(usize, table_size);
- return std.math.cast(usize, counting_reader.bytes_read) orelse return error.MalformedFseTable;
+ return std.math.cast(usize, counting_reader.bytes_read) orelse error.MalformedFseTable;
},
.repeat => return if (self.fse_tables_undefined) error.RepeatModeFirst else 0,
}
@@ -228,7 +248,10 @@ pub const DecodeState = struct {
offset: u32,
};
- fn nextSequence(self: *DecodeState, bit_reader: anytype) !Sequence {
+ fn nextSequence(
+ self: *DecodeState,
+ bit_reader: *ReverseBitReader,
+ ) error{ OffsetCodeTooLarge, EndOfStream }!Sequence {
const raw_code = self.getCode(.offset);
const offset_code = std.math.cast(u5, raw_code) orelse {
return error.OffsetCodeTooLarge;
@@ -272,7 +295,7 @@ pub const DecodeState = struct {
write_pos: usize,
literals: LiteralsSection,
sequence: Sequence,
- ) !void {
+ ) (error{MalformedSequence} || DecodeLiteralsError)!void {
if (sequence.offset > write_pos + sequence.literal_length) return error.MalformedSequence;
try self.decodeLiteralsSlice(dest[write_pos..], literals, sequence.literal_length);
@@ -288,16 +311,23 @@ pub const DecodeState = struct {
dest: *RingBuffer,
literals: LiteralsSection,
sequence: Sequence,
- ) !void {
+ ) (error{MalformedSequence} || DecodeLiteralsError)!void {
if (sequence.offset > dest.data.len) return error.MalformedSequence;
try self.decodeLiteralsRingBuffer(dest, literals, sequence.literal_length);
- const copy_slice = dest.sliceAt(dest.write_index + dest.data.len - sequence.offset, sequence.match_length);
+ const copy_start = dest.write_index + dest.data.len - sequence.offset;
+ const copy_slice = dest.sliceAt(copy_start, sequence.match_length);
// TODO: would std.mem.copy and figuring out dest slice be better/faster?
for (copy_slice.first) |b| dest.writeAssumeCapacity(b);
for (copy_slice.second) |b| dest.writeAssumeCapacity(b);
}
+ const DecodeSequenceError = error{
+ OffsetCodeTooLarge,
+ EndOfStream,
+ MalformedSequence,
+ MalformedFseBits,
+ } || DecodeLiteralsError;
/// Decode one sequence from `bit_reader` into `dest`, written starting at
/// `write_pos` and update FSE states if `last_sequence` is `false`. Returns
/// `error.MalformedSequence` error if the decompressed sequence would be longer
@@ -311,10 +341,10 @@ pub const DecodeState = struct {
dest: []u8,
write_pos: usize,
literals: LiteralsSection,
- bit_reader: anytype,
+ bit_reader: *ReverseBitReader,
sequence_size_limit: usize,
last_sequence: bool,
- ) !usize {
+ ) DecodeSequenceError!usize {
const sequence = try self.nextSequence(bit_reader);
const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length;
if (sequence_length > sequence_size_limit) return error.MalformedSequence;
@@ -336,7 +366,7 @@ pub const DecodeState = struct {
bit_reader: anytype,
sequence_size_limit: usize,
last_sequence: bool,
- ) !usize {
+ ) DecodeSequenceError!usize {
const sequence = try self.nextSequence(bit_reader);
const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length;
if (sequence_length > sequence_size_limit) return error.MalformedSequence;
@@ -350,26 +380,63 @@ pub const DecodeState = struct {
return sequence_length;
}
- fn nextLiteralMultiStream(self: *DecodeState, literals: LiteralsSection) !void {
+ fn nextLiteralMultiStream(
+ self: *DecodeState,
+ literals: LiteralsSection,
+ ) error{BitStreamHasNoStartBit}!void {
self.literal_stream_index += 1;
try self.initLiteralStream(literals.streams.four[self.literal_stream_index]);
}
- fn initLiteralStream(self: *DecodeState, bytes: []const u8) !void {
+ fn initLiteralStream(self: *DecodeState, bytes: []const u8) error{BitStreamHasNoStartBit}!void {
try self.literal_stream_reader.init(bytes);
}
+ const LiteralBitsError = error{
+ BitStreamHasNoStartBit,
+ UnexpectedEndOfLiteralStream,
+ };
+ fn readLiteralsBits(
+ self: *DecodeState,
+ comptime T: type,
+ bit_count_to_read: usize,
+ literals: LiteralsSection,
+ ) LiteralBitsError!T {
+ return self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch bits: {
+ if (literals.streams == .four and self.literal_stream_index < 3) {
+ try self.nextLiteralMultiStream(literals);
+ break :bits self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch
+ return error.UnexpectedEndOfLiteralStream;
+ } else {
+ return error.UnexpectedEndOfLiteralStream;
+ }
+ };
+ }
+
+ const DecodeLiteralsError = error{
+ MalformedLiteralsLength,
+ PrefixNotFound,
+ } || LiteralBitsError;
+
/// Decode `len` bytes of literals into `dest`. `literals` should be the
/// `LiteralsSection` that was passed to `prepare()`. Returns
/// `error.MalformedLiteralsLength` if the number of literal bytes decoded by
/// `self` plus `len` is greater than the regenerated size of `literals`.
/// Returns `error.UnexpectedEndOfLiteralStream` and `error.PrefixNotFound` if
/// there are problems decoding Huffman compressed literals.
- pub fn decodeLiteralsSlice(self: *DecodeState, dest: []u8, literals: LiteralsSection, len: usize) !void {
- if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength;
+ pub fn decodeLiteralsSlice(
+ self: *DecodeState,
+ dest: []u8,
+ literals: LiteralsSection,
+ len: usize,
+ ) DecodeLiteralsError!void {
+ if (self.literal_written_count + len > literals.header.regenerated_size)
+ return error.MalformedLiteralsLength;
+
switch (literals.header.block_type) {
.raw => {
- const literal_data = literals.streams.one[self.literal_written_count .. self.literal_written_count + len];
+ const literals_end = self.literal_written_count + len;
+ const literal_data = literals.streams.one[self.literal_written_count..literals_end];
std.mem.copy(u8, dest, literal_data);
self.literal_written_count += len;
},
@@ -395,15 +462,7 @@ pub const DecodeState = struct {
while (i < len) : (i += 1) {
var prefix: u16 = 0;
while (true) {
- const new_bits = self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch |err|
- switch (err) {
- error.EndOfStream => if (literals.streams == .four and self.literal_stream_index < 3) bits: {
- try self.nextLiteralMultiStream(literals);
- break :bits try self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read);
- } else {
- return error.UnexpectedEndOfLiteralStream;
- },
- };
+ const new_bits = try self.readLiteralsBits(u16, bit_count_to_read, literals);
prefix <<= bit_count_to_read;
prefix |= new_bits;
bits_read += bit_count_to_read;
@@ -434,11 +493,19 @@ pub const DecodeState = struct {
}
/// Decode literals into `dest`; see `decodeLiteralsSlice()`.
- pub fn decodeLiteralsRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: LiteralsSection, len: usize) !void {
- if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength;
+ pub fn decodeLiteralsRingBuffer(
+ self: *DecodeState,
+ dest: *RingBuffer,
+ literals: LiteralsSection,
+ len: usize,
+ ) DecodeLiteralsError!void {
+ if (self.literal_written_count + len > literals.header.regenerated_size)
+ return error.MalformedLiteralsLength;
+
switch (literals.header.block_type) {
.raw => {
- const literal_data = literals.streams.one[self.literal_written_count .. self.literal_written_count + len];
+ const literals_end = self.literal_written_count + len;
+ const literal_data = literals.streams.one[self.literal_written_count..literals_end];
dest.writeSliceAssumeCapacity(literal_data);
self.literal_written_count += len;
},
@@ -464,15 +531,7 @@ pub const DecodeState = struct {
while (i < len) : (i += 1) {
var prefix: u16 = 0;
while (true) {
- const new_bits = self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch |err|
- switch (err) {
- error.EndOfStream => if (literals.streams == .four and self.literal_stream_index < 3) bits: {
- try self.nextLiteralMultiStream(literals);
- break :bits try self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read);
- } else {
- return error.UnexpectedEndOfLiteralStream;
- },
- };
+ const new_bits = try self.readLiteralsBits(u16, bit_count_to_read, literals);
prefix <<= bit_count_to_read;
prefix |= new_bits;
bits_read += bit_count_to_read;
@@ -514,6 +573,11 @@ const literal_table_size_max = 1 << types.compressed_block.table_accuracy_log_ma
const match_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.match;
const offset_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.match;
+const FrameError = error{
+ DictionaryIdFlagUnsupported,
+ ChecksumFailure,
+} || InvalidBit || DecodeBlockError;
+
/// Decode a Zstandard frame from `src` into `dest`, returning the number of
/// bytes read from `src` and written to `dest`; if the frame does not declare
/// its decompressed content size `error.UnknownContentSizeUnsupported` is
@@ -521,7 +585,11 @@ const offset_table_size_max = 1 << types.compressed_block.table_accuracy_log_max
/// dictionary, and `error.ChecksumFailure` if `verify_checksum` is `true` and
/// the frame contains a checksum that does not match the checksum computed from
/// the decompressed frame.
-pub fn decodeZStandardFrame(dest: []u8, src: []const u8, verify_checksum: bool) !ReadWriteCount {
+pub fn decodeZStandardFrame(
+ dest: []u8,
+ src: []const u8,
+ verify_checksum: bool,
+) (error{ UnknownContentSizeUnsupported, ContentTooLarge } || FrameError)!ReadWriteCount {
assert(readInt(u32, src[0..4]) == frame.ZStandard.magic_number);
var consumed_count: usize = 4;
@@ -530,13 +598,11 @@ pub fn decodeZStandardFrame(dest: []u8, src: []const u8, verify_checksum: bool)
if (frame_header.descriptor.dictionary_id_flag != 0) return error.DictionaryIdFlagUnsupported;
const content_size = frame_header.content_size orelse return error.UnknownContentSizeUnsupported;
- // const window_size = frameWindowSize(header) orelse return error.WindowSizeUnknown;
if (dest.len < content_size) return error.ContentTooLarge;
const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum;
var hash_state = if (should_compute_checksum) std.hash.XxHash64.init(0) else undefined;
- // TODO: block_maximum_size should be @min(1 << 17, window_size);
const written_count = try decodeFrameBlocks(
dest,
src[consumed_count..],
@@ -567,7 +633,7 @@ pub fn decodeZStandardFrameAlloc(
src: []const u8,
verify_checksum: bool,
window_size_max: usize,
-) ![]u8 {
+) (error{ WindowSizeUnknown, WindowTooLarge, OutOfMemory } || FrameError)![]u8 {
var result = std.ArrayList(u8).init(allocator);
assert(readInt(u32, src[0..4]) == frame.ZStandard.magic_number);
var consumed_count: usize = 4;
@@ -628,7 +694,7 @@ pub fn decodeZStandardFrameAlloc(
block_header = decodeBlockHeader(src[consumed_count..][0..3]);
consumed_count += 3;
}) {
- if (block_header.block_size > block_size_maximum) return error.CompressedBlockSizeOverMaximum;
+ if (block_header.block_size > block_size_maximum) return error.BlockSizeOverMaximum;
const written_size = try decodeBlockRingBuffer(
&ring_buffer,
src[consumed_count..],
@@ -637,7 +703,7 @@ pub fn decodeZStandardFrameAlloc(
&consumed_count,
block_size_maximum,
);
- if (written_size > block_size_maximum) return error.DecompressedBlockSizeOverMaximum;
+ if (written_size > block_size_maximum) return error.BlockSizeOverMaximum;
const written_slice = ring_buffer.sliceLast(written_size);
try result.appendSlice(written_slice.first);
try result.appendSlice(written_slice.second);
@@ -650,8 +716,21 @@ pub fn decodeZStandardFrameAlloc(
return result.toOwnedSlice();
}
+const DecodeBlockError = error{
+ BlockSizeOverMaximum,
+ MalformedBlockSize,
+ ReservedBlock,
+ MalformedRleBlock,
+ MalformedCompressedBlock,
+};
+
/// Convenience wrapper for decoding all blocks in a frame; see `decodeBlock()`.
-pub fn decodeFrameBlocks(dest: []u8, src: []const u8, consumed_count: *usize, hash: ?*std.hash.XxHash64) !usize {
+pub fn decodeFrameBlocks(
+ dest: []u8,
+ src: []const u8,
+ consumed_count: *usize,
+ hash: ?*std.hash.XxHash64,
+) DecodeBlockError!usize {
// These tables take 7680 bytes
var literal_fse_data: [literal_table_size_max]Table.Fse = undefined;
var match_fse_data: [match_table_size_max]Table.Fse = undefined;
@@ -702,7 +781,12 @@ pub fn decodeFrameBlocks(dest: []u8, src: []const u8, consumed_count: *usize, ha
return written_count;
}
-fn decodeRawBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: *usize) !usize {
+fn decodeRawBlock(
+ dest: []u8,
+ src: []const u8,
+ block_size: u21,
+ consumed_count: *usize,
+) error{MalformedBlockSize}!usize {
if (src.len < block_size) return error.MalformedBlockSize;
const data = src[0..block_size];
std.mem.copy(u8, dest, data);
@@ -710,7 +794,12 @@ fn decodeRawBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count:
return block_size;
}
-fn decodeRawBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, consumed_count: *usize) !usize {
+fn decodeRawBlockRingBuffer(
+ dest: *RingBuffer,
+ src: []const u8,
+ block_size: u21,
+ consumed_count: *usize,
+) error{MalformedBlockSize}!usize {
if (src.len < block_size) return error.MalformedBlockSize;
const data = src[0..block_size];
dest.writeSliceAssumeCapacity(data);
@@ -718,7 +807,12 @@ fn decodeRawBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21,
return block_size;
}
-fn decodeRleBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: *usize) !usize {
+fn decodeRleBlock(
+ dest: []u8,
+ src: []const u8,
+ block_size: u21,
+ consumed_count: *usize,
+) error{MalformedRleBlock}!usize {
if (src.len < 1) return error.MalformedRleBlock;
var write_pos: usize = 0;
while (write_pos < block_size) : (write_pos += 1) {
@@ -728,7 +822,12 @@ fn decodeRleBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count:
return block_size;
}
-fn decodeRleBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, consumed_count: *usize) !usize {
+fn decodeRleBlockRingBuffer(
+ dest: *RingBuffer,
+ src: []const u8,
+ block_size: u21,
+ consumed_count: *usize,
+) error{MalformedRleBlock}!usize {
if (src.len < 1) return error.MalformedRleBlock;
var write_pos: usize = 0;
while (write_pos < block_size) : (write_pos += 1) {
@@ -749,7 +848,7 @@ pub fn decodeBlock(
decode_state: *DecodeState,
consumed_count: *usize,
written_count: usize,
-) !usize {
+) DecodeBlockError!usize {
const block_size_max = @min(1 << 17, dest[written_count..].len); // 128KiB
const block_size = block_header.block_size;
if (block_size_max < block_size) return error.BlockSizeOverMaximum;
@@ -759,31 +858,33 @@ pub fn decodeBlock(
.compressed => {
if (src.len < block_size) return error.MalformedBlockSize;
var bytes_read: usize = 0;
- const literals = try decodeLiteralsSection(src, &bytes_read);
- const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read);
+ const literals = decodeLiteralsSection(src, &bytes_read) catch return error.MalformedCompressedBlock;
+ const sequences_header = decodeSequencesHeader(src[bytes_read..], &bytes_read) catch
+ return error.MalformedCompressedBlock;
- bytes_read += try decode_state.prepare(src[bytes_read..], literals, sequences_header);
+ bytes_read += decode_state.prepare(src[bytes_read..], literals, sequences_header) catch
+ return error.MalformedCompressedBlock;
var bytes_written: usize = 0;
if (sequences_header.sequence_count > 0) {
const bit_stream_bytes = src[bytes_read..block_size];
var bit_stream: ReverseBitReader = undefined;
- try bit_stream.init(bit_stream_bytes);
+ bit_stream.init(bit_stream_bytes) catch return error.MalformedCompressedBlock;
- try decode_state.readInitialFseState(&bit_stream);
+ decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock;
var sequence_size_limit = block_size_max;
var i: usize = 0;
while (i < sequences_header.sequence_count) : (i += 1) {
const write_pos = written_count + bytes_written;
- const decompressed_size = try decode_state.decodeSequenceSlice(
+ const decompressed_size = decode_state.decodeSequenceSlice(
dest,
write_pos,
literals,
&bit_stream,
sequence_size_limit,
i == sequences_header.sequence_count - 1,
- );
+ ) catch return error.MalformedCompressedBlock;
bytes_written += decompressed_size;
sequence_size_limit -= decompressed_size;
}
@@ -793,7 +894,8 @@ pub fn decodeBlock(
if (decode_state.literal_written_count < literals.header.regenerated_size) {
const len = literals.header.regenerated_size - decode_state.literal_written_count;
- try decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], literals, len);
+ decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], literals, len) catch
+ return error.MalformedCompressedBlock;
bytes_written += len;
}
@@ -802,7 +904,7 @@ pub fn decodeBlock(
consumed_count.* += bytes_read;
return bytes_written;
},
- .reserved => return error.FrameContainsReservedBlock,
+ .reserved => return error.ReservedBlock,
}
}
@@ -816,7 +918,7 @@ pub fn decodeBlockRingBuffer(
decode_state: *DecodeState,
consumed_count: *usize,
block_size_max: usize,
-) !usize {
+) DecodeBlockError!usize {
const block_size = block_header.block_size;
if (block_size_max < block_size) return error.BlockSizeOverMaximum;
switch (block_header.block_type) {
@@ -825,29 +927,31 @@ pub fn decodeBlockRingBuffer(
.compressed => {
if (src.len < block_size) return error.MalformedBlockSize;
var bytes_read: usize = 0;
- const literals = try decodeLiteralsSection(src, &bytes_read);
- const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read);
+ const literals = decodeLiteralsSection(src, &bytes_read) catch return error.MalformedCompressedBlock;
+ const sequences_header = decodeSequencesHeader(src[bytes_read..], &bytes_read) catch
+ return error.MalformedCompressedBlock;
- bytes_read += try decode_state.prepare(src[bytes_read..], literals, sequences_header);
+ bytes_read += decode_state.prepare(src[bytes_read..], literals, sequences_header) catch
+ return error.MalformedCompressedBlock;
var bytes_written: usize = 0;
if (sequences_header.sequence_count > 0) {
const bit_stream_bytes = src[bytes_read..block_size];
var bit_stream: ReverseBitReader = undefined;
- try bit_stream.init(bit_stream_bytes);
+ bit_stream.init(bit_stream_bytes) catch return error.MalformedCompressedBlock;
- try decode_state.readInitialFseState(&bit_stream);
+ decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock;
var sequence_size_limit = block_size_max;
var i: usize = 0;
while (i < sequences_header.sequence_count) : (i += 1) {
- const decompressed_size = try decode_state.decodeSequenceRingBuffer(
+ const decompressed_size = decode_state.decodeSequenceRingBuffer(
dest,
literals,
&bit_stream,
sequence_size_limit,
i == sequences_header.sequence_count - 1,
- );
+ ) catch return error.MalformedCompressedBlock;
bytes_written += decompressed_size;
sequence_size_limit -= decompressed_size;
}
@@ -857,7 +961,8 @@ pub fn decodeBlockRingBuffer(
if (decode_state.literal_written_count < literals.header.regenerated_size) {
const len = literals.header.regenerated_size - decode_state.literal_written_count;
- try decode_state.decodeLiteralsRingBuffer(dest, literals, len);
+ decode_state.decodeLiteralsRingBuffer(dest, literals, len) catch
+ return error.MalformedCompressedBlock;
bytes_written += len;
}
@@ -866,7 +971,7 @@ pub fn decodeBlockRingBuffer(
consumed_count.* += bytes_read;
return bytes_written;
},
- .reserved => return error.FrameContainsReservedBlock,
+ .reserved => return error.ReservedBlock,
}
}
@@ -901,9 +1006,10 @@ pub fn frameWindowSize(header: frame.ZStandard.Header) ?u64 {
} else return header.content_size;
}
+const InvalidBit = error{ UnusedBitSet, ReservedBitSet };
/// Decode the header of a Zstandard frame. Returns `error.UnusedBitSet` or
/// `error.ReservedBitSet` if the corresponding bits are sets.
-pub fn decodeZStandardHeader(src: []const u8, consumed_count: ?*usize) !frame.ZStandard.Header {
+pub fn decodeZStandardHeader(src: []const u8, consumed_count: ?*usize) InvalidBit!frame.ZStandard.Header {
const descriptor = @bitCast(frame.ZStandard.Header.Descriptor, src[0]);
if (descriptor.unused) return error.UnusedBitSet;
@@ -958,7 +1064,10 @@ pub fn decodeBlockHeader(src: *const [3]u8) frame.ZStandard.Block.Header {
/// Decode a `LiteralsSection` from `src`, incrementing `consumed_count` by the
/// number of bytes the section uses.
-pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !LiteralsSection {
+pub fn decodeLiteralsSection(
+ src: []const u8,
+ consumed_count: *usize,
+) (error{ MalformedLiteralsHeader, MalformedLiteralsSection } || DecodeHuffmanError)!LiteralsSection {
var bytes_read: usize = 0;
const header = try decodeLiteralsHeader(src, &bytes_read);
switch (header.block_type) {
@@ -1032,7 +1141,13 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !LiteralsS
}
}
-fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !LiteralsSection.HuffmanTree {
+const DecodeHuffmanError = error{
+ MalformedHuffmanTree,
+ MalformedFseTable,
+ MalformedAccuracyLog,
+};
+
+fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) DecodeHuffmanError!LiteralsSection.HuffmanTree {
var bytes_read: usize = 0;
bytes_read += 1;
if (src.len == 0) return error.MalformedHuffmanTree;
@@ -1049,22 +1164,25 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !LiteralsSection.H
var bit_reader = bitReader(counting_reader.reader());
var entries: [1 << 6]Table.Fse = undefined;
- const table_size = try decodeFseTable(&bit_reader, 256, 6, &entries);
+ const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) {
+ error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e,
+ error.EndOfStream => return error.MalformedFseTable,
+ };
const accuracy_log = std.math.log2_int_ceil(usize, table_size);
const start_index = std.math.cast(usize, 1 + counting_reader.bytes_read) orelse return error.MalformedHuffmanTree;
var huff_data = src[start_index .. compressed_size + 1];
var huff_bits: ReverseBitReader = undefined;
- try huff_bits.init(huff_data);
+ huff_bits.init(huff_data) catch return error.MalformedHuffmanTree;
var i: usize = 0;
- var even_state: u32 = try huff_bits.readBitsNoEof(u32, accuracy_log);
- var odd_state: u32 = try huff_bits.readBitsNoEof(u32, accuracy_log);
+ var even_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree;
+ var odd_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree;
while (i < 255) {
const even_data = entries[even_state];
var read_bits: usize = 0;
- const even_bits = try huff_bits.readBits(u32, even_data.bits, &read_bits);
+ const even_bits = huff_bits.readBits(u32, even_data.bits, &read_bits) catch unreachable;
weights[i] = std.math.cast(u4, even_data.symbol) orelse return error.MalformedHuffmanTree;
i += 1;
if (read_bits < even_data.bits) {
@@ -1076,7 +1194,7 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !LiteralsSection.H
read_bits = 0;
const odd_data = entries[odd_state];
- const odd_bits = try huff_bits.readBits(u32, odd_data.bits, &read_bits);
+ const odd_bits = huff_bits.readBits(u32, odd_data.bits, &read_bits) catch unreachable;
weights[i] = std.math.cast(u4, odd_data.symbol) orelse return error.MalformedHuffmanTree;
i += 1;
if (read_bits < odd_data.bits) {
@@ -1177,8 +1295,8 @@ fn lessThanByWeight(
}
/// Decode a literals section header.
-pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) !LiteralsSection.Header {
- if (src.len == 0) return error.MalformedLiteralsSection;
+pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) error{MalformedLiteralsHeader}!LiteralsSection.Header {
+ if (src.len == 0) return error.MalformedLiteralsHeader;
const byte0 = src[0];
const block_type = @intToEnum(LiteralsSection.BlockType, byte0 & 0b11);
const size_format = @intCast(u2, (byte0 & 0b1100) >> 2);
@@ -1243,8 +1361,11 @@ pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) !LiteralsSe
}
/// Decode a sequences section header.
-pub fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !SequencesSection.Header {
- if (src.len == 0) return error.MalformedSequencesSection;
+pub fn decodeSequencesHeader(
+ src: []const u8,
+ consumed_count: *usize,
+) error{ MalformedSequencesHeader, ReservedBitSet }!SequencesSection.Header {
+ if (src.len == 0) return error.MalformedSequencesHeader;
var sequence_count: u24 = undefined;
var bytes_read: usize = 0;
@@ -1262,16 +1383,16 @@ pub fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !Sequences
sequence_count = byte0;
bytes_read += 1;
} else if (byte0 < 255) {
- if (src.len < 2) return error.MalformedSequencesSection;
+ if (src.len < 2) return error.MalformedSequencesHeader;
sequence_count = (@as(u24, (byte0 - 128)) << 8) + src[1];
bytes_read += 2;
} else {
- if (src.len < 3) return error.MalformedSequencesSection;
+ if (src.len < 3) return error.MalformedSequencesHeader;
sequence_count = src[1] + (@as(u24, src[2]) << 8) + 0x7F00;
bytes_read += 3;
}
- if (src.len < bytes_read + 1) return error.MalformedSequencesSection;
+ if (src.len < bytes_read + 1) return error.MalformedSequencesHeader;
const compression_modes = src[bytes_read];
bytes_read += 1;
@@ -1441,17 +1562,17 @@ pub const ReverseBitReader = struct {
byte_reader: ReversedByteReader,
bit_reader: std.io.BitReader(.Big, ReversedByteReader.Reader),
- pub fn init(self: *ReverseBitReader, bytes: []const u8) !void {
+ pub fn init(self: *ReverseBitReader, bytes: []const u8) error{BitStreamHasNoStartBit}!void {
self.byte_reader = ReversedByteReader.init(bytes);
self.bit_reader = std.io.bitReader(.Big, self.byte_reader.reader());
while (0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) {}
}
- pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U {
+ pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) error{EndOfStream}!U {
return self.bit_reader.readBitsNoEof(U, num_bits);
}
- pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U {
+ pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) error{}!U {
return try self.bit_reader.readBits(U, num_bits, out_bits);
}
lib/std/compress/zstandard/types.zig
@@ -92,7 +92,7 @@ pub const compressed_block = struct {
index: usize,
};
- pub fn query(self: HuffmanTree, index: usize, prefix: u16) !Result {
+ pub fn query(self: HuffmanTree, index: usize, prefix: u16) error{PrefixNotFound}!Result {
var node = self.nodes[index];
const weight = node.weight;
var i: usize = index;