Commit fc64c279a4

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-01-23 06:26:03
std.compress.zstandard: clean up api
1 parent cbfaa87
Changed files (3)
lib
lib/std/compress/zstandard/decompress.zig
@@ -3,10 +3,10 @@ const assert = std.debug.assert;
 
 const types = @import("types.zig");
 const frame = types.frame;
-const Literals = types.compressed_block.Literals;
-const Sequences = types.compressed_block.Sequences;
+const LiteralsSection = types.compressed_block.LiteralsSection;
+const SequencesSection = types.compressed_block.SequencesSection;
 const Table = types.compressed_block.Table;
-const RingBuffer = @import("RingBuffer.zig");
+pub const RingBuffer = @import("RingBuffer.zig");
 
 const readInt = std.mem.readIntLittle;
 const readIntSlice = std.mem.readIntSliceLittle;
@@ -55,7 +55,7 @@ pub fn decodeFrame(dest: []u8, src: []const u8, verify_checksum: bool) !ReadWrit
     };
 }
 
-const DecodeState = struct {
+pub const DecodeState = struct {
     repeat_offsets: [3]u32,
 
     offset: StateData(8),
@@ -70,7 +70,7 @@ const DecodeState = struct {
 
     literal_stream_reader: ReverseBitReader,
     literal_stream_index: usize,
-    huffman_tree: ?Literals.HuffmanTree,
+    huffman_tree: ?LiteralsSection.HuffmanTree,
 
     literal_written_count: usize,
 
@@ -84,7 +84,55 @@ const DecodeState = struct {
         };
     }
 
-    fn readInitialState(self: *DecodeState, bit_reader: anytype) !void {
+    pub fn prepare(
+        self: *DecodeState,
+        src: []const u8,
+        literals: LiteralsSection,
+        sequences_header: SequencesSection.Header,
+    ) !usize {
+        if (literals.huffman_tree) |tree| {
+            self.huffman_tree = tree;
+        } else if (literals.header.block_type == .treeless and self.huffman_tree == null) {
+            return error.TreelessLiteralsFirst;
+        }
+
+        switch (literals.header.block_type) {
+            .raw, .rle => {},
+            .compressed, .treeless => {
+                self.literal_stream_index = 0;
+                switch (literals.streams) {
+                    .one => |slice| try self.initLiteralStream(slice),
+                    .four => |streams| try self.initLiteralStream(streams[0]),
+                }
+            },
+        }
+
+        if (sequences_header.sequence_count > 0) {
+            var bytes_read = try self.updateFseTable(
+                src,
+                .literal,
+                sequences_header.literal_lengths,
+            );
+
+            bytes_read += try self.updateFseTable(
+                src[bytes_read..],
+                .offset,
+                sequences_header.offsets,
+            );
+
+            bytes_read += try self.updateFseTable(
+                src[bytes_read..],
+                .match,
+                sequences_header.match_lengths,
+            );
+            self.fse_tables_undefined = false;
+
+            return bytes_read;
+        }
+        return 0;
+    }
+
+    pub fn readInitialFseState(self: *DecodeState, bit_reader: anytype) !void {
         self.literal.state = try bit_reader.readBitsNoEof(u9, self.literal.accuracy_log);
         self.offset.state = try bit_reader.readBitsNoEof(u8, self.offset.accuracy_log);
         self.match.state = try bit_reader.readBitsNoEof(u9, self.match.accuracy_log);
@@ -130,7 +178,7 @@ const DecodeState = struct {
         self: *DecodeState,
         src: []const u8,
         comptime choice: DataType,
-        mode: Sequences.Header.Mode,
+        mode: SequencesSection.Header.Mode,
     ) !usize {
         const field_name = @tagName(choice);
         switch (mode) {
@@ -213,7 +261,13 @@ const DecodeState = struct {
         };
     }
 
-    fn executeSequenceSlice(self: *DecodeState, dest: []u8, write_pos: usize, literals: Literals, sequence: Sequence) !void {
+    fn executeSequenceSlice(
+        self: *DecodeState,
+        dest: []u8,
+        write_pos: usize,
+        literals: LiteralsSection,
+        sequence: Sequence,
+    ) !void {
         try self.decodeLiteralsSlice(dest[write_pos..], literals, sequence.literal_length);
 
         // TODO: should we validate offset against max_window_size?
@@ -225,7 +279,12 @@ const DecodeState = struct {
         std.mem.copy(u8, dest[write_pos + sequence.literal_length ..], dest[copy_start..copy_end]);
     }
 
-    fn executeSequenceRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: Literals, sequence: Sequence) !void {
+    fn executeSequenceRingBuffer(
+        self: *DecodeState,
+        dest: *RingBuffer,
+        literals: LiteralsSection,
+        sequence: Sequence,
+    ) !void {
         try self.decodeLiteralsRingBuffer(dest, literals, sequence.literal_length);
         // TODO: check that ring buffer window is full enough for match copies
         const copy_slice = dest.sliceAt(dest.write_index + dest.data.len - sequence.offset, sequence.match_length);
@@ -234,11 +293,11 @@ const DecodeState = struct {
         for (copy_slice.second) |b| dest.writeAssumeCapacity(b);
     }
 
-    fn decodeSequenceSlice(
+    pub fn decodeSequenceSlice(
         self: *DecodeState,
         dest: []u8,
         write_pos: usize,
-        literals: Literals,
+        literals: LiteralsSection,
         bit_reader: anytype,
         last_sequence: bool,
     ) !usize {
@@ -255,10 +314,10 @@ const DecodeState = struct {
         return sequence.match_length + sequence.literal_length;
     }
 
-    fn decodeSequenceRingBuffer(
+    pub fn decodeSequenceRingBuffer(
         self: *DecodeState,
         dest: *RingBuffer,
-        literals: Literals,
+        literals: LiteralsSection,
         bit_reader: anytype,
         last_sequence: bool,
     ) !usize {
@@ -280,7 +339,7 @@ const DecodeState = struct {
         return sequence.match_length + sequence.literal_length;
     }
 
-    fn nextLiteralMultiStream(self: *DecodeState, literals: Literals) !void {
+    fn nextLiteralMultiStream(self: *DecodeState, literals: LiteralsSection) !void {
         self.literal_stream_index += 1;
         try self.initLiteralStream(literals.streams.four[self.literal_stream_index]);
     }
@@ -290,7 +349,7 @@ const DecodeState = struct {
         try self.literal_stream_reader.init(bytes);
     }
 
-    fn decodeLiteralsSlice(self: *DecodeState, dest: []u8, literals: Literals, len: usize) !void {
+    pub fn decodeLiteralsSlice(self: *DecodeState, dest: []u8, literals: LiteralsSection, len: usize) !void {
         if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength;
         switch (literals.header.block_type) {
             .raw => {
@@ -310,7 +369,7 @@ const DecodeState = struct {
                 // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4;
                 const huffman_tree = self.huffman_tree orelse unreachable;
                 const max_bit_count = huffman_tree.max_bit_count;
-                const starting_bit_count = Literals.HuffmanTree.weightToBitCount(
+                const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
                     huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight,
                     max_bit_count,
                 );
@@ -345,7 +404,7 @@ const DecodeState = struct {
                             },
                             .index => |index| {
                                 huffman_tree_index = index;
-                                const bit_count = Literals.HuffmanTree.weightToBitCount(
+                                const bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
                                     huffman_tree.nodes[index].weight,
                                     max_bit_count,
                                 );
@@ -359,7 +418,7 @@ const DecodeState = struct {
         }
     }
 
-    fn decodeLiteralsRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: Literals, len: usize) !void {
+    pub fn decodeLiteralsRingBuffer(self: *DecodeState, dest: *RingBuffer, literals: LiteralsSection, len: usize) !void {
         if (self.literal_written_count + len > literals.header.regenerated_size) return error.MalformedLiteralsLength;
         switch (literals.header.block_type) {
             .raw => {
@@ -378,7 +437,7 @@ const DecodeState = struct {
                 // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4;
                 const huffman_tree = self.huffman_tree orelse unreachable;
                 const max_bit_count = huffman_tree.max_bit_count;
-                const starting_bit_count = Literals.HuffmanTree.weightToBitCount(
+                const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
                     huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight,
                     max_bit_count,
                 );
@@ -413,7 +472,7 @@ const DecodeState = struct {
                             },
                             .index => |index| {
                                 huffman_tree_index = index;
-                                const bit_count = Literals.HuffmanTree.weightToBitCount(
+                                const bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
                                     huffman_tree.nodes[index].weight,
                                     max_bit_count,
                                 );
@@ -647,54 +706,6 @@ fn decodeRleBlockRingBuffer(dest: *RingBuffer, src: []const u8, block_size: u21,
     return block_size;
 }
 
-fn prepareDecodeState(
-    decode_state: *DecodeState,
-    src: []const u8,
-    literals: Literals,
-    sequences_header: Sequences.Header,
-) !usize {
-    if (literals.huffman_tree) |tree| {
-        decode_state.huffman_tree = tree;
-    } else if (literals.header.block_type == .treeless and decode_state.huffman_tree == null) {
-        return error.TreelessLiteralsFirst;
-    }
-
-    switch (literals.header.block_type) {
-        .raw, .rle => {},
-        .compressed, .treeless => {
-            decode_state.literal_stream_index = 0;
-            switch (literals.streams) {
-                .one => |slice| try decode_state.initLiteralStream(slice),
-                .four => |streams| try decode_state.initLiteralStream(streams[0]),
-            }
-        },
-    }
-
-    if (sequences_header.sequence_count > 0) {
-        var bytes_read = try decode_state.updateFseTable(
-            src,
-            .literal,
-            sequences_header.literal_lengths,
-        );
-
-        bytes_read += try decode_state.updateFseTable(
-            src[bytes_read..],
-            .offset,
-            sequences_header.offsets,
-        );
-
-        bytes_read += try decode_state.updateFseTable(
-            src[bytes_read..],
-            .match,
-            sequences_header.match_lengths,
-        );
-        decode_state.fse_tables_undefined = false;
-
-        return bytes_read;
-    }
-    return 0;
-}
-
 pub fn decodeBlock(
     dest: []u8,
     src: []const u8,
@@ -715,7 +726,7 @@ pub fn decodeBlock(
             const literals = try decodeLiteralsSection(src, &bytes_read);
             const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read);
 
-            bytes_read += try prepareDecodeState(decode_state, src[bytes_read..], literals, sequences_header);
+            bytes_read += try decode_state.prepare(src[bytes_read..], literals, sequences_header);
 
             var bytes_written: usize = 0;
             if (sequences_header.sequence_count > 0) {
@@ -723,7 +734,7 @@ pub fn decodeBlock(
                 var bit_stream: ReverseBitReader = undefined;
                 try bit_stream.init(bit_stream_bytes);
 
-                try decode_state.readInitialState(&bit_stream);
+                try decode_state.readInitialFseState(&bit_stream);
 
                 var i: usize = 0;
                 while (i < sequences_header.sequence_count) : (i += 1) {
@@ -780,7 +791,7 @@ pub fn decodeBlockRingBuffer(
             const literals = try decodeLiteralsSection(src, &bytes_read);
             const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read);
 
-            bytes_read += try prepareDecodeState(decode_state, src[bytes_read..], literals, sequences_header);
+            bytes_read += try decode_state.prepare(src[bytes_read..], literals, sequences_header);
 
             var bytes_written: usize = 0;
             if (sequences_header.sequence_count > 0) {
@@ -788,7 +799,7 @@ pub fn decodeBlockRingBuffer(
                 var bit_stream: ReverseBitReader = undefined;
                 try bit_stream.init(bit_stream_bytes);
 
-                try decode_state.readInitialState(&bit_stream);
+                try decode_state.readInitialFseState(&bit_stream);
 
                 var i: usize = 0;
                 while (i < sequences_header.sequence_count) : (i += 1) {
@@ -928,7 +939,7 @@ pub fn decodeBlockHeader(src: *const [3]u8) frame.ZStandard.Block.Header {
     };
 }
 
-pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals {
+pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !LiteralsSection {
     // TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks)
     var bytes_read: usize = 0;
     const header = decodeLiteralsHeader(src, &bytes_read);
@@ -936,7 +947,7 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals
         .raw => {
             const stream = src[bytes_read .. bytes_read + header.regenerated_size];
             consumed_count.* += header.regenerated_size + bytes_read;
-            return Literals{
+            return LiteralsSection{
                 .header = header,
                 .huffman_tree = null,
                 .streams = .{ .one = stream },
@@ -945,7 +956,7 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals
         .rle => {
             const stream = src[bytes_read .. bytes_read + 1];
             consumed_count.* += 1 + bytes_read;
-            return Literals{
+            return LiteralsSection{
                 .header = header,
                 .huffman_tree = null,
                 .streams = .{ .one = stream },
@@ -966,7 +977,7 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals
                 const stream = src[bytes_read .. bytes_read + total_streams_size];
                 bytes_read += total_streams_size;
                 consumed_count.* += bytes_read;
-                return Literals{
+                return LiteralsSection{
                     .header = header,
                     .huffman_tree = huffman_tree,
                     .streams = .{ .one = stream },
@@ -988,7 +999,7 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals
 
             consumed_count.* += total_streams_size + bytes_read;
 
-            return Literals{
+            return LiteralsSection{
                 .header = header,
                 .huffman_tree = huffman_tree,
                 .streams = .{ .four = .{
@@ -1002,7 +1013,7 @@ pub fn decodeLiteralsSection(src: []const u8, consumed_count: *usize) !Literals
     }
 }
 
-fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !Literals.HuffmanTree {
+fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !LiteralsSection.HuffmanTree {
     var bytes_read: usize = 0;
     bytes_read += 1;
     const header = src[0];
@@ -1094,7 +1105,7 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !Literals.HuffmanT
     weights[symbol_count - 1] = @intCast(u4, std.math.log2_int(u16, next_power_of_two - weight_power_sum) + 1);
     log.debug("weights[{d}] = {d}", .{ symbol_count - 1, weights[symbol_count - 1] });
 
-    var weight_sorted_prefixed_symbols: [256]Literals.HuffmanTree.PrefixedSymbol = undefined;
+    var weight_sorted_prefixed_symbols: [256]LiteralsSection.HuffmanTree.PrefixedSymbol = undefined;
     for (weight_sorted_prefixed_symbols[0..symbol_count]) |_, i| {
         weight_sorted_prefixed_symbols[i] = .{
             .symbol = @intCast(u8, i),
@@ -1104,7 +1115,7 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !Literals.HuffmanT
     }
 
     std.sort.sort(
-        Literals.HuffmanTree.PrefixedSymbol,
+        LiteralsSection.HuffmanTree.PrefixedSymbol,
         weight_sorted_prefixed_symbols[0..symbol_count],
         weights,
         lessThanByWeight,
@@ -1137,7 +1148,7 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !Literals.HuffmanT
         }
     }
     consumed_count.* += bytes_read;
-    const tree = Literals.HuffmanTree{
+    const tree = LiteralsSection.HuffmanTree{
         .max_bit_count = max_number_of_bits,
         .symbol_count_minus_one = @intCast(u8, prefixed_symbol_count - 1),
         .nodes = weight_sorted_prefixed_symbols,
@@ -1148,8 +1159,8 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !Literals.HuffmanT
 
 fn lessThanByWeight(
     weights: [256]u4,
-    lhs: Literals.HuffmanTree.PrefixedSymbol,
-    rhs: Literals.HuffmanTree.PrefixedSymbol,
+    lhs: LiteralsSection.HuffmanTree.PrefixedSymbol,
+    rhs: LiteralsSection.HuffmanTree.PrefixedSymbol,
 ) bool {
     // NOTE: this function relies on the use of a stable sorting algorithm,
     //       otherwise a special case of if (weights[lhs] == weights[rhs]) return lhs < rhs;
@@ -1157,11 +1168,11 @@ fn lessThanByWeight(
     return weights[lhs.symbol] < weights[rhs.symbol];
 }
 
-pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) Literals.Header {
+pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) LiteralsSection.Header {
     // TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks)
     const start = consumed_count.*;
     const byte0 = src[0];
-    const block_type = @intToEnum(Literals.BlockType, byte0 & 0b11);
+    const block_type = @intToEnum(LiteralsSection.BlockType, byte0 & 0b11);
     const size_format = @intCast(u2, (byte0 & 0b1100) >> 2);
     var regenerated_size: u20 = undefined;
     var compressed_size: ?u18 = null;
@@ -1220,7 +1231,7 @@ pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) Literals.He
             compressed_size,
         },
     );
-    return Literals.Header{
+    return LiteralsSection.Header{
         .block_type = block_type,
         .size_format = size_format,
         .regenerated_size = regenerated_size,
@@ -1228,7 +1239,7 @@ pub fn decodeLiteralsHeader(src: []const u8, consumed_count: *usize) Literals.He
     };
 }
 
-fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !Sequences.Header {
+pub fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !SequencesSection.Header {
     var sequence_count: u24 = undefined;
 
     var bytes_read: usize = 0;
@@ -1237,7 +1248,7 @@ fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !Sequences.Hea
         bytes_read += 1;
         log.debug("decoded sequences header '{}': sequence count = 0", .{std.fmt.fmtSliceHexUpper(src[0..bytes_read])});
         consumed_count.* += bytes_read;
-        return Sequences.Header{
+        return SequencesSection.Header{
             .sequence_count = 0,
             .offsets = undefined,
             .match_lengths = undefined,
@@ -1258,9 +1269,9 @@ fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !Sequences.Hea
     bytes_read += 1;
 
     consumed_count.* += bytes_read;
-    const matches_mode = @intToEnum(Sequences.Header.Mode, (compression_modes & 0b00001100) >> 2);
-    const offsets_mode = @intToEnum(Sequences.Header.Mode, (compression_modes & 0b00110000) >> 4);
-    const literal_mode = @intToEnum(Sequences.Header.Mode, (compression_modes & 0b11000000) >> 6);
+    const matches_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b00001100) >> 2);
+    const offsets_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b00110000) >> 4);
+    const literal_mode = @intToEnum(SequencesSection.Header.Mode, (compression_modes & 0b11000000) >> 6);
     log.debug("decoded sequences header '{}': (sc={d},o={s},m={s},l={s})", .{
         std.fmt.fmtSliceHexUpper(src[0..bytes_read]),
         sequence_count,
@@ -1270,7 +1281,7 @@ fn decodeSequencesHeader(src: []const u8, consumed_count: *usize) !Sequences.Hea
     });
     if (compression_modes & 0b11 != 0) return error.ReservedBitSet;
 
-    return Sequences.Header{
+    return SequencesSection.Header{
         .sequence_count = sequence_count,
         .offsets = offsets_mode,
         .match_lengths = matches_mode,
@@ -1428,25 +1439,25 @@ const ReversedByteReader = struct {
     }
 };
 
-const ReverseBitReader = struct {
+pub const ReverseBitReader = struct {
     byte_reader: ReversedByteReader,
     bit_reader: std.io.BitReader(.Big, ReversedByteReader.Reader),
 
-    fn init(self: *ReverseBitReader, bytes: []const u8) !void {
+    pub fn init(self: *ReverseBitReader, bytes: []const u8) !void {
         self.byte_reader = ReversedByteReader.init(bytes);
         self.bit_reader = std.io.bitReader(.Big, self.byte_reader.reader());
         while (0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) {}
     }
 
-    fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U {
+    pub fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U {
         return self.bit_reader.readBitsNoEof(U, num_bits);
     }
 
-    fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U {
+    pub fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U {
         return try self.bit_reader.readBits(U, num_bits, out_bits);
     }
 
-    fn alignToByte(self: *@This()) void {
+    pub fn alignToByte(self: *@This()) void {
         self.bit_reader.alignToByte();
     }
 };
@@ -1514,7 +1525,7 @@ fn dumpFseTable(prefix: []const u8, table: []const Table.Fse) void {
     }
 }
 
-fn dumpHuffmanTree(tree: Literals.HuffmanTree) void {
+fn dumpHuffmanTree(tree: LiteralsSection.HuffmanTree) void {
     log.debug("Huffman tree: max bit count = {}, symbol count = {}", .{ tree.max_bit_count, tree.symbol_count_minus_one + 1 });
     for (tree.nodes[0 .. tree.symbol_count_minus_one + 1]) |node| {
         log.debug("symbol = {[symbol]d}, prefix = {[prefix]d}, weight = {[weight]d}", node);
lib/std/compress/zstandard/types.zig
@@ -52,7 +52,7 @@ pub const frame = struct {
 };
 
 pub const compressed_block = struct {
-    pub const Literals = struct {
+    pub const LiteralsSection = struct {
         header: Header,
         huffman_tree: ?HuffmanTree,
         streams: Streams,
@@ -119,8 +119,8 @@ pub const compressed_block = struct {
         }
     };
 
-    pub const Sequences = struct {
-        header: Sequences.Header,
+    pub const SequencesSection = struct {
+        header: SequencesSection.Header,
         literals_length_table: Table,
         offset_table: Table,
         match_length_table: Table,
lib/std/compress/zstandard.zig
@@ -1,6 +1,7 @@
 const std = @import("std");
 
 pub const decompress = @import("zstandard/decompress.zig");
+pub usingnamespace @import("zstandard/types.zig");
 
 test "decompression" {
     const uncompressed = @embedFile("testdata/rfc8478.txt");