Commit 05e63f241e
Changed files (2)
lib
std
compress
zstandard
lib/std/compress/zstandard/decompress.zig
@@ -6,6 +6,7 @@ const frame = types.frame;
const Literals = types.compressed_block.Literals;
const Sequences = types.compressed_block.Sequences;
const Table = types.compressed_block.Table;
+const RingBuffer = @import("RingBuffer.zig");
const readInt = std.mem.readIntLittle;
const readIntSlice = std.mem.readIntSliceLittle;
@@ -214,7 +215,7 @@ const DecodeState = struct {
}
fn executeSequenceSlice(self: *DecodeState, dest: []u8, write_pos: usize, literals: Literals, sequence: Sequence) !void {
- try self.decodeLiteralsInto(dest[write_pos..], literals, sequence.literal_length);
+ try self.decodeLiteralsSlice(dest[write_pos..], literals, sequence.literal_length);
// TODO: should we validate offset against max_window_size?
assert(sequence.offset <= write_pos + sequence.literal_length);
@@ -225,6 +226,15 @@ const DecodeState = struct {
std.mem.copy(u8, dest[write_pos + sequence.literal_length ..], dest[copy_start..copy_end]);
}
+ fn executeSequenceRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: Literals, sequence: Sequence) !void {
+ try self.decodeLiteralsRingBuffer(dest, literals, sequence.literal_length);
+ // TODO: check that ring buffer window is full enough for match copies
+ const copy_slice = dest.sliceAt(dest.write_index + dest.data.len - sequence.offset, sequence.match_length);
+ // TODO: would std.mem.copy and figuring out dest slice be better/faster?
+ for (copy_slice.first) |b| dest.writeAssumeCapacity(b);
+ for (copy_slice.second) |b| dest.writeAssumeCapacity(b);
+ }
+
fn decodeSequenceSlice(
self: *DecodeState,
dest: []u8,
@@ -246,6 +256,31 @@ const DecodeState = struct {
return sequence.match_length + sequence.literal_length;
}
+ fn decodeSequenceRingBuffer(
+ self: *DecodeState,
+ dest: *RingBuffer,
+ literals: Literals,
+ bit_reader: anytype,
+ last_sequence: bool,
+ ) !usize {
+ const sequence = try self.nextSequence(bit_reader);
+ try self.executeSequenceRingBuffer(dest, literals, sequence);
+ if (std.options.log_level == .debug) {
+ const sequence_length = sequence.literal_length + sequence.match_length;
+ const written_slice = dest.sliceLast(sequence_length);
+ log.debug("sequence decompressed into '{x}{x}'", .{
+ std.fmt.fmtSliceHexUpper(written_slice.first),
+ std.fmt.fmtSliceHexUpper(written_slice.second),
+ });
+ }
+ if (!last_sequence) {
+ try self.updateState(.literal, bit_reader);
+ try self.updateState(.match, bit_reader);
+ try self.updateState(.offset, bit_reader);
+ }
+ return sequence.match_length + sequence.literal_length;
+ }
+
fn nextLiteralMultiStream(self: *DecodeState, literals: Literals) !void {
self.literal_stream_index += 1;
try self.initLiteralStream(literals.streams.four[self.literal_stream_index]);
@@ -258,7 +293,7 @@ const DecodeState = struct {
while (0 == try self.literal_stream_reader.readBitsNoEof(u1, 1)) {}
}
- fn decodeLiteralsInto(self: *DecodeState, dest: []u8, literals: Literals, len: usize) !void {
+ fn decodeLiteralsSlice(self: *DecodeState, dest: []u8, literals: Literals, len: usize) !void {
if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength;
switch (literals.header.block_type) {
.raw => {
@@ -327,6 +362,74 @@ const DecodeState = struct {
}
}
+ fn decodeLiteralsRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: Literals, len: usize) !void {
+ if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength;
+ switch (literals.header.block_type) {
+ .raw => {
+ const literal_data = literals.streams.one[self.literal_written_count .. self.literal_written_count + len];
+ dest.writeSliceAssumeCapacity(literal_data);
+ self.literal_written_count += len;
+ },
+ .rle => {
+ var i: usize = 0;
+ while (i < len) : (i += 1) {
+ dest.writeAssumeCapacity(literals.streams.one[0]);
+ }
+ self.literal_written_count += len;
+ },
+ .compressed, .treeless => {
+ // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4;
+ const huffman_tree = self.huffman_tree orelse unreachable;
+ const max_bit_count = huffman_tree.max_bit_count;
+ const starting_bit_count = Literals.HuffmanTree.weightToBitCount(
+ huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight,
+ max_bit_count,
+ );
+ var bits_read: u4 = 0;
+ var huffman_tree_index: usize = huffman_tree.symbol_count_minus_one;
+ var bit_count_to_read: u4 = starting_bit_count;
+ var i: usize = 0;
+ while (i < len) : (i += 1) {
+ var prefix: u16 = 0;
+ while (true) {
+ const new_bits = self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch |err|
+ switch (err) {
+ error.EndOfStream => if (literals.streams == .four and self.literal_stream_index < 3) bits: {
+ try self.nextLiteralMultiStream(literals);
+ break :bits try self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read);
+ } else {
+ return error.UnexpectedEndOfLiteralStream;
+ },
+ };
+ prefix <<= bit_count_to_read;
+ prefix |= new_bits;
+ bits_read += bit_count_to_read;
+ const result = try huffman_tree.query(huffman_tree_index, prefix);
+
+ switch (result) {
+ .symbol => |sym| {
+ dest.writeAssumeCapacity(sym);
+ bit_count_to_read = starting_bit_count;
+ bits_read = 0;
+ huffman_tree_index = huffman_tree.symbol_count_minus_one;
+ break;
+ },
+ .index => |index| {
+ huffman_tree_index = index;
+ const bit_count = Literals.HuffmanTree.weightToBitCount(
+ huffman_tree.nodes[index].weight,
+ max_bit_count,
+ );
+ bit_count_to_read = bit_count - bits_read;
+ },
+ }
+ }
+ }
+ self.literal_written_count += len;
+ },
+ }
+ }
+
fn getCode(self: *DecodeState, comptime choice: DataType) u32 {
return switch (@field(self, @tagName(choice)).table) {
.rle => |value| value,
@@ -437,6 +540,14 @@ fn decodeRawBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count:
return block_size;
}
+fn decodeRawBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, consumed_count: *usize) usize {
+ log.debug("writing raw block - size {d}", .{block_size});
+ const data = src[0..block_size];
+ dest.writeSliceAssumeCapacity(data);
+ consumed_count.* += block_size;
+ return block_size;
+}
+
fn decodeRleBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: *usize) usize {
log.debug("writing rle block - '{x}'x{d}", .{ src[0], block_size });
var write_pos: usize = 0;
@@ -447,6 +558,16 @@ fn decodeRleBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count:
return block_size;
}
+fn decodeRleBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21, consumed_count: *usize) usize {
+ log.debug("writing rle block - '{x}'x{d}", .{ src[0], block_size });
+ var write_pos: usize = 0;
+ while (write_pos < block_size) : (write_pos += 1) {
+ dest.writeAssumeCapacity(src[0]);
+ }
+ consumed_count.* += 1;
+ return block_size;
+}
+
fn prepareDecodeState(
decode_state: *DecodeState,
src: []const u8,
@@ -545,7 +666,7 @@ pub fn decodeBlock(
if (decode_state.literal_written_count < literals.header.regenerated_size) {
log.debug("decoding remaining literals", .{});
const len = literals.header.regenerated_size - decode_state.literal_written_count;
- try decode_state.decodeLiteralsInto(dest[written_count + bytes_written ..], literals, len);
+ try decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], literals, len);
log.debug("remaining decoded literals at {d}: {}", .{
written_count,
std.fmt.fmtSliceHexUpper(dest[written_count .. written_count + len]),
@@ -562,6 +683,73 @@ pub fn decodeBlock(
}
}
+pub fn decodeBlockRingBuffer(
+ dest: *RingBuffer,
+ src: []const u8,
+ block_header: frame.ZStandard.Block.Header,
+ decode_state: *DecodeState,
+ consumed_count: *usize,
+ block_size_maximum: usize,
+) !usize {
+ const block_size = block_header.block_size;
+ if (block_size_maximum < block_size) return error.BlockSizeOverMaximum;
+ // TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks)
+ switch (block_header.block_type) {
+ .raw => return decodeRawBlockRingBuffer(dest, src, block_size, consumed_count),
+ .rle => return decodeRleBlockRingBuffer(dest, src, block_size, consumed_count),
+ .compressed => {
+ var bytes_read: usize = 0;
+ const literals = try decodeLiteralsSection(src, &bytes_read);
+ const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read);
+
+ bytes_read += try prepareDecodeState(decode_state, src[bytes_read..], literals, sequences_header);
+
+ var bytes_written: usize = 0;
+ if (sequences_header.sequence_count > 0) {
+ const bit_stream_bytes = src[bytes_read..block_size];
+ var reverse_byte_reader = reversedByteReader(bit_stream_bytes);
+ var bit_stream = reverseBitReader(reverse_byte_reader.reader());
+
+ while (0 == try bit_stream.readBitsNoEof(u1, 1)) {}
+ try decode_state.readInitialState(&bit_stream);
+
+ var i: usize = 0;
+ while (i < sequences_header.sequence_count) : (i += 1) {
+ log.debug("decoding sequence {d}", .{i});
+ const decompressed_size = try decode_state.decodeSequenceRingBuffer(
+ dest,
+ literals,
+ &bit_stream,
+ i == sequences_header.sequence_count - 1,
+ );
+ bytes_written += decompressed_size;
+ }
+
+ bytes_read += bit_stream_bytes.len;
+ }
+
+ if (decode_state.literal_written_count < literals.header.regenerated_size) {
+ log.debug("decoding remaining literals", .{});
+ const len = literals.header.regenerated_size - decode_state.literal_written_count;
+ try decode_state.decodeLiteralsRingBuffer(dest, literals, len);
+ const written_slice = dest.sliceLast(len);
+ log.debug("remaining decoded literals at {d}: {}{}", .{
+ bytes_written,
+ std.fmt.fmtSliceHexUpper(written_slice.first),
+ std.fmt.fmtSliceHexUpper(written_slice.second),
+ });
+ bytes_written += len;
+ }
+
+ decode_state.literal_written_count = 0;
+ assert(bytes_read == block_header.block_size);
+ consumed_count.* += bytes_read;
+ return bytes_written;
+ },
+ .reserved => return error.FrameContainsReservedBlock,
+ }
+}
+
pub fn decodeSkippableHeader(src: *const [8]u8) frame.Skippable.Header {
const magic = readInt(u32, src[0..4]);
assert(isSkippableMagic(magic));
lib/std/compress/zstandard/RingBuffer.zig
@@ -0,0 +1,81 @@
+//! This ring buffer stores read and write indices while being able to utilise the full
+//! backing slice by incrementing the indices modulo twice the slice's length and reducing
+//! indices modulo the slice's length on slice access. This means that the bit of information
+//! distinguishing whether the buffer is full or empty in an implementation utilising
+//! and extra flag is stored in difference of the indices.
+
+const assert = @import("std").debug.assert;
+
+const RingBuffer = @This();
+
+data: []u8,
+read_index: usize,
+write_index: usize,
+
+pub fn mask(self: RingBuffer, index: usize) usize {
+ return index % self.data.len;
+}
+
+pub fn mask2(self: RingBuffer, index: usize) usize {
+ return index % (2 * self.data.len);
+}
+
+pub fn write(self: *RingBuffer, byte: u8) !void {
+ if (self.isFull()) return error.Full;
+ self.writeAssumeCapacity(byte);
+}
+
+pub fn writeAssumeCapacity(self: *RingBuffer, byte: u8) void {
+ self.data[self.mask(self.write_index)] = byte;
+ self.write_index = self.mask2(self.write_index + 1);
+}
+
+pub fn writeSlice(self: *RingBuffer, bytes: []const u8) !void {
+ if (self.len() + bytes.len > self.data.len) return error.Full;
+ self.writeSliceAssumeCapacity(bytes);
+}
+
+pub fn writeSliceAssumeCapacity(self: *RingBuffer, bytes: []const u8) void {
+ for (bytes) |b| self.writeAssumeCapacity(b);
+}
+
+pub fn read(self: *RingBuffer) ?u8 {
+ if (self.isEmpty()) return null;
+ const byte = self.data[self.mask(self.read_index)];
+ self.read_index = self.mask2(self.read_index + 1);
+ return byte;
+}
+
+pub fn isEmpty(self: RingBuffer) bool {
+ return self.write_index == self.read_index;
+}
+
+pub fn isFull(self: RingBuffer) bool {
+ return self.mask2(self.write_index + self.data.len) == self.read_index;
+}
+
+pub fn len(self: RingBuffer) usize {
+ const adjusted_write_index = self.write_index + @boolToInt(self.write_index < self.read_index) * 2 * self.data.len;
+ return adjusted_write_index - self.read_index;
+}
+
+const Slice = struct {
+ first: []u8,
+ second: []u8,
+};
+
+pub fn sliceAt(self: RingBuffer, start_unmasked: usize, length: usize) Slice {
+ assert(length <= self.data.len);
+ const slice1_start = self.mask(start_unmasked);
+ const slice1_end = @min(self.data.len, slice1_start + length);
+ const slice1 = self.data[slice1_start..slice1_end];
+ const slice2 = self.data[0 .. length - slice1.len];
+ return Slice{
+ .first = slice1,
+ .second = slice2,
+ };
+}
+
+pub fn sliceLast(self: RingBuffer, length: usize) Slice {
+ return self.sliceAt(self.write_index + self.data.len - length, length);
+}