Commit 05e63f241e

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-01-22 06:11:47
std.compress.zstandard: add functions decoding into ring buffer
This supports decoding frames that do not declare the content size or decoding in a streaming fashion.
1 parent 1809172
Changed files (2)
lib
std
compress
lib/std/compress/zstandard/decompress.zig
@@ -6,6 +6,7 @@ const frame = types.frame;
 const Literals = types.compressed_block.Literals;
 const Sequences = types.compressed_block.Sequences;
 const Table = types.compressed_block.Table;
+const RingBuffer = @import("RingBuffer.zig");
 
 const readInt = std.mem.readIntLittle;
 const readIntSlice = std.mem.readIntSliceLittle;
@@ -214,7 +215,7 @@ const DecodeState = struct {
     }
 
     fn executeSequenceSlice(self: *DecodeState, dest: []u8, write_pos: usize, literals: Literals, sequence: Sequence) !void {
-        try self.decodeLiteralsInto(dest[write_pos..], literals, sequence.literal_length);
+        try self.decodeLiteralsSlice(dest[write_pos..], literals, sequence.literal_length);
 
         // TODO: should we validate offset against max_window_size?
         assert(sequence.offset <= write_pos + sequence.literal_length);
@@ -225,6 +226,15 @@ const DecodeState = struct {
         std.mem.copy(u8, dest[write_pos + sequence.literal_length ..], dest[copy_start..copy_end]);
     }
 
+    fn executeSequenceRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: Literals, sequence: Sequence) !void {
+        try self.decodeLiteralsRingBuffer(dest, literals, sequence.literal_length);
+        // TODO: check that ring buffer window is full enough for match copies
+        const copy_slice = dest.sliceAt(dest.write_index + dest.data.len - sequence.offset, sequence.match_length);
+        // TODO: would std.mem.copy and figuring out dest slice be better/faster?
+        for (copy_slice.first) |b| dest.writeAssumeCapacity(b);
+        for (copy_slice.second) |b| dest.writeAssumeCapacity(b);
+    }
+
     fn decodeSequenceSlice(
         self: *DecodeState,
         dest: []u8,
@@ -246,6 +256,31 @@ const DecodeState = struct {
         return sequence.match_length + sequence.literal_length;
     }
 
+    fn decodeSequenceRingBuffer(
+        self: *DecodeState,
+        dest: *RingBuffer,
+        literals: Literals,
+        bit_reader: anytype,
+        last_sequence: bool,
+    ) !usize {
+        const sequence = try self.nextSequence(bit_reader);
+        try self.executeSequenceRingBuffer(dest, literals, sequence);
+        if (std.options.log_level == .debug) {
+            const sequence_length = sequence.literal_length + sequence.match_length;
+            const written_slice = dest.sliceLast(sequence_length);
+            log.debug("sequence decompressed into '{x}{x}'", .{
+                std.fmt.fmtSliceHexUpper(written_slice.first),
+                std.fmt.fmtSliceHexUpper(written_slice.second),
+            });
+        }
+        if (!last_sequence) {
+            try self.updateState(.literal, bit_reader);
+            try self.updateState(.match, bit_reader);
+            try self.updateState(.offset, bit_reader);
+        }
+        return sequence.match_length + sequence.literal_length;
+    }
+
     fn nextLiteralMultiStream(self: *DecodeState, literals: Literals) !void {
         self.literal_stream_index += 1;
         try self.initLiteralStream(literals.streams.four[self.literal_stream_index]);
@@ -258,7 +293,7 @@ const DecodeState = struct {
         while (0 == try self.literal_stream_reader.readBitsNoEof(u1, 1)) {}
     }
 
-    fn decodeLiteralsInto(self: *DecodeState, dest: []u8, literals: Literals, len: usize) !void {
+    fn decodeLiteralsSlice(self: *DecodeState, dest: []u8, literals: Literals, len: usize) !void {
         if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength;
         switch (literals.header.block_type) {
             .raw => {
@@ -327,6 +362,74 @@ const DecodeState = struct {
         }
     }
 
+    fn decodeLiteralsRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: Literals, len: usize) !void {
+        if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength;
+        switch (literals.header.block_type) {
+            .raw => {
+                const literal_data = literals.streams.one[self.literal_written_count .. self.literal_written_count + len];
+                dest.writeSliceAssumeCapacity(literal_data);
+                self.literal_written_count += len;
+            },
+            .rle => {
+                var i: usize = 0;
+                while (i < len) : (i += 1) {
+                    dest.writeAssumeCapacity(literals.streams.one[0]);
+                }
+                self.literal_written_count += len;
+            },
+            .compressed, .treeless => {
+                // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4;
+                const huffman_tree = self.huffman_tree orelse unreachable;
+                const max_bit_count = huffman_tree.max_bit_count;
+                const starting_bit_count = Literals.HuffmanTree.weightToBitCount(
+                    huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight,
+                    max_bit_count,
+                );
+                var bits_read: u4 = 0;
+                var huffman_tree_index: usize = huffman_tree.symbol_count_minus_one;
+                var bit_count_to_read: u4 = starting_bit_count;
+                var i: usize = 0;
+                while (i < len) : (i += 1) {
+                    var prefix: u16 = 0;
+                    while (true) {
+                        const new_bits = self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch |err|
+                            switch (err) {
+                            error.EndOfStream => if (literals.streams == .four and self.literal_stream_index < 3) bits: {
+                                try self.nextLiteralMultiStream(literals);
+                                break :bits try self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read);
+                            } else {
+                                return error.UnexpectedEndOfLiteralStream;
+                            },
+                        };
+                        prefix <<= bit_count_to_read;
+                        prefix |= new_bits;
+                        bits_read += bit_count_to_read;
+                        const result = try huffman_tree.query(huffman_tree_index, prefix);
+
+                        switch (result) {
+                            .symbol => |sym| {
+                                dest.writeAssumeCapacity(sym);
+                                bit_count_to_read = starting_bit_count;
+                                bits_read = 0;
+                                huffman_tree_index = huffman_tree.symbol_count_minus_one;
+                                break;
+                            },
+                            .index => |index| {
+                                huffman_tree_index = index;
+                                const bit_count = Literals.HuffmanTree.weightToBitCount(
+                                    huffman_tree.nodes[index].weight,
+                                    max_bit_count,
+                                );
+                                bit_count_to_read = bit_count - bits_read;
+                            },
+                        }
+                    }
+                }
+                self.literal_written_count += len;
+            },
+        }
+    }
+
     fn getCode(self: *DecodeState, comptime choice: DataType) u32 {
         return switch (@field(self, @tagName(choice)).table) {
             .rle => |value| value,
@@ -437,6 +540,14 @@ fn decodeRawBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count:
     return block_size;
 }
 
+fn decodeRawBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, consumed_count: *usize) usize {
+    log.debug("writing raw block - size {d}", .{block_size});
+    const data = src[0..block_size];
+    dest.writeSliceAssumeCapacity(data);
+    consumed_count.* += block_size;
+    return block_size;
+}
+
 fn decodeRleBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: *usize) usize {
     log.debug("writing rle block - '{x}'x{d}", .{ src[0], block_size });
     var write_pos: usize = 0;
@@ -447,6 +558,16 @@ fn decodeRleBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count:
     return block_size;
 }
 
+fn decodeRleBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, consumed_count: *usize) usize {
+    log.debug("writing rle block - '{x}'x{d}", .{ src[0], block_size });
+    var write_pos: usize = 0;
+    while (write_pos < block_size) : (write_pos += 1) {
+        dest.writeAssumeCapacity(src[0]);
+    }
+    consumed_count.* += 1;
+    return block_size;
+}
+
 fn prepareDecodeState(
     decode_state: *DecodeState,
     src: []const u8,
@@ -545,7 +666,7 @@ pub fn decodeBlock(
             if (decode_state.literal_written_count < literals.header.regenerated_size) {
                 log.debug("decoding remaining literals", .{});
                 const len = literals.header.regenerated_size - decode_state.literal_written_count;
-                try decode_state.decodeLiteralsInto(dest[written_count + bytes_written ..], literals, len);
+                try decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], literals, len);
                 log.debug("remaining decoded literals at {d}: {}", .{
                     written_count,
                     std.fmt.fmtSliceHexUpper(dest[written_count .. written_count + len]),
@@ -562,6 +683,73 @@ pub fn decodeBlock(
     }
 }
 
+pub fn decodeBlockRingBuffer(
+    dest: *RingBuffer,
+    src: []const u8,
+    block_header: frame.ZStandard.Block.Header,
+    decode_state: *DecodeState,
+    consumed_count: *usize,
+    block_size_maximum: usize,
+) !usize {
+    const block_size = block_header.block_size;
+    if (block_size_maximum < block_size) return error.BlockSizeOverMaximum;
+    // TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks)
+    switch (block_header.block_type) {
+        .raw => return decodeRawBlockRingBuffer(dest, src, block_size, consumed_count),
+        .rle => return decodeRleBlockRingBuffer(dest, src, block_size, consumed_count),
+        .compressed => {
+            var bytes_read: usize = 0;
+            const literals = try decodeLiteralsSection(src, &bytes_read);
+            const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read);
+
+            bytes_read += try prepareDecodeState(decode_state, src[bytes_read..], literals, sequences_header);
+
+            var bytes_written: usize = 0;
+            if (sequences_header.sequence_count > 0) {
+                const bit_stream_bytes = src[bytes_read..block_size];
+                var reverse_byte_reader = reversedByteReader(bit_stream_bytes);
+                var bit_stream = reverseBitReader(reverse_byte_reader.reader());
+
+                while (0 == try bit_stream.readBitsNoEof(u1, 1)) {}
+                try decode_state.readInitialState(&bit_stream);
+
+                var i: usize = 0;
+                while (i < sequences_header.sequence_count) : (i += 1) {
+                    log.debug("decoding sequence {d}", .{i});
+                    const decompressed_size = try decode_state.decodeSequenceRingBuffer(
+                        dest,
+                        literals,
+                        &bit_stream,
+                        i == sequences_header.sequence_count - 1,
+                    );
+                    bytes_written += decompressed_size;
+                }
+
+                bytes_read += bit_stream_bytes.len;
+            }
+
+            if (decode_state.literal_written_count < literals.header.regenerated_size) {
+                log.debug("decoding remaining literals", .{});
+                const len = literals.header.regenerated_size - decode_state.literal_written_count;
+                try decode_state.decodeLiteralsRingBuffer(dest, literals, len);
+                const written_slice = dest.sliceLast(len);
+                log.debug("remaining decoded literals at {d}: {}{}", .{
+                    bytes_written,
+                    std.fmt.fmtSliceHexUpper(written_slice.first),
+                    std.fmt.fmtSliceHexUpper(written_slice.second),
+                });
+                bytes_written += len;
+            }
+
+            decode_state.literal_written_count = 0;
+            assert(bytes_read == block_header.block_size);
+            consumed_count.* += bytes_read;
+            return bytes_written;
+        },
+        .reserved => return error.FrameContainsReservedBlock,
+    }
+}
+
 pub fn decodeSkippableHeader(src: *const [8]u8) frame.Skippable.Header {
     const magic = readInt(u32, src[0..4]);
     assert(isSkippableMagic(magic));
lib/std/compress/zstandard/RingBuffer.zig
@@ -0,0 +1,81 @@
+//! This ring buffer stores read and write indices while being able to utilise the full
+//! backing slice by incrementing the indices modulo twice the slice's length and reducing
+//! indices modulo the slice's length on slice access. This means that the bit of information
+//! distinguishing whether the buffer is full or empty in an implementation utilising
+//! and extra flag is stored in difference of the indices.
+
+const assert = @import("std").debug.assert;
+
+const RingBuffer = @This();
+
+data: []u8,
+read_index: usize,
+write_index: usize,
+
+pub fn mask(self: RingBuffer, index: usize) usize {
+    return index % self.data.len;
+}
+
+pub fn mask2(self: RingBuffer, index: usize) usize {
+    return index % (2 * self.data.len);
+}
+
+pub fn write(self: *RingBuffer, byte: u8) !void {
+    if (self.isFull()) return error.Full;
+    self.writeAssumeCapacity(byte);
+}
+
+pub fn writeAssumeCapacity(self: *RingBuffer, byte: u8) void {
+    self.data[self.mask(self.write_index)] = byte;
+    self.write_index = self.mask2(self.write_index + 1);
+}
+
+pub fn writeSlice(self: *RingBuffer, bytes: []const u8) !void {
+    if (self.len() + bytes.len > self.data.len) return error.Full;
+    self.writeSliceAssumeCapacity(bytes);
+}
+
+pub fn writeSliceAssumeCapacity(self: *RingBuffer, bytes: []const u8) void {
+    for (bytes) |b| self.writeAssumeCapacity(b);
+}
+
+pub fn read(self: *RingBuffer) ?u8 {
+    if (self.isEmpty()) return null;
+    const byte = self.data[self.mask(self.read_index)];
+    self.read_index = self.mask2(self.read_index + 1);
+    return byte;
+}
+
+pub fn isEmpty(self: RingBuffer) bool {
+    return self.write_index == self.read_index;
+}
+
+pub fn isFull(self: RingBuffer) bool {
+    return self.mask2(self.write_index + self.data.len) == self.read_index;
+}
+
+pub fn len(self: RingBuffer) usize {
+    const adjusted_write_index = self.write_index + @boolToInt(self.write_index < self.read_index) * 2 * self.data.len;
+    return adjusted_write_index - self.read_index;
+}
+
+const Slice = struct {
+    first: []u8,
+    second: []u8,
+};
+
+pub fn sliceAt(self: RingBuffer, start_unmasked: usize, length: usize) Slice {
+    assert(length <= self.data.len);
+    const slice1_start = self.mask(start_unmasked);
+    const slice1_end = @min(self.data.len, slice1_start + length);
+    const slice1 = self.data[slice1_start..slice1_end];
+    const slice2 = self.data[0 .. length - slice1.len];
+    return Slice{
+        .first = slice1,
+        .second = slice2,
+    };
+}
+
+pub fn sliceLast(self: RingBuffer, length: usize) Slice {
+    return self.sliceAt(self.write_index + self.data.len - length, length);
+}