Commit 476d2fe1fa

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-02-12 03:05:34
std.compress.zstandard: fix zstandardStream finishing early
1 parent 373d8ef
Changed files (1)
lib
std
compress
lib/std/compress/zstandard.zig
@@ -27,6 +27,7 @@ pub fn ZstandardStream(
         literals_buffer: []u8,
         sequence_buffer: []u8,
         checksum: if (verify_checksum) ?u32 else void,
+        current_frame_decompressed_size: usize,
 
         pub const Error = ReaderType.Error || error{
             ChecksumFailure,
@@ -51,6 +52,7 @@ pub fn ZstandardStream(
                 .literals_buffer = undefined,
                 .sequence_buffer = undefined,
                 .checksum = undefined,
+                .current_frame_decompressed_size = undefined,
             };
         }
 
@@ -113,6 +115,7 @@ pub fn ZstandardStream(
                     self.frame_context = frame_context;
 
                     self.checksum = if (verify_checksum) null else {};
+                    self.current_frame_decompressed_size = 0;
 
                     self.state = .InFrame;
                 },
@@ -134,20 +137,24 @@ pub fn ZstandardStream(
         }
 
         pub fn read(self: *Self, buffer: []u8) Error!usize {
-            const initial_count = self.source.bytes_read;
             if (buffer.len == 0) return 0;
-            while (self.state == .NewFrame) {
-                self.frameInit() catch |err| switch (err) {
-                    error.EndOfStream => return if (self.source.bytes_read == initial_count)
-                        0
-                    else
-                        error.MalformedFrame,
-                    error.OutOfMemory => return error.OutOfMemory,
-                    else => return error.MalformedFrame,
-                };
-            }
 
-            return self.readInner(buffer);
+            var size: usize = 0;
+            while (size == 0) {
+                while (self.state == .NewFrame) {
+                    const initial_count = self.source.bytes_read;
+                    self.frameInit() catch |err| switch (err) {
+                        error.EndOfStream => return if (self.source.bytes_read == initial_count)
+                            0
+                        else
+                            error.MalformedFrame,
+                        error.OutOfMemory => return error.OutOfMemory,
+                        else => return error.MalformedFrame,
+                    };
+                }
+                size = try self.readInner(buffer);
+            }
+            return size;
         }
 
         fn readInner(self: *Self, buffer: []u8) Error!usize {
@@ -172,6 +179,7 @@ pub fn ZstandardStream(
 
                 if (self.frame_context.hasher_opt) |*hasher| {
                     const size = self.buffer.len();
+                    self.current_frame_decompressed_size += size;
                     if (size > 0) {
                         const written_slice = self.buffer.sliceLast(size);
                         hasher.update(written_slice.first);
@@ -190,12 +198,17 @@ pub fn ZstandardStream(
                             }
                         }
                     }
+                    if (self.frame_context.content_size) |content_size| {
+                        if (content_size != self.current_frame_decompressed_size) {
+                            return error.MalformedFrame;
+                        }
+                    }
                 }
             }
 
-            const decoded_data_len = self.buffer.len();
+            const size = @min(self.buffer.len(), buffer.len);
             var count: usize = 0;
-            while (count < decoded_data_len and count < buffer.len) : (count += 1) {
+            while (count < size) : (count += 1) {
                 buffer[count] = self.buffer.read().?;
             }
             if (self.state == .LastBlock and self.buffer.len() == 0) {