Commit 31cc4605ab

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-02-09 10:15:00
std.compress.zstandard: fix errors and crashes in ZstandardStream
1 parent 55e6e94
Changed files (2)
lib
std
compress
zstandard
decode
lib/std/compress/zstandard/decode/block.zig
@@ -795,6 +795,7 @@ pub fn decodeBlockReader(
     if (block_size_max < block_size) return error.BlockSizeOverMaximum;
     switch (block_header.block_type) {
         .raw => {
+            if (block_size == 0) return;
             const slice = dest.sliceAt(dest.write_index, block_size);
             try source.readNoEof(slice.first);
             try source.readNoEof(slice.second);
lib/std/compress/zstandard.zig
@@ -12,12 +12,11 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool
         const Self = @This();
 
         allocator: Allocator,
-        in_reader: ReaderType,
-        state: enum { NewFrame, InFrame },
+        source: std.io.CountingReader(ReaderType),
+        state: enum { NewFrame, InFrame, LastBlock },
         decode_state: decompress.block.DecodeState,
         frame_context: decompress.FrameContext,
         buffer: RingBuffer,
-        last_block: bool,
         literal_fse_buffer: []types.compressed_block.Table.Fse,
         match_fse_buffer: []types.compressed_block.Table.Fse,
         offset_fse_buffer: []types.compressed_block.Table.Fse,
@@ -32,12 +31,11 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool
         pub fn init(allocator: Allocator, source: ReaderType) !Self {
             return Self{
                 .allocator = allocator,
-                .in_reader = source,
+                .source = std.io.countingReader(source),
                 .state = .NewFrame,
                 .decode_state = undefined,
                 .frame_context = undefined,
                 .buffer = undefined,
-                .last_block = undefined,
                 .literal_fse_buffer = undefined,
                 .match_fse_buffer = undefined,
                 .offset_fse_buffer = undefined,
@@ -48,22 +46,16 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool
         }
 
         fn frameInit(self: *Self) !void {
-            var bytes: [4]u8 = undefined;
-            const bytes_read = try self.in_reader.readAll(&bytes);
-            if (bytes_read == 0) return error.NoBytes;
-            if (bytes_read < 4) return error.EndOfStream;
-            const frame_type = try decompress.frameType(std.mem.readIntLittle(u32, &bytes));
-            switch (frame_type) {
-                .skippable => {
-                    const size = try self.in_reader.readIntLittle(u32);
-                    try self.in_reader.skipBytes(size, .{});
+            const source_reader = self.source.reader();
+            switch (try decompress.decodeFrameHeader(source_reader)) {
+                .skippable => |header| {
+                    try source_reader.skipBytes(header.frame_size, .{});
                     self.state = .NewFrame;
                 },
-                .zstandard => {
+                .zstandard => |header| {
                     const frame_context = context: {
-                        const frame_header = try decompress.decodeZstandardHeader(self.in_reader);
                         break :context try decompress.FrameContext.init(
-                            frame_header,
+                            header,
                             window_size_max,
                             verify_checksum,
                         );
@@ -112,7 +104,6 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool
                     self.frame_context = frame_context;
 
                     self.checksum = if (verify_checksum) null else {};
-                    self.last_block = false;
 
                     self.state = .InFrame;
                 },
@@ -134,10 +125,14 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool
         }
 
         pub fn read(self: *Self, buffer: []u8) Error!usize {
+            const initial_count = self.source.bytes_read;
             if (buffer.len == 0) return 0;
             while (self.state == .NewFrame) {
                 self.frameInit() catch |err| switch (err) {
-                    error.NoBytes => return 0,
+                    error.EndOfStream => return if (self.source.bytes_read == initial_count)
+                        0
+                    else
+                        error.MalformedFrame,
                     error.OutOfMemory => return error.OutOfMemory,
                     else => return error.MalformedFrame,
                 };
@@ -147,15 +142,16 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool
         }
 
         fn readInner(self: *Self, buffer: []u8) Error!usize {
-            std.debug.assert(self.state == .InFrame);
+            std.debug.assert(self.state != .NewFrame);
 
-            if (self.buffer.isEmpty() and !self.last_block) {
-                const header_bytes = self.in_reader.readBytesNoEof(3) catch return error.MalformedFrame;
+            const source_reader = self.source.reader();
+            while (self.buffer.isEmpty() and self.state != .LastBlock) {
+                const header_bytes = source_reader.readBytesNoEof(3) catch return error.MalformedFrame;
                 const block_header = decompress.block.decodeBlockHeader(&header_bytes);
 
                 decompress.block.decodeBlockReader(
                     &self.buffer,
-                    self.in_reader,
+                    source_reader,
                     block_header,
                     &self.decode_state,
                     self.frame_context.block_size_max,
@@ -164,15 +160,18 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool
                 ) catch
                     return error.MalformedBlock;
 
-                self.last_block = block_header.last_block;
                 if (self.frame_context.hasher_opt) |*hasher| {
-                    const written_slice = self.buffer.sliceLast(self.buffer.len());
-                    hasher.update(written_slice.first);
-                    hasher.update(written_slice.second);
+                    const size = self.buffer.len();
+                    if (size > 0) {
+                        const written_slice = self.buffer.sliceLast(size);
+                        hasher.update(written_slice.first);
+                        hasher.update(written_slice.second);
+                    }
                 }
                 if (block_header.last_block) {
+                    self.state = .LastBlock;
                     if (self.frame_context.has_checksum) {
-                        const checksum = self.in_reader.readIntLittle(u32) catch return error.MalformedFrame;
+                        const checksum = source_reader.readIntLittle(u32) catch return error.MalformedFrame;
                         if (comptime verify_checksum) {
                             if (self.frame_context.hasher_opt) |*hasher| {
                                 if (checksum != decompress.computeChecksum(hasher)) return error.ChecksumFailure;
@@ -187,7 +186,7 @@ pub fn ZstandardStream(comptime ReaderType: type, comptime verify_checksum: bool
             while (written_count < decoded_data_len and written_count < buffer.len) : (written_count += 1) {
                 buffer[written_count] = self.buffer.read().?;
             }
-            if (self.buffer.len() == 0) {
+            if (self.state == .LastBlock and self.buffer.len() == 0) {
                 self.state = .NewFrame;
                 self.allocator.free(self.literal_fse_buffer);
                 self.allocator.free(self.match_fse_buffer);
@@ -219,7 +218,7 @@ fn testReader(data: []const u8, comptime expected: []const u8) !void {
     try std.testing.expectEqualSlices(u8, expected, buf);
 }
 
-test "decompression" {
+test "zstandard decompression" {
     const uncompressed = @embedFile("testdata/rfc8478.txt");
     const compressed3 = @embedFile("testdata/rfc8478.txt.zst.3");
     const compressed19 = @embedFile("testdata/rfc8478.txt.zst.19");