Commit 18091723d5

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-01-22 03:32:16
std.compress.zstandard: cleanup decodeBlock
1 parent 61cb514
Changed files (1)
lib
std
compress
zstandard
lib/std/compress/zstandard/decompress.zig
@@ -65,12 +65,14 @@ const DecodeState = struct {
     match_fse_buffer: []Table.Fse,
     literal_fse_buffer: []Table.Fse,
 
-    literal_written_count: usize,
+    fse_tables_undefined: bool,
 
     literal_stream_reader: ReverseBitReader(ReversedByteReader.Reader),
     literal_stream_bytes: ReversedByteReader,
     literal_stream_index: usize,
-    huffman_tree: Literals.HuffmanTree,
+    huffman_tree: ?Literals.HuffmanTree,
+
+    literal_written_count: usize,
 
     fn StateData(comptime max_accuracy_log: comptime_int) type {
         return struct {
@@ -129,7 +131,6 @@ const DecodeState = struct {
         src: []const u8,
         comptime choice: DataType,
         mode: Sequences.Header.Mode,
-        first_compressed_block: bool,
     ) !usize {
         const field_name = @tagName(choice);
         switch (mode) {
@@ -162,7 +163,7 @@ const DecodeState = struct {
                 dumpFseTable(field_name, @field(self, field_name).table.fse);
                 return counting_reader.bytes_read;
             },
-            .repeat => return if (first_compressed_block) error.RepeatModeFirst else 0,
+            .repeat => return if (self.fse_tables_undefined) error.RepeatModeFirst else 0,
         }
     }
 
@@ -275,7 +276,7 @@ const DecodeState = struct {
             },
             .compressed, .treeless => {
                 // const written_bytes_per_stream = (literals.header.regenerated_size + 3) / 4;
-                const huffman_tree = self.huffman_tree;
+                const huffman_tree = self.huffman_tree orelse unreachable;
                 const max_bit_count = huffman_tree.max_bit_count;
                 const starting_bit_count = Literals.HuffmanTree.weightToBitCount(
                     huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight,
@@ -399,14 +400,14 @@ pub fn decodeFrameBlocks(dest: []u8, src: []const u8, consumed_count: *usize, ha
         .match_fse_buffer = &match_fse_data,
         .offset_fse_buffer = &offset_fse_data,
 
+        .fse_tables_undefined = true,
+
         .literal_written_count = 0,
         .literal_stream_reader = undefined,
         .literal_stream_bytes = undefined,
         .literal_stream_index = undefined,
-        .huffman_tree = undefined,
+        .huffman_tree = null,
     };
-    var first_compressed_block = true;
-    var first_compressed_literals = true;
     var written_count: usize = 0;
     while (true) : ({
         block_header = decodeBlockHeader(src[bytes_read..][0..3]);
@@ -417,8 +418,6 @@ pub fn decodeFrameBlocks(dest: []u8, src: []const u8, consumed_count: *usize, ha
             src[bytes_read..],
             block_header,
             &decode_state,
-            &first_compressed_block,
-            &first_compressed_literals,
             &bytes_read,
             written_count,
         );
@@ -430,13 +429,77 @@ pub fn decodeFrameBlocks(dest: []u8, src: []const u8, consumed_count: *usize, ha
     return written_count;
 }
 
+fn decodeRawBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: *usize) usize {
+    log.debug("writing raw block - size {d}", .{block_size});
+    const data = src[0..block_size];
+    std.mem.copy(u8, dest, data);
+    consumed_count.* += block_size;
+    return block_size;
+}
+
+fn decodeRleBlock(dest: []u8, src: []const u8, block_size: u21, consumed_count: *usize) usize {
+    log.debug("writing rle block - '{x}'x{d}", .{ src[0], block_size });
+    var write_pos: usize = 0;
+    while (write_pos < block_size) : (write_pos += 1) {
+        dest[write_pos] = src[0];
+    }
+    consumed_count.* += 1;
+    return block_size;
+}
+
+fn prepareDecodeState(
+    decode_state: *DecodeState,
+    src: []const u8,
+    literals: Literals,
+    sequences_header: Sequences.Header,
+) !usize {
+    if (literals.huffman_tree) |tree| {
+        decode_state.huffman_tree = tree;
+    } else if (literals.header.block_type == .treeless and decode_state.huffman_tree == null) {
+        return error.TreelessLiteralsFirst;
+    }
+
+    switch (literals.header.block_type) {
+        .raw, .rle => {},
+        .compressed, .treeless => {
+            decode_state.literal_stream_index = 0;
+            switch (literals.streams) {
+                .one => |slice| try decode_state.initLiteralStream(slice),
+                .four => |streams| try decode_state.initLiteralStream(streams[0]),
+            }
+        },
+    }
+
+    if (sequences_header.sequence_count > 0) {
+        var bytes_read = try decode_state.updateFseTable(
+            src,
+            .literal,
+            sequences_header.literal_lengths,
+        );
+
+        bytes_read += try decode_state.updateFseTable(
+            src[bytes_read..],
+            .offset,
+            sequences_header.offsets,
+        );
+
+        bytes_read += try decode_state.updateFseTable(
+            src[bytes_read..],
+            .match,
+            sequences_header.match_lengths,
+        );
+        decode_state.fse_tables_undefined = false;
+
+        return bytes_read;
+    }
+    return 0;
+}
+
 pub fn decodeBlock(
     dest: []u8,
     src: []const u8,
     block_header: frame.ZStandard.Block.Header,
     decode_state: *DecodeState,
-    first_compressed_block: *bool,
-    first_compressed_literals: *bool,
     consumed_count: *usize,
     written_count: usize,
 ) !usize {
@@ -445,69 +508,14 @@ pub fn decodeBlock(
     if (block_maximum_size < block_size) return error.BlockSizeOverMaximum;
     // TODO: we probably want to enable safety for release-fast and release-small (or insert custom checks)
     switch (block_header.block_type) {
-        .raw => {
-            log.debug("writing raw block - size {d}", .{block_size});
-            const data = src[0..block_size];
-            std.mem.copy(u8, dest[written_count..], data);
-            consumed_count.* += block_size;
-            return block_size;
-        },
-        .rle => {
-            log.debug("writing rle block - '{x}'x{d}", .{ src[0], block_size });
-            var write_pos: usize = written_count;
-            while (write_pos < block_size + written_count) : (write_pos += 1) {
-                dest[write_pos] = src[0];
-            }
-            consumed_count.* += 1;
-            return block_size;
-        },
+        .raw => return decodeRawBlock(dest[written_count..], src, block_size, consumed_count),
+        .rle => return decodeRleBlock(dest[written_count..], src, block_size, consumed_count),
         .compressed => {
             var bytes_read: usize = 0;
             const literals = try decodeLiteralsSection(src, &bytes_read);
             const sequences_header = try decodeSequencesHeader(src[bytes_read..], &bytes_read);
 
-            if (first_compressed_literals.* and literals.header.block_type == .treeless)
-                return error.TreelessLiteralsFirst;
-
-            if (literals.huffman_tree) |tree| {
-                decode_state.huffman_tree = tree;
-                first_compressed_literals.* = false;
-            }
-
-            switch (literals.header.block_type) {
-                .raw, .rle => {},
-                .compressed, .treeless => {
-                    decode_state.literal_stream_index = 0;
-                    switch (literals.streams) {
-                        .one => |slice| try decode_state.initLiteralStream(slice),
-                        .four => |streams| try decode_state.initLiteralStream(streams[0]),
-                    }
-                },
-            }
-
-            if (sequences_header.sequence_count > 0) {
-                bytes_read += try decode_state.updateFseTable(
-                    src[bytes_read..],
-                    .literal,
-                    sequences_header.literal_lengths,
-                    first_compressed_block.*,
-                );
-
-                bytes_read += try decode_state.updateFseTable(
-                    src[bytes_read..],
-                    .offset,
-                    sequences_header.offsets,
-                    first_compressed_block.*,
-                );
-
-                bytes_read += try decode_state.updateFseTable(
-                    src[bytes_read..],
-                    .match,
-                    sequences_header.match_lengths,
-                    first_compressed_block.*,
-                );
-                first_compressed_block.* = false;
-            }
+            bytes_read += try prepareDecodeState(decode_state, src[bytes_read..], literals, sequences_header);
 
             var bytes_written: usize = 0;
             if (sequences_header.sequence_count > 0) {