Commit cbfaa876d4

dweiller <4678790+dweiller@users.noreplay.github.com>
2023-01-23 02:47:46
std.compress.zstandard: cleanup ReverseBitReader
1 parent c819e58
Changed files (1)
lib
std
compress
zstandard
lib/std/compress/zstandard/decompress.zig
@@ -68,8 +68,7 @@ const DecodeState = struct {
 
     fse_tables_undefined: bool,
 
-    literal_stream_reader: ReverseBitReader(ReversedByteReader.Reader),
-    literal_stream_bytes: ReversedByteReader,
+    literal_stream_reader: ReverseBitReader,
     literal_stream_index: usize,
     huffman_tree: ?Literals.HuffmanTree,
 
@@ -288,9 +287,7 @@ const DecodeState = struct {
 
     fn initLiteralStream(self: *DecodeState, bytes: []const u8) !void {
         log.debug("initing literal stream: {}", .{std.fmt.fmtSliceHexUpper(bytes)});
-        self.literal_stream_bytes = reversedByteReader(bytes);
-        self.literal_stream_reader = reverseBitReader(self.literal_stream_bytes.reader());
-        while (0 == try self.literal_stream_reader.readBitsNoEof(u1, 1)) {}
+        try self.literal_stream_reader.init(bytes);
     }
 
     fn decodeLiteralsSlice(self: *DecodeState, dest: []u8, literals: Literals, len: usize) !void {
@@ -532,7 +529,6 @@ pub fn decodeZStandardFrameAlloc(allocator: std.mem.Allocator, src: []const u8,
 
         .literal_written_count = 0,
         .literal_stream_reader = undefined,
-        .literal_stream_bytes = undefined,
         .literal_stream_index = undefined,
         .huffman_tree = null,
     };
@@ -591,7 +587,6 @@ pub fn decodeFrameBlocks(dest: []u8, src: []const u8, consumed_count: *usize, ha
 
         .literal_written_count = 0,
         .literal_stream_reader = undefined,
-        .literal_stream_bytes = undefined,
         .literal_stream_index = undefined,
         .huffman_tree = null,
     };
@@ -725,10 +720,9 @@ pub fn decodeBlock(
             var bytes_written: usize = 0;
             if (sequences_header.sequence_count > 0) {
                 const bit_stream_bytes = src[bytes_read..block_size];
-                var reverse_byte_reader = reversedByteReader(bit_stream_bytes);
-                var bit_stream = reverseBitReader(reverse_byte_reader.reader());
+                var bit_stream: ReverseBitReader = undefined;
+                try bit_stream.init(bit_stream_bytes);
 
-                while (0 == try bit_stream.readBitsNoEof(u1, 1)) {}
                 try decode_state.readInitialState(&bit_stream);
 
                 var i: usize = 0;
@@ -791,10 +785,9 @@ pub fn decodeBlockRingBuffer(
             var bytes_written: usize = 0;
             if (sequences_header.sequence_count > 0) {
                 const bit_stream_bytes = src[bytes_read..block_size];
-                var reverse_byte_reader = reversedByteReader(bit_stream_bytes);
-                var bit_stream = reverseBitReader(reverse_byte_reader.reader());
+                var bit_stream: ReverseBitReader = undefined;
+                try bit_stream.init(bit_stream_bytes);
 
-                while (0 == try bit_stream.readBitsNoEof(u1, 1)) {}
                 try decode_state.readInitialState(&bit_stream);
 
                 var i: usize = 0;
@@ -1028,9 +1021,8 @@ fn decodeHuffmanTree(src: []const u8, consumed_count: *usize) !Literals.HuffmanT
         const accuracy_log = std.math.log2_int_ceil(usize, table_size);
 
         var huff_data = src[1 + counting_reader.bytes_read .. compressed_size + 1];
-        var huff_data_bytes = reversedByteReader(huff_data);
-        var huff_bits = reverseBitReader(huff_data_bytes.reader());
-        while (0 == try huff_bits.readBitsNoEof(u1, 1)) {}
+        var huff_bits: ReverseBitReader = undefined;
+        try huff_bits.init(huff_data);
 
         dumpFseTable("huffman", entries[0..table_size]);
 
@@ -1415,48 +1407,49 @@ const ReversedByteReader = struct {
 
     const Reader = std.io.Reader(*ReversedByteReader, error{}, readFn);
 
+    fn init(bytes: []const u8) ReversedByteReader {
+        return .{
+            .bytes = bytes,
+            .remaining_bytes = bytes.len,
+        };
+    }
+
     fn reader(self: *ReversedByteReader) Reader {
         return .{ .context = self };
     }
-};
-
-fn readFn(ctx: *ReversedByteReader, buffer: []u8) !usize {
-    if (ctx.remaining_bytes == 0) return 0;
-    const byte_index = ctx.remaining_bytes - 1;
-    buffer[0] = ctx.bytes[byte_index];
-    // buffer[0] = @bitReverse(ctx.bytes[byte_index]);
-    ctx.remaining_bytes = byte_index;
-    return 1;
-}
 
-fn reversedByteReader(bytes: []const u8) ReversedByteReader {
-    return ReversedByteReader{
-        .remaining_bytes = bytes.len,
-        .bytes = bytes,
-    };
-}
+    fn readFn(ctx: *ReversedByteReader, buffer: []u8) !usize {
+        if (ctx.remaining_bytes == 0) return 0;
+        const byte_index = ctx.remaining_bytes - 1;
+        buffer[0] = ctx.bytes[byte_index];
+        // buffer[0] = @bitReverse(ctx.bytes[byte_index]);
+        ctx.remaining_bytes = byte_index;
+        return 1;
+    }
+};
 
-fn ReverseBitReader(comptime Reader: type) type {
-    return struct {
-        underlying: std.io.BitReader(.Big, Reader),
+const ReverseBitReader = struct {
+    byte_reader: ReversedByteReader,
+    bit_reader: std.io.BitReader(.Big, ReversedByteReader.Reader),
 
-        fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U {
-            return self.underlying.readBitsNoEof(U, num_bits);
-        }
+    fn init(self: *ReverseBitReader, bytes: []const u8) !void {
+        self.byte_reader = ReversedByteReader.init(bytes);
+        self.bit_reader = std.io.bitReader(.Big, self.byte_reader.reader());
+        while (0 == self.readBitsNoEof(u1, 1) catch return error.BitStreamHasNoStartBit) {}
+    }
 
-        fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U {
-            return try self.underlying.readBits(U, num_bits, out_bits);
-        }
+    fn readBitsNoEof(self: *@This(), comptime U: type, num_bits: usize) !U {
+        return self.bit_reader.readBitsNoEof(U, num_bits);
+    }
 
-        fn alignToByte(self: *@This()) void {
-            self.underlying.alignToByte();
-        }
-    };
-}
+    fn readBits(self: *@This(), comptime U: type, num_bits: usize, out_bits: *usize) !U {
+        return try self.bit_reader.readBits(U, num_bits, out_bits);
+    }
 
-fn reverseBitReader(reader: anytype) ReverseBitReader(@TypeOf(reader)) {
-    return .{ .underlying = std.io.bitReader(.Big, reader) };
-}
+    fn alignToByte(self: *@This()) void {
+        self.bit_reader.alignToByte();
+    }
+};
 
 fn BitReader(comptime Reader: type) type {
     return struct {