Commit 8a963fd66e

Igor Anić <igor.anic@gmail.com>
2024-03-02 21:10:36
flate: 32 bit BitReader
Extend BitReader to accept size of internal buffer. It can be u64 (only option until now) or u32.
1 parent 90c1a2c
Changed files (2)
lib
std
lib/std/compress/flate/bit_reader.zig
@@ -2,8 +2,16 @@ const std = @import("std");
 const assert = std.debug.assert;
 const testing = std.testing;
 
-pub fn bitReader(reader: anytype) BitReader(@TypeOf(reader)) {
-    return BitReader(@TypeOf(reader)).init(reader);
+pub fn bitReader(comptime T: type, reader: anytype) BitReader(T, @TypeOf(reader)) {
+    return BitReader(T, @TypeOf(reader)).init(reader);
+}
+
+pub fn BitReader64(comptime ReaderType: type) type {
+    return BitReader(u64, ReaderType);
+}
+
+pub fn BitReader32(comptime ReaderType: type) type {
+    return BitReader(u32, ReaderType);
 }
 
 /// Bit reader used during inflate (decompression). Has internal buffer of 64
@@ -15,12 +23,16 @@ pub fn bitReader(reader: anytype) BitReader(@TypeOf(reader)) {
 /// fill buffer from forward_reader by calling fill in advance and readF with
 /// buffered flag set.
 ///
-pub fn BitReader(comptime ReaderType: type) type {
+pub fn BitReader(T: type, comptime ReaderType: type) type {
+    assert(T == u32 or T == u64);
+    const t_bytes: usize = @sizeOf(T);
+    const Tshift = if (T == u64) u6 else u5;
+
     return struct {
         // Underlying reader used for filling internal bits buffer
         forward_reader: ReaderType = undefined,
         // Internal buffer of 64 bits
-        bits: u64 = 0,
+        bits: T = 0,
         // Number of bits in the buffer
         nbits: u32 = 0,
 
@@ -51,14 +63,14 @@ pub fn BitReader(comptime ReaderType: type) type {
 
             // Number of empty bytes in bits, round nbits to whole bytes.
             const empty_bytes =
-                @as(u8, if (self.nbits & 0x7 == 0) 8 else 7) - // 8 for 8, 16, 24..., 7 otherwise
+                @as(u8, if (self.nbits & 0x7 == 0) t_bytes else t_bytes - 1) - // 8 for 8, 16, 24..., 7 otherwise
                 (self.nbits >> 3); // 0 for 0-7, 1 for 8-16, ... same as / 8
 
-            var buf: [8]u8 = [_]u8{0} ** 8;
+            var buf: [t_bytes]u8 = [_]u8{0} ** t_bytes;
             const bytes_read = self.forward_reader.readAll(buf[0..empty_bytes]) catch 0;
             if (bytes_read > 0) {
-                const u: u64 = std.mem.readInt(u64, buf[0..8], .little);
-                self.bits |= u << @as(u6, @intCast(self.nbits));
+                const u: T = std.mem.readInt(T, buf[0..t_bytes], .little);
+                self.bits |= u << @as(Tshift, @intCast(self.nbits));
                 self.nbits += 8 * @as(u8, @intCast(bytes_read));
                 return;
             }
@@ -99,7 +111,17 @@ pub fn BitReader(comptime ReaderType: type) type {
 
         /// Read with flags provided.
         pub fn readF(self: *Self, comptime U: type, comptime how: u3) !U {
-            const n: u6 = @bitSizeOf(U);
+            if (U == T) {
+                assert(how == 0);
+                assert(self.alignBits() == 0);
+                try self.fill(@bitSizeOf(T));
+                assert(self.nbits == @bitSizeOf(T));
+                const v = self.bits;
+                self.nbits = 0;
+                self.bits = 0;
+                return v;
+            }
+            const n: Tshift = @bitSizeOf(U);
             switch (how) {
                 0 => { // `normal` read
                     try self.fill(n); // ensure that there are n bits in the buffer
@@ -157,7 +179,7 @@ pub fn BitReader(comptime ReaderType: type) type {
         }
 
         /// Advance buffer for n bits.
-        pub fn shift(self: *Self, n: u6) !void {
+        pub fn shift(self: *Self, n: Tshift) !void {
             if (n > self.nbits) return error.EndOfStream;
             self.bits >>= n;
             self.nbits -= n;
@@ -218,10 +240,10 @@ pub fn BitReader(comptime ReaderType: type) type {
     };
 }
 
-test "BitReader" {
+test "readF" {
     var fbs = std.io.fixedBufferStream(&[_]u8{ 0xf3, 0x48, 0xcd, 0xc9, 0x00, 0x00 });
-    var br = bitReader(fbs.reader());
-    const F = BitReader(@TypeOf(fbs.reader())).flag;
+    var br = bitReader(u64, fbs.reader());
+    const F = BitReader64(@TypeOf(fbs.reader())).flag;
 
     try testing.expectEqual(@as(u8, 48), br.nbits);
     try testing.expectEqual(@as(u64, 0xc9cd48f3), br.bits);
@@ -254,36 +276,38 @@ test "BitReader" {
 }
 
 test "read block type 1 data" {
-    const data = [_]u8{
-        0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x57, 0x28, 0xcf, // deflate data block type 1
-        0x2f, 0xca, 0x49, 0xe1, 0x02, 0x00,
-        0x0c, 0x01, 0x02, 0x03, //
-        0xaa, 0xbb, 0xcc, 0xdd,
-    };
-    var fbs = std.io.fixedBufferStream(&data);
-    var br = bitReader(fbs.reader());
-    const F = BitReader(@TypeOf(fbs.reader())).flag;
+    inline for ([_]type{ u64, u32 }) |T| {
+        const data = [_]u8{
+            0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x57, 0x28, 0xcf, // deflate data block type 1
+            0x2f, 0xca, 0x49, 0xe1, 0x02, 0x00,
+            0x0c, 0x01, 0x02, 0x03, //
+            0xaa, 0xbb, 0xcc, 0xdd,
+        };
+        var fbs = std.io.fixedBufferStream(&data);
+        var br = bitReader(T, fbs.reader());
+        const F = BitReader(T, @TypeOf(fbs.reader())).flag;
 
-    try testing.expectEqual(@as(u1, 1), try br.readF(u1, 0)); // bfinal
-    try testing.expectEqual(@as(u2, 1), try br.readF(u2, 0)); // block_type
+        try testing.expectEqual(@as(u1, 1), try br.readF(u1, 0)); // bfinal
+        try testing.expectEqual(@as(u2, 1), try br.readF(u2, 0)); // block_type
 
-    for ("Hello world\n") |c| {
-        try testing.expectEqual(@as(u8, c), try br.readF(u8, F.reverse) - 0x30);
+        for ("Hello world\n") |c| {
+            try testing.expectEqual(@as(u8, c), try br.readF(u8, F.reverse) - 0x30);
+        }
+        try testing.expectEqual(@as(u7, 0), try br.readF(u7, 0)); // end of block
+        br.alignToByte();
+        try testing.expectEqual(@as(u32, 0x0302010c), try br.readF(u32, 0));
+        try testing.expectEqual(@as(u16, 0xbbaa), try br.readF(u16, 0));
+        try testing.expectEqual(@as(u16, 0xddcc), try br.readF(u16, 0));
     }
-    try testing.expectEqual(@as(u7, 0), try br.readF(u7, 0)); // end of block
-    br.alignToByte();
-    try testing.expectEqual(@as(u32, 0x0302010c), try br.readF(u32, 0));
-    try testing.expectEqual(@as(u16, 0xbbaa), try br.readF(u16, 0));
-    try testing.expectEqual(@as(u16, 0xddcc), try br.readF(u16, 0));
 }
 
-test "init" {
+test "shift/fill" {
     const data = [_]u8{
         0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
         0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
     };
     var fbs = std.io.fixedBufferStream(&data);
-    var br = bitReader(fbs.reader());
+    var br = bitReader(u64, fbs.reader());
 
     try testing.expectEqual(@as(u64, 0x08_07_06_05_04_03_02_01), br.bits);
     try br.shift(8);
@@ -303,31 +327,39 @@ test "init" {
 }
 
 test "readAll" {
-    const data = [_]u8{
-        0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
-        0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
-    };
-    var fbs = std.io.fixedBufferStream(&data);
-    var br = bitReader(fbs.reader());
+    inline for ([_]type{ u64, u32 }) |T| {
+        const data = [_]u8{
+            0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
+            0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
+        };
+        var fbs = std.io.fixedBufferStream(&data);
+        var br = bitReader(T, fbs.reader());
 
-    try testing.expectEqual(@as(u64, 0x08_07_06_05_04_03_02_01), br.bits);
+        switch (T) {
+            u64 => try testing.expectEqual(@as(u64, 0x08_07_06_05_04_03_02_01), br.bits),
+            u32 => try testing.expectEqual(@as(u32, 0x04_03_02_01), br.bits),
+            else => unreachable,
+        }
 
-    var out: [16]u8 = undefined;
-    try br.readAll(out[0..]);
-    try testing.expect(br.nbits == 0);
-    try testing.expect(br.bits == 0);
+        var out: [16]u8 = undefined;
+        try br.readAll(out[0..]);
+        try testing.expect(br.nbits == 0);
+        try testing.expect(br.bits == 0);
 
-    try testing.expectEqualSlices(u8, data[0..16], &out);
+        try testing.expectEqualSlices(u8, data[0..16], &out);
+    }
 }
 
 test "readFixedCode" {
-    const fixed_codes = @import("huffman_encoder.zig").fixed_codes;
+    inline for ([_]type{ u64, u32 }) |T| {
+        const fixed_codes = @import("huffman_encoder.zig").fixed_codes;
 
-    var fbs = std.io.fixedBufferStream(&fixed_codes);
-    var rdr = bitReader(fbs.reader());
+        var fbs = std.io.fixedBufferStream(&fixed_codes);
+        var rdr = bitReader(T, fbs.reader());
 
-    for (0..286) |c| {
-        try testing.expectEqual(c, try rdr.readFixedCode());
+        for (0..286) |c| {
+            try testing.expectEqual(c, try rdr.readFixedCode());
+        }
+        try testing.expect(rdr.nbits == 0);
     }
-    try testing.expect(rdr.nbits == 0);
 }
lib/std/compress/flate/inflate.zig
@@ -3,7 +3,7 @@ const assert = std.debug.assert;
 const testing = std.testing;
 
 const hfd = @import("huffman_decoder.zig");
-const BitReader = @import("bit_reader.zig").BitReader;
+const BitReader = @import("bit_reader.zig").BitReader64;
 const CircularBuffer = @import("CircularBuffer.zig");
 const Container = @import("container.zig").Container;
 const Token = @import("Token.zig");