Commit ee4f5b3f92

Andrew Kelley <andrew@ziglang.org>
2025-07-25 08:31:00
std.compress.zstd: respect the window length
1 parent 7f1c044
Changed files (2)
lib
std
lib/std/compress/zstd/Decompress.zig
@@ -10,6 +10,7 @@ input: *Reader,
 reader: Reader,
 state: State,
 verify_checksum: bool,
+window_len: u32,
 err: ?Error = null,
 
 const State = union(enum) {
@@ -29,6 +30,8 @@ pub const Options = struct {
     /// Verifying checksums is not implemented yet and will cause a panic if
     /// you set this to true.
     verify_checksum: bool = false,
+    /// Affects the minimum capacity of the provided buffer.
+    window_len: u32 = zstd.default_window_len,
 };
 
 pub const Error = error{
@@ -65,11 +68,14 @@ pub const Error = error{
     WindowSizeUnknown,
 };
 
+/// If buffer that is written to is not big enough, some streams will fail with
+/// `error.OutputBufferUndersize`. A safe value is `zstd.default_window_len * 2`.
 pub fn init(input: *Reader, buffer: []u8, options: Options) Decompress {
     return .{
         .input = input,
         .state = .new_frame,
         .verify_checksum = options.verify_checksum,
+        .window_len = options.window_len,
         .reader = .{
             .vtable = &.{ .stream = stream },
             .buffer = buffer,
@@ -143,6 +149,7 @@ fn initFrame(d: *Decompress, window_size_max: usize, magic: Frame.Magic) !void {
 
 fn readInFrame(d: *Decompress, w: *Writer, limit: Limit, state: *State.InFrame) !usize {
     const in = d.input;
+    const window_len = d.window_len;
 
     const header_bytes = try in.takeArray(3);
     const block_header: Frame.Zstandard.Block.Header = @bitCast(header_bytes.*);
@@ -153,12 +160,12 @@ fn readInFrame(d: *Decompress, w: *Writer, limit: Limit, state: *State.InFrame)
     var bytes_written: usize = 0;
     switch (block_header.type) {
         .raw => {
-            try in.streamExact(w, block_size);
+            try in.streamExactPreserve(w, window_len, block_size);
             bytes_written = block_size;
         },
         .rle => {
             const byte = try in.takeByte();
-            try w.splatByteAll(byte, block_size);
+            try w.splatBytePreserve(window_len, byte, block_size);
             bytes_written = block_size;
         },
         .compressed => {
@@ -167,7 +174,7 @@ fn readInFrame(d: *Decompress, w: *Writer, limit: Limit, state: *State.InFrame)
             var offset_fse_buffer: [zstd.table_size_max.offset]Table.Fse = undefined;
             var literals_buffer: [zstd.block_size_max]u8 = undefined;
             var sequence_buffer: [zstd.block_size_max]u8 = undefined;
-            var decode: Frame.Zstandard.Decode = .init(&literal_fse_buffer, &match_fse_buffer, &offset_fse_buffer);
+            var decode: Frame.Zstandard.Decode = .init(&literal_fse_buffer, &match_fse_buffer, &offset_fse_buffer, window_len);
             var remaining: Limit = .limited(block_size);
             const literals = try LiteralsSection.decode(in, &remaining, &literals_buffer);
             const sequences_header = try SequencesSection.Header.decode(in, &remaining);
@@ -185,15 +192,16 @@ fn readInFrame(d: *Decompress, w: *Writer, limit: Limit, state: *State.InFrame)
                     try decode.readInitialFseState(&bit_stream);
 
                     // Ensures the following calls to `decodeSequence` will not flush.
-                    if (frame_block_size_max > w.buffer.len) return error.OutputBufferUndersize;
-                    const dest = (try w.writableSliceGreedy(frame_block_size_max))[0..frame_block_size_max];
+                    if (window_len + frame_block_size_max > w.buffer.len) return error.OutputBufferUndersize;
+                    const dest = (try w.writableSliceGreedyPreserve(window_len, frame_block_size_max))[0..frame_block_size_max];
+                    const write_pos = dest.ptr - w.buffer.ptr;
                     for (0..sequences_header.sequence_count - 1) |_| {
-                        bytes_written += try decode.decodeSequence(dest, bytes_written, &bit_stream);
+                        bytes_written += try decode.decodeSequence(w.buffer, write_pos + bytes_written, &bit_stream);
                         try decode.updateState(.literal, &bit_stream);
                         try decode.updateState(.match, &bit_stream);
                         try decode.updateState(.offset, &bit_stream);
                     }
-                    bytes_written += try decode.decodeSequence(dest, bytes_written, &bit_stream);
+                    bytes_written += try decode.decodeSequence(w.buffer, write_pos + bytes_written, &bit_stream);
                     if (bytes_written > dest.len) return error.MalformedSequence;
                     w.advance(bytes_written);
                 }
@@ -363,6 +371,7 @@ pub const Frame = struct {
         };
 
         pub const Decode = struct {
+            window_len: u32,
             repeat_offsets: [3]u32,
 
             offset: StateData(8),
@@ -397,8 +406,10 @@ pub const Frame = struct {
                 literal_fse_buffer: []Table.Fse,
                 match_fse_buffer: []Table.Fse,
                 offset_fse_buffer: []Table.Fse,
+                window_len: u32,
             ) Decode {
                 return .{
+                    .window_len = window_len,
                     .repeat_offsets = .{
                         zstd.start_repeated_offset_1,
                         zstd.start_repeated_offset_2,
@@ -698,19 +709,19 @@ pub const Frame = struct {
                 };
             }
 
-            /// Decode `len` bytes of literals into `dest`.
-            fn decodeLiterals(self: *Decode, dest: *Writer, len: usize) !void {
-                switch (self.literal_header.block_type) {
+            /// Decode `len` bytes of literals into `w`.
+            fn decodeLiterals(d: *Decode, w: *Writer, len: usize) !void {
+                switch (d.literal_header.block_type) {
                     .raw => {
-                        try dest.writeAll(self.literal_streams.one[self.literal_written_count..][0..len]);
+                        try w.writeAll(d.literal_streams.one[d.literal_written_count..][0..len]);
                     },
                     .rle => {
-                        try dest.splatByteAll(self.literal_streams.one[0], len);
+                        try w.splatByteAll(d.literal_streams.one[0], len);
                     },
                     .compressed, .treeless => {
-                        if (len > dest.buffer.len) return error.OutputBufferUndersize;
-                        const buf = try dest.writableSlice(len);
-                        const huffman_tree = self.huffman_tree.?;
+                        if (len > w.buffer.len) return error.OutputBufferUndersize;
+                        const buf = try w.writableSlice(len);
+                        const huffman_tree = d.huffman_tree.?;
                         const max_bit_count = huffman_tree.max_bit_count;
                         const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
                             huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight,
@@ -722,7 +733,7 @@ pub const Frame = struct {
                         for (buf) |*out| {
                             var prefix: u16 = 0;
                             while (true) {
-                                const new_bits = try self.readLiteralsBits(bit_count_to_read);
+                                const new_bits = try d.readLiteralsBits(bit_count_to_read);
                                 prefix <<= bit_count_to_read;
                                 prefix |= new_bits;
                                 bits_read += bit_count_to_read;
lib/std/compress/zstd.zig
@@ -1,12 +1,11 @@
 const std = @import("../std.zig");
 const assert = std.debug.assert;
 
+pub const Decompress = @import("zstd/Decompress.zig");
+
 /// Recommended amount by the standard. Lower than this may result in inability
 /// to decompress common streams.
 pub const default_window_len = 8 * 1024 * 1024;
-
-pub const Decompress = @import("zstd/Decompress.zig");
-
 pub const block_size_max = 1 << 17;
 
 pub const literals_length_default_distribution = [36]i16{