Commit ddeabc9aa7

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-02-02 12:06:23
std.compress.zstandard: add `decodeFrameAlloc()`
1 parent 89f9c5c
Changed files (1)
lib
std
compress
zstandard
lib/std/compress/zstandard/decompress.zig
@@ -1,5 +1,6 @@
 const std = @import("std");
 const assert = std.debug.assert;
+const Allocator = std.mem.Allocator;
 
 const types = @import("types.zig");
 const frame = types.frame;
@@ -32,6 +33,14 @@ pub fn isSkippableMagic(magic: u32) bool {
 ///   - `error.EndOfStream` if `source` contains fewer than 4 bytes
 pub fn decodeFrameType(source: anytype) error{ BadMagic, EndOfStream }!frame.Kind {
     const magic = try source.readIntLittle(u32);
+    return frameType(magic);
+}
+
+/// Returns the kind of frame associated to `magic`.
+///
+/// Errors returned:
+///   - `error.BadMagic` if `magic` is not a valid magic number.
+pub fn frameType(magic: u32) error{BadMagic}!frame.Kind {
     return if (magic == frame.ZStandard.magic_number)
         .zstandard
     else if (isSkippableMagic(magic))
@@ -78,6 +87,56 @@ pub fn decodeFrame(
     };
 }
 
+pub const DecodeResult = struct {
+    bytes: []u8,
+    read_count: usize,
+};
+pub const DecodedFrame = union(enum) {
+    zstandard: DecodeResult,
+    skippable: frame.Skippable.Header,
+};
+
+/// Decodes the frame at the start of `src` into `dest`. Returns the number of
+/// bytes read from `src` and the decoded bytes for a Zstandard frame, or the
+/// frame header for a Skippable frame.
+///
+/// Errors returned:
+///   - `error.BadMagic` if the first 4 bytes of `src` is not a valid magic
+///     number for a Zstandard or Skippable frame
+///   - `error.WindowSizeUnknown` if the frame does not have a valid window size
+///   - `error.WindowTooLarge` if the window size is larger than
+///     `window_size_max`
+///   - `error.DictionaryIdFlagUnsupported` if the frame uses a dictionary
+///   - `error.ChecksumFailure` if `verify_checksum` is true and the frame
+///     contains a checksum that does not match the checksum of the decompressed
+///     data
+///   - `error.ReservedBitSet` if the reserved bit of the frame header is set
+///   - `error.UnusedBitSet` if the unused bit of the frame header is set
+///   - `error.EndOfStream` if `src` does not contain a complete frame
+///   - `error.OutOfMemory` if `allocator` cannot allocate enough memory
+///   - an error in `block.Error` if there are errors decoding a block
+pub fn decodeFrameAlloc(
+    allocator: Allocator,
+    src: []const u8,
+    verify_checksum: bool,
+    window_size_max: usize,
+) !DecodedFrame {
+    var fbs = std.io.fixedBufferStream(src);
+    const reader = fbs.reader();
+    const magic = try reader.readIntLittle(u32);
+    return switch (try frameType(magic)) {
+        .zstandard => .{
+            .zstandard = try decodeZStandardFrameAlloc(allocator, src, verify_checksum, window_size_max),
+        },
+        .skippable => .{
+            .skippable = .{
+                .magic_number = magic,
+                .frame_size = try reader.readIntLittle(u32),
+            },
+        },
+    };
+}
+
 /// Returns the frame checksum corresponding to the data fed into `hasher`
 pub fn computeChecksum(hasher: *std.hash.XxHash64) u32 {
     const hash = hasher.final();
@@ -181,10 +240,11 @@ pub const FrameContext = struct {
     }
 };
 
-/// Decode a Zstandard from from `src` and return the decompressed bytes; see
-/// `decodeZStandardFrame()`. `allocator` is used to allocate both the returned
-/// slice and internal buffers used during decoding. The first four bytes of
-/// `src` must be the magic number for a Zstandard frame.
+/// Decode a Zstandard from from `src` and return the decompressed bytes and the
+/// number of bytes read; see `decodeZStandardFrame()`. `allocator` is used to
+/// allocate both the returned slice and internal buffers used during decoding.
+/// The first four bytes of `src` must be the magic number for a Zstandard
+/// frame.
 ///
 /// Errors returned:
 ///   - `error.WindowSizeUnknown` if the frame does not have a valid window size
@@ -200,11 +260,11 @@ pub const FrameContext = struct {
 ///   - `error.OutOfMemory` if `allocator` cannot allocate enough memory
 ///   - an error in `block.Error` if there are errors decoding a block
 pub fn decodeZStandardFrameAlloc(
-    allocator: std.mem.Allocator,
+    allocator: Allocator,
     src: []const u8,
     verify_checksum: bool,
     window_size_max: usize,
-) (error{OutOfMemory} || FrameContext.Error || FrameError)![]u8 {
+) (error{OutOfMemory} || FrameContext.Error || FrameError)!DecodeResult {
     var result = std.ArrayList(u8).init(allocator);
     assert(readInt(u32, src[0..4]) == frame.ZStandard.magic_number);
     var consumed_count: usize = 4;
@@ -258,7 +318,7 @@ pub fn decodeZStandardFrameAlloc(
             if (checksum != computeChecksum(hasher)) return error.ChecksumFailure;
         }
     }
-    return result.toOwnedSlice();
+    return DecodeResult{ .bytes = try result.toOwnedSlice(), .read_count = consumed_count };
 }
 
 /// Convenience wrapper for decoding all blocks in a frame; see `decodeBlock()`.