Commit 7558bf6451

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-01-24 15:30:17
std.compress.zstandard: minor cleanup and add doc comments
1 parent ab18adf
Changed files (1)
lib
std
compress
zstandard
lib/std/compress/zstandard/decompress.zig
@@ -18,6 +18,10 @@ fn isSkippableMagic(magic: u32) bool {
     return frame.Skippable.magic_number_min <= magic and magic <= frame.Skippable.magic_number_max;
 }
 
+/// Returns the decompressed size of the frame at the start of `src`. Returns 0
+/// if the the frame is skippable, `null` for Zstanndard frames that do not
+/// declare their content size. Returns `UnusedBitSet` and `ReservedBitSet`
+/// errors if the respective bits of the the frame descriptor are set.
 pub fn getFrameDecompressedSize(src: []const u8) !?usize {
     switch (try frameType(src)) {
         .zstandard => {
@@ -28,7 +32,10 @@ pub fn getFrameDecompressedSize(src: []const u8) !?usize {
     }
 }
 
-pub fn frameType(src: []const u8) !frame.Kind {
+/// Returns the kind of frame at the beginning of `src`. Returns `BadMagic` if
+/// `src` begin with bytes not equal to the Zstandard frame magic number, or
+/// outside the range of magic numbers for skippable frames.
+pub fn frameType(src: []const u8) error{BadMagic}!frame.Kind {
     const magic = readInt(u32, src[0..4]);
     return if (magic == frame.ZStandard.magic_number)
         .zstandard
@@ -43,11 +50,13 @@ const ReadWriteCount = struct {
     write_count: usize,
 };
 
+/// Decodes the frame at the start of `src` into `dest`. Returns the number of
+/// bytes read from `src` and written to `dest`.
 pub fn decodeFrame(dest: []u8, src: []const u8, verify_checksum: bool) !ReadWriteCount {
     return switch (try frameType(src)) {
         .zstandard => decodeZStandardFrame(dest, src, verify_checksum),
         .skippable => ReadWriteCount{
-            .read_count = try skippableFrameSize(src[0..8]) + 8,
+            .read_count = skippableFrameSize(src[0..8]) + 8,
             .write_count = 0,
         },
     };
@@ -82,6 +91,10 @@ pub const DecodeState = struct {
         };
     }
 
+    /// Prepare the decoder to decode a compressed block. Loads the literals
+    /// stream and Huffman tree from `literals` and reads the FSE tables from `src`.
+    /// Returns `error.BitStreamHasNoStartBit` if the (reversed) literal bitstream's
+    /// first byte does not have any bits set.
     pub fn prepare(
         self: *DecodeState,
         src: []const u8,
@@ -130,6 +143,8 @@ pub const DecodeState = struct {
         return 0;
     }
 
+    /// Read initial FSE states for sequence decoding. Returns `error.EndOfStream`
+    /// if `bit_reader` does not contain enough bits.
     pub fn readInitialFseState(self: *DecodeState, bit_reader: anytype) !void {
         self.literal.state = try bit_reader.readBitsNoEof(u9, self.literal.accuracy_log);
         self.offset.state = try bit_reader.readBitsNoEof(u8, self.offset.accuracy_log);
@@ -283,6 +298,14 @@ pub const DecodeState = struct {
         for (copy_slice.second) |b| dest.writeAssumeCapacity(b);
     }
 
+    /// Decode one sequence from `bit_reader` into `dest`, written starting at
+    /// `write_pos` and update FSE states if `last_sequence` is `false`. Returns
+    /// `error.MalformedSequence` error if the decompressed sequence would be longer
+    /// than `sequence_size_limit` or the sequence's offset is too large; returns
+    /// `error.EndOfStream` if `bit_reader` does not contain enough bits; returns
+    /// `error.UnexpectedEndOfLiteralStream` if the decoder state's literal streams
+    /// do not contain enough literals for the sequence (this may mean the literal
+    /// stream or the sequence is malformed).
     pub fn decodeSequenceSlice(
         self: *DecodeState,
         dest: []u8,
@@ -305,6 +328,7 @@ pub const DecodeState = struct {
         return sequence_length;
     }
 
+    /// Decode one sequence from `bit_reader` into `dest`; see `decodeSequenceSlice`.
     pub fn decodeSequenceRingBuffer(
         self: *DecodeState,
         dest: *RingBuffer,
@@ -335,6 +359,12 @@ pub const DecodeState = struct {
         try self.literal_stream_reader.init(bytes);
     }
 
+    /// Decode `len` bytes of literals into `dest`. `literals` should be the
+    /// `LiteralsSection` that was passed to `prepare()`. Returns
+    /// `error.MalformedLiteralsLength` if the number of literal bytes decoded by
+    /// `self` plus `len` is greater than the regenerated size of `literals`.
+    /// Returns `error.UnexpectedEndOfLiteralStream` and `error.PrefixNotFound` if
+    /// there are problems decoding Huffman compressed literals.
     pub fn decodeLiteralsSlice(self: *DecodeState, dest: []u8, literals: LiteralsSection, len: usize) !void {
         if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength;
         switch (literals.header.block_type) {
@@ -403,6 +433,7 @@ pub const DecodeState = struct {
         }
     }
 
+    /// Decode literals into `dest`; see `decodeLiteralsSlice()`.
     pub fn decodeLiteralsRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: LiteralsSection, len: usize) !void {
         if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength;
         switch (literals.header.block_type) {
@@ -483,6 +514,13 @@ const literal_table_size_max = 1 << types.compressed_block.table_accuracy_log_ma
 const match_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.match;
 const offset_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.match;
 
+/// Decode a Zstandard frame from `src` into `dest`, returning the number of
+/// bytes read from `src` and written to `dest`; if the frame does not declare
+/// its decompressed content size `error.UnknownContentSizeUnsupported` is
+/// returned. Returns `error.DictionaryIdFlagUnsupported` if the frame uses a
+/// dictionary, and `error.ChecksumFailure` if `verify_checksum` is `true` and
+/// the frame contains a checksum that does not match the checksum computed from
+/// the decompressed frame.
 pub fn decodeZStandardFrame(dest: []u8, src: []const u8, verify_checksum: bool) !ReadWriteCount {
     assert(readInt(u32, src[0..4]) == frame.ZStandard.magic_number);
     var consumed_count: usize = 4;
@@ -520,6 +558,10 @@ pub fn decodeZStandardFrame(dest: []u8, src: []const u8, verify_checksum: bool)
     return ReadWriteCount{ .read_count = consumed_count, .write_count = written_count };
 }
 
+/// Decode a Zstandard from from `src` and return the decompressed bytes; see
+/// `decodeZStandardFrame()`. Returns `error.WindowSizeUnknown` if the frame
+/// does not declare its content size or a window descriptor (this indicates a
+/// malformed frame).
 pub fn decodeZStandardFrameAlloc(allocator: std.mem.Allocator, src: []const u8, verify_checksum: bool) ![]u8 {
     var result = std.ArrayList(u8).init(allocator);
     assert(readInt(u32, src[0..4]) == frame.ZStandard.magic_number);
@@ -599,6 +641,7 @@ pub fn decodeZStandardFrameAlloc(allocator: std.mem.Allocator, src: []const u8,
     return result.toOwnedSlice();
 }
 
+/// Convenience wrapper for decoding all blocks in a frame; see `decodeBlock()`.
 pub fn decodeFrameBlocks(dest: []u8, src: []const u8, consumed_count: *usize, hash: ?*std.hash.XxHash64) !usize {
     // These tables take 7680 bytes
     var literal_fse_data: [literal_table_size_max]Table.Fse = undefined;
@@ -686,6 +729,10 @@ fn decodeRleBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21,
     return block_size;
 }
 
+/// Decode a single block from `src` into `dest`. The beginning of `src` should
+/// be the start of the block content (i.e. directly after the block header).
+/// Increments `consumed_count` by the number of bytes read from `src` to decode
+/// the block and returns the decompressed size of the block.
 pub fn decodeBlock(
     dest: []u8,
     src: []const u8,
@@ -750,6 +797,9 @@ pub fn decodeBlock(
     }
 }
 
+/// Decode a single block from `src` into `dest`; see `decodeBlock()`. Returns
+/// the size of the decompressed block, which can be used with `dest.sliceLast()`
+/// to get the decompressed bytes.
 pub fn decodeBlockRingBuffer(
     dest: *RingBuffer,
     src: []const u8,
@@ -811,6 +861,7 @@ pub fn decodeBlockRingBuffer(
     }
 }
 
+/// Decode the header of a skippable frame.
 pub fn decodeSkippableHeader(src: *const [8]u8) frame.Skippable.Header {
     const magic = readInt(u32, src[0..4]);
     assert(isSkippableMagic(magic));
@@ -821,12 +872,15 @@ pub fn decodeSkippableHeader(src: *const [8]u8) frame.Skippable.Header {
     };
 }
 
-pub fn skippableFrameSize(src: *const [8]u8) !usize {
+/// Returns the content size of a skippable frame.
+pub fn skippableFrameSize(src: *const [8]u8) usize {
     assert(isSkippableMagic(readInt(u32, src[0..4])));
     const frame_size = readInt(u32, src[4..8]);
     return frame_size;
 }
 
+/// Returns the window size required to decompress a frame, or `null` if it cannot be
+/// determined, which indicates a malformed frame header.
 pub fn frameWindowSize(header: frame.ZStandard.Header) ?u64 {
     if (header.window_descriptor) |descriptor| {
         const exponent = (descriptor & 0b11111000) >> 3;
@@ -838,6 +892,8 @@ pub fn frameWindowSize(header: frame.ZStandard.Header) ?u64 {
     } else return header.content_size;
 }
 
+/// Decode the header of a Zstandard frame. Returns `error.UnusedBitSet` or
+/// `error.ReservedBitSet` if the corresponding bits are sets.
 pub fn decodeZStandardHeader(src: []const u8, consumed_count: ?*usize) !frame.ZStandard.Header {
     const descriptor = @bitCast(frame.ZStandard.Header.Descriptor, src[0]);
 
@@ -879,6 +935,7 @@ pub fn decodeZStandardHeader(src: []const u8, consumed_count: ?*usize) !frame.ZS
     return header;
 }
 
+/// Decode the header of a block.
 pub fn decodeBlockHeader(src: *const [3]u8) frame.ZStandard.Block.Header {
     const last_block = src[0] & 1 == 1;
     const block_type = @intToEnum(frame.ZStandard.Block.Type, (src[0] & 0b110) >> 1);
@@ -890,6 +947,8 @@ pub fn decodeBlockHeader(src: *const [3]u8) frame.ZStandard.Block.Header {
     };
 }
 
+/// Decode a `LiteralsSection` from `src`, incrementing `consumed_count` by the
+/// number of bytes the section uses.
 pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !LiteralsSection {
     var bytes_read: usize = 0;
     const header = try decodeLiteralsHeader(src, &bytes_read);
@@ -1107,6 +1166,7 @@ fn lessThanByWeight(
     return weights[lhs.symbol] < weights[rhs.symbol];
 }
 
+/// Decode a literals section header.
 pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) !LiteralsSection.Header {
     if (src.len == 0) return error.MalformedLiteralsSection;
     const byte0 = src[0];
@@ -1172,6 +1232,7 @@ pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) !LiteralsSe
     };
 }
 
+/// Decode a sequences section header.
 pub fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !SequencesSection.Header {
     if (src.len == 0) return error.MalformedSequencesSection;
     var sequence_count: u24 = undefined;
@@ -1241,7 +1302,8 @@ fn buildFseTable(values: []const u16, entries: []Table.Fse) !void {
         if (value == 0 or value == 1) continue;
         const probability = value - 1;
 
-        const state_share_dividend = try std.math.ceilPowerOfTwo(u16, probability);
+        const state_share_dividend = std.math.ceilPowerOfTwo(u16, probability) catch
+            return error.MalformedFseTable;
         const share_size = @divExact(total_probability, state_share_dividend);
         const double_state_count = state_share_dividend - probability;
         const single_state_count = probability - double_state_count;
@@ -1363,6 +1425,8 @@ const ReversedByteReader = struct {
     }
 };
 
+/// A bit reader for reading the reversed bit streams used to encode
+/// FSE compressed data.
 pub const ReverseBitReader = struct {
     byte_reader: ReversedByteReader,
     bit_reader: std.io.BitReader(.Big, ReversedByteReader.Reader),