Commit 32cf1d7cbf

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-02-21 07:14:45
std.compress.zstandard: fix error sets for streaming API
1 parent c6ef83e
Changed files (2)
lib
std
compress
zstandard
lib/std/compress/zstandard/decode/huffman.zig
@@ -15,7 +15,12 @@ pub const Error = error{
     EndOfStream,
 };
 
-fn decodeFseHuffmanTree(source: anytype, compressed_size: usize, buffer: []u8, weights: *[256]u4) !usize {
+fn decodeFseHuffmanTree(
+    source: anytype,
+    compressed_size: usize,
+    buffer: []u8,
+    weights: *[256]u4,
+) !usize {
     var stream = std.io.limitedReader(source, compressed_size);
     var bit_reader = readers.bitReader(stream.reader());
 
@@ -23,6 +28,7 @@ fn decodeFseHuffmanTree(source: anytype, compressed_size: usize, buffer: []u8, w
     const table_size = decodeFseTable(&bit_reader, 256, 6, &entries) catch |err| switch (err) {
         error.MalformedAccuracyLog, error.MalformedFseTable => |e| return e,
         error.EndOfStream => return error.MalformedFseTable,
+        else => |e| return e,
     };
     const accuracy_log = std.math.log2_int_ceil(usize, table_size);
 
@@ -46,7 +52,8 @@ fn decodeFseHuffmanTreeSlice(src: []const u8, compressed_size: usize, weights: *
     };
     const accuracy_log = std.math.log2_int_ceil(usize, table_size);
 
-    const start_index = std.math.cast(usize, counting_reader.bytes_read) orelse return error.MalformedHuffmanTree;
+    const start_index = std.math.cast(usize, counting_reader.bytes_read) orelse
+        return error.MalformedHuffmanTree;
     var huff_data = src[start_index..compressed_size];
     var huff_bits: readers.ReverseBitReader = undefined;
     huff_bits.init(huff_data) catch return error.MalformedHuffmanTree;
@@ -54,7 +61,12 @@ fn decodeFseHuffmanTreeSlice(src: []const u8, compressed_size: usize, weights: *
     return assignWeights(&huff_bits, accuracy_log, &entries, weights);
 }
 
-fn assignWeights(huff_bits: *readers.ReverseBitReader, accuracy_log: usize, entries: *[1 << 6]Table.Fse, weights: *[256]u4) !usize {
+fn assignWeights(
+    huff_bits: *readers.ReverseBitReader,
+    accuracy_log: usize,
+    entries: *[1 << 6]Table.Fse,
+    weights: *[256]u4,
+) !usize {
     var i: usize = 0;
     var even_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree;
     var odd_state: u32 = huff_bits.readBitsNoEof(u32, accuracy_log) catch return error.MalformedHuffmanTree;
@@ -173,7 +185,10 @@ fn buildHuffmanTree(weights: *[256]u4, symbol_count: usize) error{MalformedHuffm
     return tree;
 }
 
-pub fn decodeHuffmanTree(source: anytype, buffer: []u8) !LiteralsSection.HuffmanTree {
+pub fn decodeHuffmanTree(
+    source: anytype,
+    buffer: []u8,
+) (@TypeOf(source).Error || Error)!LiteralsSection.HuffmanTree {
     const header = try source.readByte();
     var weights: [256]u4 = undefined;
     const symbol_count = if (header < 128)
@@ -185,7 +200,10 @@ pub fn decodeHuffmanTree(source: anytype, buffer: []u8) !LiteralsSection.Huffman
     return buildHuffmanTree(&weights, symbol_count);
 }
 
-pub fn decodeHuffmanTreeSlice(src: []const u8, consumed_count: *usize) Error!LiteralsSection.HuffmanTree {
+pub fn decodeHuffmanTreeSlice(
+    src: []const u8,
+    consumed_count: *usize,
+) Error!LiteralsSection.HuffmanTree {
     if (src.len == 0) return error.MalformedHuffmanTree;
     const header = src[0];
     var bytes_read: usize = 1;
lib/std/compress/zstandard/decompress.zig
@@ -64,7 +64,7 @@ pub const HeaderError = error{ BadMagic, EndOfStream, ReservedBitSet };
 ///   - `error.EndOfStream` if `source` contains fewer than 4 bytes
 ///   - `error.ReservedBitSet` if the frame is a Zstandard frame and any of the
 ///     reserved bits are set
-pub fn decodeFrameHeader(source: anytype) HeaderError!FrameHeader {
+pub fn decodeFrameHeader(source: anytype) (@TypeOf(source).Error || HeaderError)!FrameHeader {
     const magic = try source.readIntLittle(u32);
     const frame_type = try frameType(magic);
     switch (frame_type) {
@@ -596,7 +596,9 @@ pub fn frameWindowSize(header: ZstandardHeader) ?u64 {
 /// Errors returned:
 ///   - `error.ReservedBitSet` if any of the reserved bits of the header are set
 ///   - `error.EndOfStream` if `source` does not contain a complete header
-pub fn decodeZstandardHeader(source: anytype) error{ EndOfStream, ReservedBitSet }!ZstandardHeader {
+pub fn decodeZstandardHeader(
+    source: anytype,
+) (@TypeOf(source).Error || error{ EndOfStream, ReservedBitSet })!ZstandardHeader {
     const descriptor = @bitCast(ZstandardHeader.Descriptor, try source.readByte());
 
     if (descriptor.reserved) return error.ReservedBitSet;