Commit 5723291444

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-02-02 06:19:13
std.compress.zstandard: add `decodeBlockReader`
1 parent 947ad3e
Changed files (1)
lib
std
compress
zstandard
lib/std/compress/zstandard/decompress.zig
@@ -14,29 +14,18 @@ fn readVarInt(comptime T: type, bytes: []const u8) T {
     return std.mem.readVarInt(T, bytes, .Little);
 }
 
-fn isSkippableMagic(magic: u32) bool {
+pub fn isSkippableMagic(magic: u32) bool {
     return frame.Skippable.magic_number_min <= magic and magic <= frame.Skippable.magic_number_max;
 }
 
-/// Returns the decompressed size of the frame at the start of `src`. Returns 0
-/// if the the frame is skippable, `null` for Zstanndard frames that do not
-/// declare their content size. Returns `UnusedBitSet` and `ReservedBitSet`
-/// errors if the respective bits of the the frame descriptor are set.
-pub fn getFrameDecompressedSize(src: []const u8) (InvalidBit || error{BadMagic})!?u64 {
-    switch (try frameType(src)) {
-        .zstandard => {
-            const header = try decodeZStandardHeader(src[4..], null);
-            return header.content_size;
-        },
-        .skippable => return 0,
-    }
-}
-
-/// Returns the kind of frame at the beginning of `src`. Returns `BadMagic` if
-/// `src` begin with bytes not equal to the Zstandard frame magic number, or
-/// outside the range of magic numbers for skippable frames.
-pub fn frameType(src: []const u8) error{BadMagic}!frame.Kind {
-    const magic = readInt(u32, src[0..4]);
+/// Returns the kind of frame at the beginning of `src`.
+///
+/// Errors:
+///   - returns `error.BadMagic` if `source` begins with bytes not equal to the
+///     Zstandard frame magic number, or outside the range of magic numbers for
+///     skippable frames.
+pub fn decodeFrameType(source: anytype) !frame.Kind {
+    const magic = try source.readIntLittle(u32);
     return if (magic == frame.ZStandard.magic_number)
         .zstandard
     else if (isSkippableMagic(magic))
@@ -52,15 +41,21 @@ const ReadWriteCount = struct {
 
 /// Decodes the frame at the start of `src` into `dest`. Returns the number of
 /// bytes read from `src` and written to `dest`.
+///
+/// Errors:
+///   - returns `error.UnknownContentSizeUnsupported`
+///   - returns `error.ContentTooLarge`
+///   - returns `error.BadMagic`
 pub fn decodeFrame(
     dest: []u8,
     src: []const u8,
     verify_checksum: bool,
-) (error{ UnknownContentSizeUnsupported, ContentTooLarge, BadMagic } || FrameError)!ReadWriteCount {
-    return switch (try frameType(src)) {
+) !ReadWriteCount {
+    var fbs = std.io.fixedBufferStream(src);
+    return switch (try decodeFrameType(fbs.reader())) {
         .zstandard => decodeZStandardFrame(dest, src, verify_checksum),
         .skippable => ReadWriteCount{
-            .read_count = skippableFrameSize(src[0..8]) + 8,
+            .read_count = try fbs.reader().readIntLittle(u32) + 8,
             .write_count = 0,
         },
     };
@@ -97,16 +92,52 @@ pub const DecodeState = struct {
         };
     }
 
+    pub fn init(
+        literal_fse_buffer: []Table.Fse,
+        match_fse_buffer: []Table.Fse,
+        offset_fse_buffer: []Table.Fse,
+    ) DecodeState {
+        return DecodeState{
+            .repeat_offsets = .{
+                types.compressed_block.start_repeated_offset_1,
+                types.compressed_block.start_repeated_offset_2,
+                types.compressed_block.start_repeated_offset_3,
+            },
+
+            .offset = undefined,
+            .match = undefined,
+            .literal = undefined,
+
+            .literal_fse_buffer = literal_fse_buffer,
+            .match_fse_buffer = match_fse_buffer,
+            .offset_fse_buffer = offset_fse_buffer,
+
+            .fse_tables_undefined = true,
+
+            .literal_written_count = 0,
+            .literal_header = undefined,
+            .literal_streams = undefined,
+            .literal_stream_reader = undefined,
+            .literal_stream_index = undefined,
+            .huffman_tree = null,
+        };
+    }
+
     /// Prepare the decoder to decode a compressed block. Loads the literals
-    /// stream and Huffman tree from `literals` and reads the FSE tables from `src`.
-    /// Returns `error.BitStreamHasNoStartBit` if the (reversed) literal bitstream's
-    /// first byte does not have any bits set.
+    /// stream and Huffman tree from `literals` and reads the FSE tables from
+    /// `source`.
+    ///
+    /// Errors:
+    ///   - returns `error.BitStreamHasNoStartBit` if the (reversed) literal bitstream's
+    ///     first byte does not have any bits set.
+    ///   - returns `error.TreelessLiteralsFirst` `literals` is a treeless literals section
+    ///     and the decode state does not have a Huffman tree from a previous block.
     pub fn prepare(
         self: *DecodeState,
-        src: []const u8,
+        source: anytype,
         literals: LiteralsSection,
         sequences_header: SequencesSection.Header,
-    ) (error{ BitStreamHasNoStartBit, TreelessLiteralsFirst } || FseTableError)!usize {
+    ) !void {
         self.literal_written_count = 0;
         self.literal_header = literals.header;
         self.literal_streams = literals.streams;
@@ -129,28 +160,11 @@ pub const DecodeState = struct {
         }
 
         if (sequences_header.sequence_count > 0) {
-            var bytes_read = try self.updateFseTable(
-                src,
-                .literal,
-                sequences_header.literal_lengths,
-            );
-
-            bytes_read += try self.updateFseTable(
-                src[bytes_read..],
-                .offset,
-                sequences_header.offsets,
-            );
-
-            bytes_read += try self.updateFseTable(
-                src[bytes_read..],
-                .match,
-                sequences_header.match_lengths,
-            );
+            try self.updateFseTable(source, .literal, sequences_header.literal_lengths);
+            try self.updateFseTable(source, .offset, sequences_header.offsets);
+            try self.updateFseTable(source, .match, sequences_header.match_lengths);
             self.fse_tables_undefined = false;
-
-            return bytes_read;
         }
-        return 0;
     }
 
     /// Read initial FSE states for sequence decoding. Returns `error.EndOfStream`
@@ -208,10 +222,10 @@ pub const DecodeState = struct {
 
     fn updateFseTable(
         self: *DecodeState,
-        src: []const u8,
+        source: anytype,
         comptime choice: DataType,
         mode: SequencesSection.Header.Mode,
-    ) FseTableError!usize {
+    ) !void {
         const field_name = @tagName(choice);
         switch (mode) {
             .predefined => {
@@ -220,17 +234,13 @@ pub const DecodeState = struct {
 
                 @field(self, field_name).table =
                     @field(types.compressed_block, "predefined_" ++ field_name ++ "_fse_table");
-                return 0;
             },
             .rle => {
                 @field(self, field_name).accuracy_log = 0;
-                @field(self, field_name).table = .{ .rle = src[0] };
-                return 1;
+                @field(self, field_name).table = .{ .rle = try source.readByte() };
             },
             .fse => {
-                var stream = std.io.fixedBufferStream(src);
-                var counting_reader = std.io.countingReader(stream.reader());
-                var bit_reader = bitReader(counting_reader.reader());
+                var bit_reader = bitReader(source);
 
                 const table_size = try decodeFseTable(
                     &bit_reader,
@@ -242,9 +252,8 @@ pub const DecodeState = struct {
                     .fse = @field(self, field_name ++ "_fse_buffer")[0..table_size],
                 };
                 @field(self, field_name).accuracy_log = std.math.log2_int_ceil(usize, table_size);
-                return std.math.cast(usize, counting_reader.bytes_read) orelse error.MalformedFseTable;
             },
-            .repeat => return if (self.fse_tables_undefined) error.RepeatModeFirst else 0,
+            .repeat => if (self.fse_tables_undefined) return error.RepeatModeFirst,
         }
     }
 
@@ -462,11 +471,15 @@ pub const DecodeState = struct {
                 while (i < len) : (i += 1) {
                     var prefix: u16 = 0;
                     while (true) {
-                        const new_bits = try self.readLiteralsBits(u16, bit_count_to_read);
+                        const new_bits = self.readLiteralsBits(u16, bit_count_to_read) catch |err| {
+                            return err;
+                        };
                         prefix <<= bit_count_to_read;
                         prefix |= new_bits;
                         bits_read += bit_count_to_read;
-                        const result = try huffman_tree.query(huffman_tree_index, prefix);
+                        const result = huffman_tree.query(huffman_tree_index, prefix) catch |err| {
+                            return err;
+                        };
 
                         switch (result) {
                             .symbol => |sym| {
@@ -589,11 +602,14 @@ pub fn decodeZStandardFrame(
     dest: []u8,
     src: []const u8,
     verify_checksum: bool,
-) (error{ UnknownContentSizeUnsupported, ContentTooLarge } || FrameError)!ReadWriteCount {
+) (error{ UnknownContentSizeUnsupported, ContentTooLarge, EndOfStream } || FrameError)!ReadWriteCount {
     assert(readInt(u32, src[0..4]) == frame.ZStandard.magic_number);
     var consumed_count: usize = 4;
 
-    const frame_header = try decodeZStandardHeader(src[consumed_count..], &consumed_count);
+    var fbs = std.io.fixedBufferStream(src[consumed_count..]);
+    var source = fbs.reader();
+    const frame_header = try decodeZStandardHeader(source);
+    consumed_count += fbs.pos;
 
     if (frame_header.descriptor.dictionary_id_flag != 0) return error.DictionaryIdFlagUnsupported;
 
@@ -649,18 +665,25 @@ pub const FrameContext = struct {
 /// `decodeZStandardFrame()`. Returns `error.WindowSizeUnknown` if the frame
 /// does not declare its content size or a window descriptor (this indicates a
 /// malformed frame).
+///
+/// Errors:
+///   - returns `error.WindowTooLarge`
+///   - returns `error.WindowSizeUnknown`
 pub fn decodeZStandardFrameAlloc(
     allocator: std.mem.Allocator,
     src: []const u8,
     verify_checksum: bool,
     window_size_max: usize,
-) (error{ WindowSizeUnknown, WindowTooLarge, OutOfMemory } || FrameError)![]u8 {
+) (error{ WindowSizeUnknown, WindowTooLarge, OutOfMemory, EndOfStream } || FrameError)![]u8 {
     var result = std.ArrayList(u8).init(allocator);
     assert(readInt(u32, src[0..4]) == frame.ZStandard.magic_number);
     var consumed_count: usize = 4;
 
     var frame_context = context: {
-        const frame_header = try decodeZStandardHeader(src[consumed_count..], &consumed_count);
+        var fbs = std.io.fixedBufferStream(src[consumed_count..]);
+        var source = fbs.reader();
+        const frame_header = try decodeZStandardHeader(source);
+        consumed_count += fbs.pos;
         break :context try FrameContext.init(frame_header, window_size_max, verify_checksum);
     };
 
@@ -674,30 +697,7 @@ pub fn decodeZStandardFrameAlloc(
 
     var block_header = decodeBlockHeader(src[consumed_count..][0..3]);
     consumed_count += 3;
-    var decode_state = DecodeState{
-        .repeat_offsets = .{
-            types.compressed_block.start_repeated_offset_1,
-            types.compressed_block.start_repeated_offset_2,
-            types.compressed_block.start_repeated_offset_3,
-        },
-
-        .offset = undefined,
-        .match = undefined,
-        .literal = undefined,
-
-        .literal_fse_buffer = &literal_fse_data,
-        .match_fse_buffer = &match_fse_data,
-        .offset_fse_buffer = &offset_fse_data,
-
-        .fse_tables_undefined = true,
-
-        .literal_written_count = 0,
-        .literal_header = undefined,
-        .literal_streams = undefined,
-        .literal_stream_reader = undefined,
-        .literal_stream_index = undefined,
-        .huffman_tree = null,
-    };
+    var decode_state = DecodeState.init(&literal_fse_data, &match_fse_data, &offset_fse_data);
     while (true) : ({
         block_header = decodeBlockHeader(src[consumed_count..][0..3]);
         consumed_count += 3;
@@ -754,30 +754,7 @@ pub fn decodeFrameBlocks(
     var block_header = decodeBlockHeader(src[0..3]);
     var bytes_read: usize = 3;
     defer consumed_count.* += bytes_read;
-    var decode_state = DecodeState{
-        .repeat_offsets = .{
-            types.compressed_block.start_repeated_offset_1,
-            types.compressed_block.start_repeated_offset_2,
-            types.compressed_block.start_repeated_offset_3,
-        },
-
-        .offset = undefined,
-        .match = undefined,
-        .literal = undefined,
-
-        .literal_fse_buffer = &literal_fse_data,
-        .match_fse_buffer = &match_fse_data,
-        .offset_fse_buffer = &offset_fse_data,
-
-        .fse_tables_undefined = true,
-
-        .literal_written_count = 0,
-        .literal_header = undefined,
-        .literal_streams = undefined,
-        .literal_stream_reader = undefined,
-        .literal_stream_index = undefined,
-        .huffman_tree = null,
-    };
+    var decode_state = DecodeState.init(&literal_fse_data, &match_fse_data, &offset_fse_data);
     var written_count: usize = 0;
     while (true) : ({
         block_header = decodeBlockHeader(src[bytes_read..][0..3]);
@@ -798,62 +775,6 @@ pub fn decodeFrameBlocks(
     return written_count;
 }
 
-fn decodeRawBlock(
-    dest: []u8,
-    src: []const u8,
-    block_size: u21,
-    consumed_count: *usize,
-) error{MalformedBlockSize}!usize {
-    if (src.len < block_size) return error.MalformedBlockSize;
-    const data = src[0..block_size];
-    std.mem.copy(u8, dest, data);
-    consumed_count.* += block_size;
-    return block_size;
-}
-
-fn decodeRawBlockRingBuffer(
-    dest: *RingBuffer,
-    src: []const u8,
-    block_size: u21,
-    consumed_count: *usize,
-) error{MalformedBlockSize}!usize {
-    if (src.len < block_size) return error.MalformedBlockSize;
-    const data = src[0..block_size];
-    dest.writeSliceAssumeCapacity(data);
-    consumed_count.* += block_size;
-    return block_size;
-}
-
-fn decodeRleBlock(
-    dest: []u8,
-    src: []const u8,
-    block_size: u21,
-    consumed_count: *usize,
-) error{MalformedRleBlock}!usize {
-    if (src.len < 1) return error.MalformedRleBlock;
-    var write_pos: usize = 0;
-    while (write_pos < block_size) : (write_pos += 1) {
-        dest[write_pos] = src[0];
-    }
-    consumed_count.* += 1;
-    return block_size;
-}
-
-fn decodeRleBlockRingBuffer(
-    dest: *RingBuffer,
-    src: []const u8,
-    block_size: u21,
-    consumed_count: *usize,
-) error{MalformedRleBlock}!usize {
-    if (src.len < 1) return error.MalformedRleBlock;
-    var write_pos: usize = 0;
-    while (write_pos < block_size) : (write_pos += 1) {
-        dest.writeAssumeCapacity(src[0]);
-    }
-    consumed_count.* += 1;
-    return block_size;
-}
-
 /// Decode a single block from `src` into `dest`. The beginning of `src` should
 /// be the start of the block content (i.e. directly after the block header).
 /// Increments `consumed_count` by the number of bytes read from `src` to decode
@@ -870,19 +791,37 @@ pub fn decodeBlock(
     const block_size = block_header.block_size;
     if (block_size_max < block_size) return error.BlockSizeOverMaximum;
     switch (block_header.block_type) {
-        .raw => return decodeRawBlock(dest[written_count..], src, block_size, consumed_count),
-        .rle => return decodeRleBlock(dest[written_count..], src, block_size, consumed_count),
+        .raw => {
+            if (src.len < block_size) return error.MalformedBlockSize;
+            const data = src[0..block_size];
+            std.mem.copy(u8, dest[written_count..], data);
+            consumed_count.* += block_size;
+            return block_size;
+        },
+        .rle => {
+            if (src.len < 1) return error.MalformedRleBlock;
+            var write_pos: usize = written_count;
+            while (write_pos < block_size + written_count) : (write_pos += 1) {
+                dest[write_pos] = src[0];
+            }
+            consumed_count.* += 1;
+            return block_size;
+        },
         .compressed => {
             if (src.len < block_size) return error.MalformedBlockSize;
             var bytes_read: usize = 0;
-            const literals = decodeLiteralsSection(src, &bytes_read) catch
+            const literals = decodeLiteralsSectionSlice(src, &bytes_read) catch
                 return error.MalformedCompressedBlock;
-            const sequences_header = decodeSequencesHeader(src[bytes_read..], &bytes_read) catch
+            var fbs = std.io.fixedBufferStream(src[bytes_read..]);
+            const fbs_reader = fbs.reader();
+            const sequences_header = decodeSequencesHeader(fbs_reader) catch
                 return error.MalformedCompressedBlock;
 
-            bytes_read += decode_state.prepare(src[bytes_read..], literals, sequences_header) catch
+            decode_state.prepare(fbs_reader, literals, sequences_header) catch
                 return error.MalformedCompressedBlock;
 
+            bytes_read += fbs.pos;
+
             var bytes_written: usize = 0;
             if (sequences_header.sequence_count > 0) {
                 const bit_stream_bytes = src[bytes_read..block_size];
@@ -938,19 +877,37 @@ pub fn decodeBlockRingBuffer(
     const block_size = block_header.block_size;
     if (block_size_max < block_size) return error.BlockSizeOverMaximum;
     switch (block_header.block_type) {
-        .raw => return decodeRawBlockRingBuffer(dest, src, block_size, consumed_count),
-        .rle => return decodeRleBlockRingBuffer(dest, src, block_size, consumed_count),
+        .raw => {
+            if (src.len < block_size) return error.MalformedBlockSize;
+            const data = src[0..block_size];
+            dest.writeSliceAssumeCapacity(data);
+            consumed_count.* += block_size;
+            return block_size;
+        },
+        .rle => {
+            if (src.len < 1) return error.MalformedRleBlock;
+            var write_pos: usize = 0;
+            while (write_pos < block_size) : (write_pos += 1) {
+                dest.writeAssumeCapacity(src[0]);
+            }
+            consumed_count.* += 1;
+            return block_size;
+        },
         .compressed => {
             if (src.len < block_size) return error.MalformedBlockSize;
             var bytes_read: usize = 0;
-            const literals = decodeLiteralsSection(src, &bytes_read) catch
+            const literals = decodeLiteralsSectionSlice(src, &bytes_read) catch
                 return error.MalformedCompressedBlock;
-            const sequences_header = decodeSequencesHeader(src[bytes_read..], &bytes_read) catch
+            var fbs = std.io.fixedBufferStream(src[bytes_read..]);
+            const fbs_reader = fbs.reader();
+            const sequences_header = decodeSequencesHeader(fbs_reader) catch
                 return error.MalformedCompressedBlock;
 
-            bytes_read += decode_state.prepare(src[bytes_read..], literals, sequences_header) catch
+            decode_state.prepare(fbs_reader, literals, sequences_header) catch
                 return error.MalformedCompressedBlock;
 
+            bytes_read += fbs.pos;
+
             var bytes_written: usize = 0;
             if (sequences_header.sequence_count > 0) {
                 const bit_stream_bytes = src[bytes_read..block_size];
@@ -991,6 +948,82 @@ pub fn decodeBlockRingBuffer(
     }
 }
 
+/// Decode a single block from `source` into `dest`. Literal and sequence data
+/// from the block is copied into `literals_buffer` and `sequence_buffer`, which
+/// must be large enough or `error.LiteralsBufferTooSmall` and
+/// `error.SequenceBufferTooSmall` are returned (the maximum block size is an
+/// upper bound for the size of both buffers). See `decodeBlock`
+/// and `decodeBlockRingBuffer` for function that can decode a block without
+/// these extra copies.
+pub fn decodeBlockReader(
+    dest: *RingBuffer,
+    source: anytype,
+    block_header: frame.ZStandard.Block.Header,
+    decode_state: *DecodeState,
+    block_size_max: usize,
+    literals_buffer: []u8,
+    sequence_buffer: []u8,
+) !void {
+    const block_size = block_header.block_size;
+    var block_reader_limited = std.io.limitedReader(source, block_size);
+    const block_reader = block_reader_limited.reader();
+    if (block_size_max < block_size) return error.BlockSizeOverMaximum;
+    switch (block_header.block_type) {
+        .raw => {
+            const slice = dest.sliceAt(dest.write_index, block_size);
+            try source.readNoEof(slice.first);
+            try source.readNoEof(slice.second);
+            dest.write_index = dest.mask2(dest.write_index + block_size);
+        },
+        .rle => {
+            const byte = try source.readByte();
+            var i: usize = 0;
+            while (i < block_size) : (i += 1) {
+                dest.writeAssumeCapacity(byte);
+            }
+        },
+        .compressed => {
+            const literals = try decodeLiteralsSection(block_reader, literals_buffer);
+            const sequences_header = try decodeSequencesHeader(block_reader);
+
+            try decode_state.prepare(block_reader, literals, sequences_header);
+
+            if (sequences_header.sequence_count > 0) {
+                if (sequence_buffer.len < block_reader_limited.bytes_left)
+                    return error.SequenceBufferTooSmall;
+
+                const size = try block_reader.readAll(sequence_buffer);
+                var bit_stream: ReverseBitReader = undefined;
+                try bit_stream.init(sequence_buffer[0..size]);
+
+                decode_state.readInitialFseState(&bit_stream) catch return error.MalformedCompressedBlock;
+
+                var sequence_size_limit = block_size_max;
+                var i: usize = 0;
+                while (i < sequences_header.sequence_count) : (i += 1) {
+                    const decompressed_size = decode_state.decodeSequenceRingBuffer(
+                        dest,
+                        &bit_stream,
+                        sequence_size_limit,
+                        i == sequences_header.sequence_count - 1,
+                    ) catch return error.MalformedCompressedBlock;
+                    sequence_size_limit -= decompressed_size;
+                }
+            }
+
+            if (decode_state.literal_written_count < literals.header.regenerated_size) {
+                const len = literals.header.regenerated_size - decode_state.literal_written_count;
+                decode_state.decodeLiteralsRingBuffer(dest, len) catch
+                    return error.MalformedCompressedBlock;
+            }
+
+            decode_state.literal_written_count = 0;
+            assert(block_reader.readByte() == error.EndOfStream);
+        },
+        .reserved => return error.ReservedBlock,
+    }
+}
+
 /// Decode the header of a skippable frame.
 pub fn decodeSkippableHeader(src: *const [8]u8) frame.Skippable.Header {
     const magic = readInt(u32, src[0..4]);
@@ -1002,13 +1035,6 @@ pub fn decodeSkippableHeader(src: *const [8]u8) frame.Skippable.Header {
     };
 }
 
-/// Returns the content size of a skippable frame.
-pub fn skippableFrameSize(src: *const [8]u8) usize {
-    assert(isSkippableMagic(readInt(u32, src[0..4])));
-    const frame_size = readInt(u32, src[4..8]);
-    return frame_size;
-}
-
 /// Returns the window size required to decompress a frame, or `null` if it cannot be
 /// determined, which indicates a malformed frame header.
 pub fn frameWindowSize(header: frame.ZStandard.Header) ?u64 {
@@ -1023,40 +1049,37 @@ pub fn frameWindowSize(header: frame.ZStandard.Header) ?u64 {
 }
 
 const InvalidBit = error{ UnusedBitSet, ReservedBitSet };
-/// Decode the header of a Zstandard frame. Returns `error.UnusedBitSet` or
-/// `error.ReservedBitSet` if the corresponding bits are sets.
-pub fn decodeZStandardHeader(src: []const u8, consumed_count: ?*usize) InvalidBit!frame.ZStandard.Header {
-    const descriptor = @bitCast(frame.ZStandard.Header.Descriptor, src[0]);
+/// Decode the header of a Zstandard frame.
+///
+/// Errors:
+///   - returns `error.UnusedBitSet` if the unused bits of the header are set
+///   - returns `error.ReservedBitSet` if the reserved bits of the header are
+///     set
+pub fn decodeZStandardHeader(source: anytype) (error{EndOfStream} || InvalidBit)!frame.ZStandard.Header {
+    const descriptor = @bitCast(frame.ZStandard.Header.Descriptor, try source.readByte());
 
     if (descriptor.unused) return error.UnusedBitSet;
     if (descriptor.reserved) return error.ReservedBitSet;
 
-    var bytes_read_count: usize = 1;
-
     var window_descriptor: ?u8 = null;
     if (!descriptor.single_segment_flag) {
-        window_descriptor = src[bytes_read_count];
-        bytes_read_count += 1;
+        window_descriptor = try source.readByte();
     }
 
     var dictionary_id: ?u32 = null;
     if (descriptor.dictionary_id_flag > 0) {
         // if flag is 3 then field_size = 4, else field_size = flag
         const field_size = (@as(u4, 1) << descriptor.dictionary_id_flag) >> 1;
-        dictionary_id = readVarInt(u32, src[bytes_read_count .. bytes_read_count + field_size]);
-        bytes_read_count += field_size;
+        dictionary_id = try source.readVarInt(u32, .Little, field_size);
     }
 
     var content_size: ?u64 = null;
     if (descriptor.single_segment_flag or descriptor.content_size_flag > 0) {
         const field_size = @as(u4, 1) << descriptor.content_size_flag;
-        content_size = readVarInt(u64, src[bytes_read_count .. bytes_read_count + field_size]);
+        content_size = try source.readVarInt(u64, .Little, field_size);
         if (field_size == 2) content_size.? += 256;
-        bytes_read_count += field_size;
     }
 
-    if (consumed_count) |p| p.* += bytes_read_count;
-
     const header = frame.ZStandard.Header{
         .descriptor = descriptor,
         .window_descriptor = window_descriptor,
@@ -1080,12 +1103,20 @@ pub fn decodeBlockHeader(src: *const [3]u8) frame.ZStandard.Block.Header {
 
 /// Decode a `LiteralsSection` from `src`, incrementing `consumed_count` by the
 /// number of bytes the section uses.
-pub fn decodeLiteralsSection(
+///
+/// Errors:
+///   - returns `error.MalformedLiteralsHeader` if the header is invalid
+///   - returns `error.MalformedLiteralsSection` if there are errors decoding
+pub fn decodeLiteralsSectionSlice(
     src: []const u8,
     consumed_count: *usize,
-) (error{ MalformedLiteralsHeader, MalformedLiteralsSection } || DecodeHuffmanError)!LiteralsSection {
+) (error{ MalformedLiteralsHeader, MalformedLiteralsSection, EndOfStream } || DecodeHuffmanError)!LiteralsSection {
     var bytes_read: usize = 0;
-    const header = try decodeLiteralsHeader(src, &bytes_read);
+    const header = header: {
+        var fbs = std.io.fixedBufferStream(src);
+        defer bytes_read = fbs.pos;
+        break :header decodeLiteralsHeader(fbs.reader()) catch return error.MalformedLiteralsHeader;
+    };
     switch (header.block_type) {
         .raw => {
             if (src.len < bytes_read + header.regenerated_size) return error.MalformedLiteralsSection;
@@ -1110,7 +1141,7 @@ pub fn decodeLiteralsSection(
         .compressed, .treeless => {
             const huffman_tree_start = bytes_read;
             const huffman_tree = if (header.block_type == .compressed)
-                try decodeHuffmanTree(src[bytes_read..], &bytes_read)
+                try decodeHuffmanTreeSlice(src[bytes_read..], &bytes_read)
             else
                 null;
             const huffman_tree_size = bytes_read - huffman_tree_start;
@@ -1119,137 +1150,185 @@ pub fn decodeLiteralsSection(
             if (src.len < bytes_read + total_streams_size) return error.MalformedLiteralsSection;
             const stream_data = src[bytes_read .. bytes_read + total_streams_size];
 
-            if (header.size_format == 0) {
-                consumed_count.* += total_streams_size + bytes_read;
-                return LiteralsSection{
-                    .header = header,
-                    .huffman_tree = huffman_tree,
-                    .streams = .{ .one = stream_data },
-                };
-            }
-
-            if (stream_data.len < 6) return error.MalformedLiteralsSection;
-
-            const stream_1_length = @as(usize, readInt(u16, stream_data[0..2]));
-            const stream_2_length = @as(usize, readInt(u16, stream_data[2..4]));
-            const stream_3_length = @as(usize, readInt(u16, stream_data[4..6]));
-            const stream_4_length = (total_streams_size - 6) - (stream_1_length + stream_2_length + stream_3_length);
+            const streams = try decodeStreams(header.size_format, stream_data);
+            consumed_count.* += bytes_read + total_streams_size;
+            return LiteralsSection{
+                .header = header,
+                .huffman_tree = huffman_tree,
+                .streams = streams,
+            };
+        },
+    }
+}
 
-            const stream_1_start = 6;
-            const stream_2_start = stream_1_start + stream_1_length;
-            const stream_3_start = stream_2_start + stream_2_length;
-            const stream_4_start = stream_3_start + stream_3_length;
+/// Decode a `LiteralsSection` from `src`, incrementing `consumed_count` by the
+/// number of bytes the section uses.
+///
+/// Errors:
+///   - returns `error.MalformedLiteralsHeader` if the header is invalid
+///   - returns `error.MalformedLiteralsSection` if there are errors decoding
+pub fn decodeLiteralsSection(
+    source: anytype,
+    buffer: []u8,
+) !LiteralsSection {
+    const header = try decodeLiteralsHeader(source);
+    switch (header.block_type) {
+        .raw => {
+            try source.readNoEof(buffer[0..header.regenerated_size]);
+            return LiteralsSection{
+                .header = header,
+                .huffman_tree = null,
+                .streams = .{ .one = buffer },
+            };
+        },
+        .rle => {
+            buffer[0] = try source.readByte();
+            return LiteralsSection{
+                .header = header,
+                .huffman_tree = null,
+                .streams = .{ .one = buffer[0..1] },
+            };
+        },
+        .compressed, .treeless => {
+            var counting_reader = std.io.countingReader(source);
+            const huffman_tree = if (header.block_type == .compressed)
+                try decodeHuffmanTree(counting_reader.reader(), buffer)
+            else
+                null;
+            const huffman_tree_size = counting_reader.bytes_read;
+            const total_streams_size = @as(usize, header.compressed_size.?) - @intCast(usize, huffman_tree_size);
 
-            if (stream_data.len < stream_4_start + stream_4_length) return error.MalformedLiteralsSection;
-            consumed_count.* += total_streams_size + bytes_read;
+            if (total_streams_size > buffer.len) return error.LiteralsBufferTooSmall;
+            try source.readNoEof(buffer[0..total_streams_size]);
+            const stream_data = buffer[0..total_streams_size];
 
+            const streams = try decodeStreams(header.size_format, stream_data);
             return LiteralsSection{
                 .header = header,
                 .huffman_tree = huffman_tree,
-                .streams = .{ .four = .{
-                    stream_data[stream_1_start .. stream_1_start + stream_1_length],
-                    stream_data[stream_2_start .. stream_2_start + stream_2_length],
-                    stream_data[stream_3_start .. stream_3_start + stream_3_length],
-                    stream_data[stream_4_start .. stream_4_start + stream_4_length],
-                } },
+                .streams = streams,
             };
         },
     }
 }
 
+fn decodeStreams(size_format: u2, stream_data: []const u8) !LiteralsSection.Streams {
+    if (size_format == 0) {
+        return .{ .one = stream_data };
+    }
+
+    if (stream_data.len < 6) return error.MalformedLiteralsSection;
+
+    const stream_1_length = @as(usize, readInt(u16, stream_data[0..2]));
+    const stream_2_length = @as(usize, readInt(u16, stream_data[2..4]));
+    const stream_3_length = @as(usize, readInt(u16, stream_data[4..6]));
+
+    const stream_1_start = 6;
+    const stream_2_start = stream_1_start + stream_1_length;
+    const stream_3_start = stream_2_start + stream_2_length;
+    const stream_4_start = stream_3_start + stream_3_length;
+
+    return .{ .four = .{
+        stream_data[stream_1_start .. stream_1_start + stream_1_length],
+        stream_data[stream_2_start .. stream_2_start + stream_2_length],
+        stream_data[stream_3_start .. stream_3_start + stream_3_length],
+        stream_data[stream_4_start..],
+    } };
+}
+
 const DecodeHuffmanError = error{
     MalformedHuffmanTree,
     MalformedFseTable,
     MalformedAccuracyLog,
 };
 
-fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) DecodeHuffmanError!LiteralsSection.HuffmanTree {
-    var bytes_read: usize = 0;
-    bytes_read += 1;
-    if (src.len == 0) return error.MalformedHuffmanTree;
-    const header = src[0];
-    var symbol_count: usize = undefined;
-    var weights: [256]u4 = undefined;
-    var max_number_of_bits: u4 = undefined;
-    if (header < 128) {
-        // FSE compressed weights
-        const compressed_size = header;
-        if (src.len < 1 + compressed_size) return error.MalformedHuffmanTree;
-        var stream = std.io.fixedBufferStream(src[1 .. compressed_size + 1]);
-        var counting_reader = std.io.countingReader(stream.reader());
-        var bit_reader = bitReader(counting_reader.reader());
-
-        var entries: [1 << 6]Table.Fse = undefined;
-        const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) {
-            error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e,
-            error.EndOfStream => return error.MalformedFseTable,
-        };
-        const accuracy_log = std.math.log2_int_ceil(usize, table_size);
-
-        const start_index = std.math.cast(usize, 1 + counting_reader.bytes_read) orelse return error.MalformedHuffmanTree;
-        var huff_data = src[start_index .. compressed_size + 1];
-        var huff_bits: ReverseBitReader = undefined;
-        huff_bits.init(huff_data) catch return error.MalformedHuffmanTree;
-
-        var i: usize = 0;
-        var even_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree;
-        var odd_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree;
-
-        while (i < 255) {
-            const even_data = entries[even_state];
-            var read_bits: usize = 0;
-            const even_bits = huff_bits.readBits(u32, even_data.bits, &read_bits) catch unreachable;
-            weights[i] = std.math.cast(u4, even_data.symbol) orelse return error.MalformedHuffmanTree;
-            i += 1;
-            if (read_bits < even_data.bits) {
-                weights[i] = std.math.cast(u4, entries[odd_state].symbol) orelse return error.MalformedHuffmanTree;
-                i += 1;
-                break;
-            }
-            even_state = even_data.baseline + even_bits;
+fn decodeFseHuffmanTree(source: anytype, compressed_size: usize, buffer: []u8, weights: *[256]u4) !usize {
+    var stream = std.io.limitedReader(source, compressed_size);
+    var bit_reader = bitReader(stream.reader());
 
-            read_bits = 0;
-            const odd_data = entries[odd_state];
-            const odd_bits = huff_bits.readBits(u32, odd_data.bits, &read_bits) catch unreachable;
-            weights[i] = std.math.cast(u4, odd_data.symbol) orelse return error.MalformedHuffmanTree;
-            i += 1;
-            if (read_bits < odd_data.bits) {
-                if (i == 256) return error.MalformedHuffmanTree;
-                weights[i] = std.math.cast(u4, entries[even_state].symbol) orelse return error.MalformedHuffmanTree;
-                i += 1;
-                break;
-            }
-            odd_state = odd_data.baseline + odd_bits;
-        } else return error.MalformedHuffmanTree;
+    var entries: [1 << 6]Table.Fse = undefined;
+    const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) {
+        error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e,
+        error.EndOfStream => return error.MalformedFseTable,
+    };
+    const accuracy_log = std.math.log2_int_ceil(usize, table_size);
 
-        symbol_count = i + 1; // stream contains all but the last symbol
-        bytes_read += compressed_size;
-    } else {
-        const encoded_symbol_count = header - 127;
-        symbol_count = encoded_symbol_count + 1;
-        const weights_byte_count = (encoded_symbol_count + 1) / 2;
-        if (src.len < weights_byte_count) return error.MalformedHuffmanTree;
-        var i: usize = 0;
-        while (i < weights_byte_count) : (i += 1) {
-            weights[2 * i] = @intCast(u4, src[i + 1] >> 4);
-            weights[2 * i + 1] = @intCast(u4, src[i + 1] & 0xF);
+    const amount = try stream.reader().readAll(buffer);
+    var huff_bits: ReverseBitReader = undefined;
+    huff_bits.init(buffer[0..amount]) catch return error.MalformedHuffmanTree;
+
+    return assignWeights(&huff_bits, accuracy_log, &entries, weights);
+}
+
+fn decodeFseHuffmanTreeSlice(src: []const u8, compressed_size: usize, weights: *[256]u4) !usize {
+    if (src.len < compressed_size) return error.MalformedHuffmanTree;
+    var stream = std.io.fixedBufferStream(src[0..compressed_size]);
+    var counting_reader = std.io.countingReader(stream.reader());
+    var bit_reader = bitReader(counting_reader.reader());
+
+    var entries: [1 << 6]Table.Fse = undefined;
+    const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) {
+        error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e,
+        error.EndOfStream => return error.MalformedFseTable,
+    };
+    const accuracy_log = std.math.log2_int_ceil(usize, table_size);
+
+    const start_index = std.math.cast(usize, counting_reader.bytes_read) orelse return error.MalformedHuffmanTree;
+    var huff_data = src[start_index..compressed_size];
+    var huff_bits: ReverseBitReader = undefined;
+    huff_bits.init(huff_data) catch return error.MalformedHuffmanTree;
+
+    return assignWeights(&huff_bits, accuracy_log, &entries, weights);
+}
+
+fn assignWeights(huff_bits: *ReverseBitReader, accuracy_log: usize, entries: *[1 << 6]Table.Fse, weights: *[256]u4) !usize {
+    var i: usize = 0;
+    var even_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree;
+    var odd_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree;
+
+    while (i < 255) {
+        const even_data = entries[even_state];
+        var read_bits: usize = 0;
+        const even_bits = huff_bits.readBits(u32, even_data.bits, &read_bits) catch unreachable;
+        weights[i] = std.math.cast(u4, even_data.symbol) orelse return error.MalformedHuffmanTree;
+        i += 1;
+        if (read_bits < even_data.bits) {
+            weights[i] = std.math.cast(u4, entries[odd_state].symbol) orelse return error.MalformedHuffmanTree;
+            i += 1;
+            break;
         }
-        bytes_read += weights_byte_count;
-    }
-    var weight_power_sum: u16 = 0;
-    for (weights[0 .. symbol_count - 1]) |value| {
-        if (value > 0) {
-            weight_power_sum += @as(u16, 1) << (value - 1);
+        even_state = even_data.baseline + even_bits;
+
+        read_bits = 0;
+        const odd_data = entries[odd_state];
+        const odd_bits = huff_bits.readBits(u32, odd_data.bits, &read_bits) catch unreachable;
+        weights[i] = std.math.cast(u4, odd_data.symbol) orelse return error.MalformedHuffmanTree;
+        i += 1;
+        if (read_bits < odd_data.bits) {
+            if (i == 256) return error.MalformedHuffmanTree;
+            weights[i] = std.math.cast(u4, entries[even_state].symbol) orelse return error.MalformedHuffmanTree;
+            i += 1;
+            break;
         }
-    }
+        odd_state = odd_data.baseline + odd_bits;
+    } else return error.MalformedHuffmanTree;
 
-    // advance to next power of two (even if weight_power_sum is a power of 2)
-    max_number_of_bits = std.math.log2_int(u16, weight_power_sum) + 1;
-    const next_power_of_two = @as(u16, 1) << max_number_of_bits;
-    weights[symbol_count - 1] = std.math.log2_int(u16, next_power_of_two - weight_power_sum) + 1;
+    return i + 1; // stream contains all but the last symbol
+}
 
-    var weight_sorted_prefixed_symbols: [256]LiteralsSection.HuffmanTree.PrefixedSymbol = undefined;
-    for (weight_sorted_prefixed_symbols[0..symbol_count]) |_, i| {
+fn decodeDirectHuffmanTree(source: anytype, encoded_symbol_count: usize, weights: *[256]u4) !usize {
+    const weights_byte_count = (encoded_symbol_count + 1) / 2;
+    var i: usize = 0;
+    while (i < weights_byte_count) : (i += 1) {
+        const byte = try source.readByte();
+        weights[2 * i] = @intCast(u4, byte >> 4);
+        weights[2 * i + 1] = @intCast(u4, byte & 0xF);
+    }
+    return encoded_symbol_count + 1;
+}
+
+fn assignSymbols(weight_sorted_prefixed_symbols: []LiteralsSection.HuffmanTree.PrefixedSymbol, weights: [256]u4) usize {
+    for (weight_sorted_prefixed_symbols) |_, i| {
         weight_sorted_prefixed_symbols[i] = .{
             .symbol = @intCast(u8, i),
             .weight = undefined,
@@ -1259,7 +1338,7 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) DecodeHuffmanError
 
     std.sort.sort(
         LiteralsSection.HuffmanTree.PrefixedSymbol,
-        weight_sorted_prefixed_symbols[0..symbol_count],
+        weight_sorted_prefixed_symbols,
         weights,
         lessThanByWeight,
     );
@@ -1267,6 +1346,7 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) DecodeHuffmanError
     var prefix: u16 = 0;
     var prefixed_symbol_count: usize = 0;
     var sorted_index: usize = 0;
+    const symbol_count = weight_sorted_prefixed_symbols.len;
     while (sorted_index < symbol_count) {
         var symbol = weight_sorted_prefixed_symbols[sorted_index].symbol;
         const weight = weights[symbol];
@@ -1290,7 +1370,24 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) DecodeHuffmanError
             weight_sorted_prefixed_symbols[prefixed_symbol_count].weight = weight;
         }
     }
-    consumed_count.* += bytes_read;
+    return prefixed_symbol_count;
+}
+
+fn buildHuffmanTree(weights: *[256]u4, symbol_count: usize) LiteralsSection.HuffmanTree {
+    var weight_power_sum: u16 = 0;
+    for (weights[0 .. symbol_count - 1]) |value| {
+        if (value > 0) {
+            weight_power_sum += @as(u16, 1) << (value - 1);
+        }
+    }
+
+    // advance to next power of two (even if weight_power_sum is a power of 2)
+    const max_number_of_bits = std.math.log2_int(u16, weight_power_sum) + 1;
+    const next_power_of_two = @as(u16, 1) << max_number_of_bits;
+    weights[symbol_count - 1] = std.math.log2_int(u16, next_power_of_two - weight_power_sum) + 1;
+
+    var weight_sorted_prefixed_symbols: [256]LiteralsSection.HuffmanTree.PrefixedSymbol = undefined;
+    const prefixed_symbol_count = assignSymbols(weight_sorted_prefixed_symbols[0..symbol_count], weights.*);
     const tree = LiteralsSection.HuffmanTree{
         .max_bit_count = max_number_of_bits,
         .symbol_count_minus_one = @intCast(u8, prefixed_symbol_count - 1),
@@ -1299,6 +1396,37 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) DecodeHuffmanError
     return tree;
 }
 
+fn decodeHuffmanTree(source: anytype, buffer: []u8) !LiteralsSection.HuffmanTree {
+    const header = try source.readByte();
+    var weights: [256]u4 = undefined;
+    const symbol_count = if (header < 128)
+        // FSE compressed weights
+        try decodeFseHuffmanTree(source, header, buffer, &weights)
+    else
+        try decodeDirectHuffmanTree(source, header - 127, &weights);
+
+    return buildHuffmanTree(&weights, symbol_count);
+}
+
+fn decodeHuffmanTreeSlice(src: []const u8, consumed_count: *usize) (error{EndOfStream} || DecodeHuffmanError)!LiteralsSection.HuffmanTree {
+    if (src.len == 0) return error.MalformedHuffmanTree;
+    const header = src[0];
+    var bytes_read: usize = 1;
+    var weights: [256]u4 = undefined;
+    const symbol_count = if (header < 128) count: {
+        // FSE compressed weights
+        bytes_read += header;
+        break :count try decodeFseHuffmanTreeSlice(src[1..], header, &weights);
+    } else count: {
+        var fbs = std.io.fixedBufferStream(src[1..]);
+        defer bytes_read += fbs.pos;
+        break :count try decodeDirectHuffmanTree(fbs.reader(), header - 127, &weights);
+    };
+
+    consumed_count.* += bytes_read;
+    return buildHuffmanTree(&weights, symbol_count);
+}
+
 fn lessThanByWeight(
     weights: [256]u4,
     lhs: LiteralsSection.HuffmanTree.PrefixedSymbol,
@@ -1311,9 +1439,8 @@ fn lessThanByWeight(
 }
 
 /// Decode a literals section header.
-pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) error{MalformedLiteralsHeader}!LiteralsSection.Header {
-    if (src.len == 0) return error.MalformedLiteralsHeader;
-    const byte0 = src[0];
+pub fn decodeLiteralsHeader(source: anytype) !LiteralsSection.Header {
+    const byte0 = try source.readByte();
     const block_type = @intToEnum(LiteralsSection.BlockType, byte0 & 0b11);
     const size_format = @intCast(u2, (byte0 & 0b1100) >> 2);
     var regenerated_size: u20 = undefined;
@@ -1323,47 +1450,31 @@ pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) error{Malfo
             switch (size_format) {
                 0, 2 => {
                     regenerated_size = byte0 >> 3;
-                    consumed_count.* += 1;
-                },
-                1 => {
-                    if (src.len < 2) return error.MalformedLiteralsHeader;
-                    regenerated_size = (byte0 >> 4) +
-                        (@as(u20, src[1]) << 4);
-                    consumed_count.* += 2;
-                },
-                3 => {
-                    if (src.len < 3) return error.MalformedLiteralsHeader;
-                    regenerated_size = (byte0 >> 4) +
-                        (@as(u20, src[1]) << 4) +
-                        (@as(u20, src[2]) << 12);
-                    consumed_count.* += 3;
                 },
+                1 => regenerated_size = (byte0 >> 4) + (@as(u20, try source.readByte()) << 4),
+                3 => regenerated_size = (byte0 >> 4) +
+                    (@as(u20, try source.readByte()) << 4) +
+                    (@as(u20, try source.readByte()) << 12),
             }
         },
         .compressed, .treeless => {
-            const byte1 = src[1];
-            const byte2 = src[2];
+            const byte1 = try source.readByte();
+            const byte2 = try source.readByte();
             switch (size_format) {
                 0, 1 => {
-                    if (src.len < 3) return error.MalformedLiteralsHeader;
                     regenerated_size = (byte0 >> 4) + ((@as(u20, byte1) & 0b00111111) << 4);
                     compressed_size = ((byte1 & 0b11000000) >> 6) + (@as(u18, byte2) << 2);
-                    consumed_count.* += 3;
                 },
                 2 => {
-                    if (src.len < 4) return error.MalformedLiteralsHeader;
-                    const byte3 = src[3];
+                    const byte3 = try source.readByte();
                     regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00000011) << 12);
                     compressed_size = ((byte2 & 0b11111100) >> 2) + (@as(u18, byte3) << 6);
-                    consumed_count.* += 4;
                 },
                 3 => {
-                    if (src.len < 5) return error.MalformedLiteralsHeader;
-                    const byte3 = src[3];
-                    const byte4 = src[4];
+                    const byte3 = try source.readByte();
+                    const byte4 = try source.readByte();
                     regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00111111) << 12);
                     compressed_size = ((byte2 & 0b11000000) >> 6) + (@as(u18, byte3) << 2) + (@as(u18, byte4) << 10);
-                    consumed_count.* += 5;
                 },
             }
         },
@@ -1377,18 +1488,17 @@ pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) error{Malfo
 }
 
 /// Decode a sequences section header.
+///
+/// Errors:
+///   - returns `error.ReservedBitSet` is the reserved bit is set
+///   - returns `error.MalformedSequencesHeader` if the header is invalid
 pub fn decodeSequencesHeader(
-    src: []const u8,
-    consumed_count: *usize,
-) error{ MalformedSequencesHeader, ReservedBitSet }!SequencesSection.Header {
-    if (src.len == 0) return error.MalformedSequencesHeader;
+    source: anytype,
+) !SequencesSection.Header {
     var sequence_count: u24 = undefined;
 
-    var bytes_read: usize = 0;
-    const byte0 = src[0];
+    const byte0 = try source.readByte();
     if (byte0 == 0) {
-        bytes_read += 1;
-        consumed_count.* += bytes_read;
         return SequencesSection.Header{
             .sequence_count = 0,
             .offsets = undefined,
@@ -1397,22 +1507,14 @@ pub fn decodeSequencesHeader(
         };
     } else if (byte0 < 128) {
         sequence_count = byte0;
-        bytes_read += 1;
     } else if (byte0 < 255) {
-        if (src.len < 2) return error.MalformedSequencesHeader;
-        sequence_count = (@as(u24, (byte0 - 128)) << 8) + src[1];
-        bytes_read += 2;
+        sequence_count = (@as(u24, (byte0 - 128)) << 8) + try source.readByte();
     } else {
-        if (src.len < 3) return error.MalformedSequencesHeader;
-        sequence_count = src[1] + (@as(u24, src[2]) << 8) + 0x7F00;
-        bytes_read += 3;
+        sequence_count = (try source.readByte()) + (@as(u24, try source.readByte()) << 8) + 0x7F00;
     }
 
-    if (src.len < bytes_read + 1) return error.MalformedSequencesHeader;
-    const compression_modes = src[bytes_read];
-    bytes_read += 1;
+    const compression_modes = try source.readByte();
 
-    consumed_count.* += bytes_read;
     const matches_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b00001100) >> 2);
     const offsets_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b00110000) >> 4);
     const literal_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b11000000) >> 6);
@@ -1615,7 +1717,7 @@ fn BitReader(comptime Reader: type) type {
     };
 }
 
-fn bitReader(reader: anytype) BitReader(@TypeOf(reader)) {
+pub fn bitReader(reader: anytype) BitReader(@TypeOf(reader)) {
     return .{ .underlying = std.io.bitReader(.Little, reader) };
 }