Commit 947ad3e268

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-01-31 03:24:27
std.compress.zstandard: add FrameContext and add literals into DecodeState
1 parent 2d35c16
Changed files (2)
lib
std
compress
lib/std/compress/zstandard/decompress.zig
@@ -81,6 +81,8 @@ pub const DecodeState = struct {
 
     literal_stream_reader: ReverseBitReader,
     literal_stream_index: usize,
+    literal_streams: LiteralsSection.Streams,
+    literal_header: LiteralsSection.Header,
     huffman_tree: ?LiteralsSection.HuffmanTree,
 
     literal_written_count: usize,
@@ -105,6 +107,10 @@ pub const DecodeState = struct {
         literals: LiteralsSection,
         sequences_header: SequencesSection.Header,
     ) (error{ BitStreamHasNoStartBit, TreelessLiteralsFirst } || FseTableError)!usize {
+        self.literal_written_count = 0;
+        self.literal_header = literals.header;
+        self.literal_streams = literals.streams;
+
         if (literals.huffman_tree) |tree| {
             self.huffman_tree = tree;
         } else if (literals.header.block_type == .treeless and self.huffman_tree == null) {
@@ -293,12 +299,11 @@ pub const DecodeState = struct {
         self: *DecodeState,
         dest: []u8,
         write_pos: usize,
-        literals: LiteralsSection,
         sequence: Sequence,
     ) (error{MalformedSequence} || DecodeLiteralsError)!void {
         if (sequence.offset > write_pos + sequence.literal_length) return error.MalformedSequence;
 
-        try self.decodeLiteralsSlice(dest[write_pos..], literals, sequence.literal_length);
+        try self.decodeLiteralsSlice(dest[write_pos..], sequence.literal_length);
         const copy_start = write_pos + sequence.literal_length - sequence.offset;
         const copy_end = copy_start + sequence.match_length;
         // NOTE: we ignore the usage message for std.mem.copy and copy with dest.ptr >= src.ptr
@@ -309,12 +314,11 @@ pub const DecodeState = struct {
     fn executeSequenceRingBuffer(
         self: *DecodeState,
         dest: *RingBuffer,
-        literals: LiteralsSection,
         sequence: Sequence,
     ) (error{MalformedSequence} || DecodeLiteralsError)!void {
         if (sequence.offset > dest.data.len) return error.MalformedSequence;
 
-        try self.decodeLiteralsRingBuffer(dest, literals, sequence.literal_length);
+        try self.decodeLiteralsRingBuffer(dest, sequence.literal_length);
         const copy_start = dest.write_index + dest.data.len - sequence.offset;
         const copy_slice = dest.sliceAt(copy_start, sequence.match_length);
         // TODO: would std.mem.copy and figuring out dest slice be better/faster?
@@ -328,6 +332,7 @@ pub const DecodeState = struct {
         MalformedSequence,
         MalformedFseBits,
     } || DecodeLiteralsError;
+
     /// Decode one sequence from `bit_reader` into `dest`, written starting at
     /// `write_pos` and update FSE states if `last_sequence` is `false`. Returns
     /// `error.MalformedSequence` error if the decompressed sequence would be longer
@@ -340,7 +345,6 @@ pub const DecodeState = struct {
         self: *DecodeState,
         dest: []u8,
         write_pos: usize,
-        literals: LiteralsSection,
         bit_reader: *ReverseBitReader,
         sequence_size_limit: usize,
         last_sequence: bool,
@@ -349,7 +353,7 @@ pub const DecodeState = struct {
         const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length;
         if (sequence_length > sequence_size_limit) return error.MalformedSequence;
 
-        try self.executeSequenceSlice(dest, write_pos, literals, sequence);
+        try self.executeSequenceSlice(dest, write_pos, sequence);
         if (!last_sequence) {
             try self.updateState(.literal, bit_reader);
             try self.updateState(.match, bit_reader);
@@ -362,7 +366,6 @@ pub const DecodeState = struct {
     pub fn decodeSequenceRingBuffer(
         self: *DecodeState,
         dest: *RingBuffer,
-        literals: LiteralsSection,
         bit_reader: anytype,
         sequence_size_limit: usize,
         last_sequence: bool,
@@ -371,7 +374,7 @@ pub const DecodeState = struct {
         const sequence_length = @as(usize, sequence.literal_length) + sequence.match_length;
         if (sequence_length > sequence_size_limit) return error.MalformedSequence;
 
-        try self.executeSequenceRingBuffer(dest, literals, sequence);
+        try self.executeSequenceRingBuffer(dest, sequence);
         if (!last_sequence) {
             try self.updateState(.literal, bit_reader);
             try self.updateState(.match, bit_reader);
@@ -382,13 +385,12 @@ pub const DecodeState = struct {
 
     fn nextLiteralMultiStream(
         self: *DecodeState,
-        literals: LiteralsSection,
     ) error{BitStreamHasNoStartBit}!void {
         self.literal_stream_index += 1;
-        try self.initLiteralStream(literals.streams.four[self.literal_stream_index]);
+        try self.initLiteralStream(self.literal_streams.four[self.literal_stream_index]);
     }
 
-    fn initLiteralStream(self: *DecodeState, bytes: []const u8) error{BitStreamHasNoStartBit}!void {
+    pub fn initLiteralStream(self: *DecodeState, bytes: []const u8) error{BitStreamHasNoStartBit}!void {
         try self.literal_stream_reader.init(bytes);
     }
 
@@ -400,11 +402,10 @@ pub const DecodeState = struct {
         self: *DecodeState,
         comptime T: type,
         bit_count_to_read: usize,
-        literals: LiteralsSection,
     ) LiteralBitsError!T {
         return self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch bits: {
-            if (literals.streams == .four and self.literal_stream_index < 3) {
-                try self.nextLiteralMultiStream(literals);
+            if (self.literal_streams == .four and self.literal_stream_index < 3) {
+                try self.nextLiteralMultiStream();
                 break :bits self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch
                     return error.UnexpectedEndOfLiteralStream;
             } else {
@@ -427,23 +428,22 @@ pub const DecodeState = struct {
     pub fn decodeLiteralsSlice(
         self: *DecodeState,
         dest: []u8,
-        literals: LiteralsSection,
         len: usize,
     ) DecodeLiteralsError!void {
-        if (self.literal_written_count + len > literals.header.regenerated_size)
+        if (self.literal_written_count + len > self.literal_header.regenerated_size)
             return error.MalformedLiteralsLength;
 
-        switch (literals.header.block_type) {
+        switch (self.literal_header.block_type) {
             .raw => {
                 const literals_end = self.literal_written_count + len;
-                const literal_data = literals.streams.one[self.literal_written_count..literals_end];
+                const literal_data = self.literal_streams.one[self.literal_written_count..literals_end];
                 std.mem.copy(u8, dest, literal_data);
                 self.literal_written_count += len;
             },
             .rle => {
                 var i: usize = 0;
                 while (i < len) : (i += 1) {
-                    dest[i] = literals.streams.one[0];
+                    dest[i] = self.literal_streams.one[0];
                 }
                 self.literal_written_count += len;
             },
@@ -462,7 +462,7 @@ pub const DecodeState = struct {
                 while (i < len) : (i += 1) {
                     var prefix: u16 = 0;
                     while (true) {
-                        const new_bits = try self.readLiteralsBits(u16, bit_count_to_read, literals);
+                        const new_bits = try self.readLiteralsBits(u16, bit_count_to_read);
                         prefix <<= bit_count_to_read;
                         prefix |= new_bits;
                         bits_read += bit_count_to_read;
@@ -496,23 +496,22 @@ pub const DecodeState = struct {
     pub fn decodeLiteralsRingBuffer(
         self: *DecodeState,
         dest: *RingBuffer,
-        literals: LiteralsSection,
         len: usize,
     ) DecodeLiteralsError!void {
-        if (self.literal_written_count + len > literals.header.regenerated_size)
+        if (self.literal_written_count + len > self.literal_header.regenerated_size)
             return error.MalformedLiteralsLength;
 
-        switch (literals.header.block_type) {
+        switch (self.literal_header.block_type) {
             .raw => {
                 const literals_end = self.literal_written_count + len;
-                const literal_data = literals.streams.one[self.literal_written_count..literals_end];
+                const literal_data = self.literal_streams.one[self.literal_written_count..literals_end];
                 dest.writeSliceAssumeCapacity(literal_data);
                 self.literal_written_count += len;
             },
             .rle => {
                 var i: usize = 0;
                 while (i < len) : (i += 1) {
-                    dest.writeAssumeCapacity(literals.streams.one[0]);
+                    dest.writeAssumeCapacity(self.literal_streams.one[0]);
                 }
                 self.literal_written_count += len;
             },
@@ -531,7 +530,7 @@ pub const DecodeState = struct {
                 while (i < len) : (i += 1) {
                     var prefix: u16 = 0;
                     while (true) {
-                        const new_bits = try self.readLiteralsBits(u16, bit_count_to_read, literals);
+                        const new_bits = try self.readLiteralsBits(u16, bit_count_to_read);
                         prefix <<= bit_count_to_read;
                         prefix |= new_bits;
                         bits_read += bit_count_to_read;
@@ -569,10 +568,6 @@ pub const DecodeState = struct {
     }
 };
 
-const literal_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.literal;
-const match_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.match;
-const offset_table_size_max = 1 << types.compressed_block.table_accuracy_log_max.match;
-
 pub fn computeChecksum(hasher: *std.hash.XxHash64) u32 {
     const hash = hasher.final();
     return @intCast(u32, hash & 0xFFFFFFFF);
@@ -625,6 +620,31 @@ pub fn decodeZStandardFrame(
     return ReadWriteCount{ .read_count = consumed_count, .write_count = written_count };
 }
 
+pub const FrameContext = struct {
+    hasher_opt: ?std.hash.XxHash64,
+    window_size: usize,
+    has_checksum: bool,
+    block_size_max: usize,
+
+    pub fn init(frame_header: frame.ZStandard.Header, window_size_max: usize, verify_checksum: bool) !FrameContext {
+        if (frame_header.descriptor.dictionary_id_flag != 0) return error.DictionaryIdFlagUnsupported;
+
+        const window_size_raw = frameWindowSize(frame_header) orelse return error.WindowSizeUnknown;
+        const window_size = if (window_size_raw > window_size_max)
+            return error.WindowTooLarge
+        else
+            @intCast(usize, window_size_raw);
+
+        const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum;
+        return .{
+            .hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null,
+            .window_size = window_size,
+            .has_checksum = frame_header.descriptor.content_checksum_flag,
+            .block_size_max = @min(1 << 17, window_size),
+        };
+    }
+};
+
 /// Decode a Zstandard from from `src` and return the decompressed bytes; see
 /// `decodeZStandardFrame()`. Returns `error.WindowSizeUnknown` if the frame
 /// does not declare its content size or a window descriptor (this indicates a
@@ -639,33 +659,18 @@ pub fn decodeZStandardFrameAlloc(
     assert(readInt(u32, src[0..4]) == frame.ZStandard.magic_number);
     var consumed_count: usize = 4;
 
-    const frame_header = try decodeZStandardHeader(src[consumed_count..], &consumed_count);
-
-    if (frame_header.descriptor.dictionary_id_flag != 0) return error.DictionaryIdFlagUnsupported;
-
-    const window_size_raw = frameWindowSize(frame_header) orelse return error.WindowSizeUnknown;
-    const window_size = if (window_size_raw > window_size_max)
-        return error.WindowTooLarge
-    else
-        @intCast(usize, window_size_raw);
-
-    const should_compute_checksum = frame_header.descriptor.content_checksum_flag and verify_checksum;
-    var hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null;
-
-    const block_size_maximum = @min(1 << 17, window_size);
-
-    var window_data = try allocator.alloc(u8, window_size);
-    defer allocator.free(window_data);
-    var ring_buffer = RingBuffer{
-        .data = window_data,
-        .write_index = 0,
-        .read_index = 0,
+    var frame_context = context: {
+        const frame_header = try decodeZStandardHeader(src[consumed_count..], &consumed_count);
+        break :context try FrameContext.init(frame_header, window_size_max, verify_checksum);
     };
 
+    var ring_buffer = try RingBuffer.init(allocator, frame_context.window_size);
+    defer ring_buffer.deinit(allocator);
+
     // These tables take 7680 bytes
-    var literal_fse_data: [literal_table_size_max]Table.Fse = undefined;
-    var match_fse_data: [match_table_size_max]Table.Fse = undefined;
-    var offset_fse_data: [offset_table_size_max]Table.Fse = undefined;
+    var literal_fse_data: [types.compressed_block.table_size_max.literal]Table.Fse = undefined;
+    var match_fse_data: [types.compressed_block.table_size_max.match]Table.Fse = undefined;
+    var offset_fse_data: [types.compressed_block.table_size_max.offset]Table.Fse = undefined;
 
     var block_header = decodeBlockHeader(src[consumed_count..][0..3]);
     consumed_count += 3;
@@ -687,6 +692,8 @@ pub fn decodeZStandardFrameAlloc(
         .fse_tables_undefined = true,
 
         .literal_written_count = 0,
+        .literal_header = undefined,
+        .literal_streams = undefined,
         .literal_stream_reader = undefined,
         .literal_stream_index = undefined,
         .huffman_tree = null,
@@ -695,30 +702,29 @@ pub fn decodeZStandardFrameAlloc(
         block_header = decodeBlockHeader(src[consumed_count..][0..3]);
         consumed_count += 3;
     }) {
-        if (block_header.block_size > block_size_maximum) return error.BlockSizeOverMaximum;
+        if (block_header.block_size > frame_context.block_size_max) return error.BlockSizeOverMaximum;
         const written_size = try decodeBlockRingBuffer(
             &ring_buffer,
             src[consumed_count..],
             block_header,
             &decode_state,
             &consumed_count,
-            block_size_maximum,
+            frame_context.block_size_max,
         );
-        if (written_size > block_size_maximum) return error.BlockSizeOverMaximum;
         const written_slice = ring_buffer.sliceLast(written_size);
         try result.appendSlice(written_slice.first);
         try result.appendSlice(written_slice.second);
-        if (hasher_opt) |*hasher| {
+        if (frame_context.hasher_opt) |*hasher| {
             hasher.update(written_slice.first);
             hasher.update(written_slice.second);
         }
         if (block_header.last_block) break;
     }
 
-    if (frame_header.descriptor.content_checksum_flag) {
+    if (frame_context.has_checksum) {
         const checksum = readIntSlice(u32, src[consumed_count .. consumed_count + 4]);
         consumed_count += 4;
-        if (hasher_opt) |*hasher| {
+        if (frame_context.hasher_opt) |*hasher| {
             if (checksum != computeChecksum(hasher)) return error.ChecksumFailure;
         }
     }
@@ -741,9 +747,9 @@ pub fn decodeFrameBlocks(
     hash: ?*std.hash.XxHash64,
 ) DecodeBlockError!usize {
     // These tables take 7680 bytes
-    var literal_fse_data: [literal_table_size_max]Table.Fse = undefined;
-    var match_fse_data: [match_table_size_max]Table.Fse = undefined;
-    var offset_fse_data: [offset_table_size_max]Table.Fse = undefined;
+    var literal_fse_data: [types.compressed_block.table_size_max.literal]Table.Fse = undefined;
+    var match_fse_data: [types.compressed_block.table_size_max.match]Table.Fse = undefined;
+    var offset_fse_data: [types.compressed_block.table_size_max.offset]Table.Fse = undefined;
 
     var block_header = decodeBlockHeader(src[0..3]);
     var bytes_read: usize = 3;
@@ -766,6 +772,8 @@ pub fn decodeFrameBlocks(
         .fse_tables_undefined = true,
 
         .literal_written_count = 0,
+        .literal_header = undefined,
+        .literal_streams = undefined,
         .literal_stream_reader = undefined,
         .literal_stream_index = undefined,
         .huffman_tree = null,
@@ -867,7 +875,8 @@ pub fn decodeBlock(
         .compressed => {
             if (src.len < block_size) return error.MalformedBlockSize;
             var bytes_read: usize = 0;
-            const literals = decodeLiteralsSection(src, &bytes_read) catch return error.MalformedCompressedBlock;
+            const literals = decodeLiteralsSection(src, &bytes_read) catch
+                return error.MalformedCompressedBlock;
             const sequences_header = decodeSequencesHeader(src[bytes_read..], &bytes_read) catch
                 return error.MalformedCompressedBlock;
 
@@ -889,7 +898,6 @@ pub fn decodeBlock(
                     const decompressed_size = decode_state.decodeSequenceSlice(
                         dest,
                         write_pos,
-                        literals,
                         &bit_stream,
                         sequence_size_limit,
                         i == sequences_header.sequence_count - 1,
@@ -903,12 +911,11 @@ pub fn decodeBlock(
 
             if (decode_state.literal_written_count < literals.header.regenerated_size) {
                 const len = literals.header.regenerated_size - decode_state.literal_written_count;
-                decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], literals, len) catch
+                decode_state.decodeLiteralsSlice(dest[written_count + bytes_written ..], len) catch
                     return error.MalformedCompressedBlock;
                 bytes_written += len;
             }
 
-            decode_state.literal_written_count = 0;
             assert(bytes_read == block_header.block_size);
             consumed_count.* += bytes_read;
             return bytes_written;
@@ -936,7 +943,8 @@ pub fn decodeBlockRingBuffer(
         .compressed => {
             if (src.len < block_size) return error.MalformedBlockSize;
             var bytes_read: usize = 0;
-            const literals = decodeLiteralsSection(src, &bytes_read) catch return error.MalformedCompressedBlock;
+            const literals = decodeLiteralsSection(src, &bytes_read) catch
+                return error.MalformedCompressedBlock;
             const sequences_header = decodeSequencesHeader(src[bytes_read..], &bytes_read) catch
                 return error.MalformedCompressedBlock;
 
@@ -956,7 +964,6 @@ pub fn decodeBlockRingBuffer(
                 while (i < sequences_header.sequence_count) : (i += 1) {
                     const decompressed_size = decode_state.decodeSequenceRingBuffer(
                         dest,
-                        literals,
                         &bit_stream,
                         sequence_size_limit,
                         i == sequences_header.sequence_count - 1,
@@ -970,14 +977,14 @@ pub fn decodeBlockRingBuffer(
 
             if (decode_state.literal_written_count < literals.header.regenerated_size) {
                 const len = literals.header.regenerated_size - decode_state.literal_written_count;
-                decode_state.decodeLiteralsRingBuffer(dest, literals, len) catch
+                decode_state.decodeLiteralsRingBuffer(dest, len) catch
                     return error.MalformedCompressedBlock;
                 bytes_written += len;
             }
 
-            decode_state.literal_written_count = 0;
             assert(bytes_read == block_header.block_size);
             consumed_count.* += bytes_read;
+            if (bytes_written > block_size_max) return error.BlockSizeOverMaximum;
             return bytes_written;
         },
         .reserved => return error.ReservedBlock,
lib/std/compress/zstandard/types.zig
@@ -386,6 +386,11 @@ pub const compressed_block = struct {
         pub const match = 6;
         pub const offset = 5;
     };
+    pub const table_size_max = struct {
+        pub const literal = 1 << table_accuracy_log_max.literal;
+        pub const match = 1 << table_accuracy_log_max.match;
+        pub const offset = 1 << table_accuracy_log_max.match;
+    };
 };
 
 test {