Commit 9c8cb777d4

Andrew Kelley <andrew@ziglang.org>
2025-07-28 22:13:59
std.compress.flate.Decompress: implement more bit reading
1 parent 6509fa1
Changed files (1)
lib
std
compress
lib/std/compress/flate/Decompress.zig
@@ -146,8 +146,8 @@ fn dynamicCodeLength(self: *Decompress, code: u16, lens: []u4, pos: usize) !usiz
 // used. Shift bit reader for that much bits, those bits are used. And
 // return symbol.
 fn decodeSymbol(self: *Decompress, decoder: anytype) !Symbol {
-    const sym = try decoder.find(try self.peekBitsReverseBuffered(u15));
-    try self.shiftBits(sym.code_bits);
+    const sym = try decoder.find(@bitReverse(try self.peekBits(u15)));
+    try self.tossBits(sym.code_bits);
     return sym;
 }
 
@@ -245,8 +245,8 @@ fn readInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader.S
                     var dec_lens: [286 + 30]u4 = @splat(0);
                     var pos: usize = 0;
                     while (pos < hlit + hdist) {
-                        const sym = try cl_dec.find(try d.peekBitsReverse(u7));
-                        try d.shiftBits(sym.code_bits);
+                        const sym = try cl_dec.find(@bitReverse(try d.peekBits(u7)));
+                        try d.tossBits(sym.code_bits);
                         pos += try d.dynamicCodeLength(sym.symbol, &dec_lens, pos);
                     }
                     if (pos > hlit + hdist) {
@@ -291,7 +291,7 @@ fn readInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader.S
                         // Handles fixed block non literal (length) code.
                         // Length code is followed by 5 bits of distance code.
                         const length = try d.decodeLength(@intCast(code - 257));
-                        const distance = try d.decodeDistance(try d.takeBitsReverseBuffered(u5));
+                        const distance = try d.decodeDistance(@bitReverse(try d.takeBits(u5)));
                         remaining = try writeMatch(w, length, distance, remaining);
                     },
                     else => return error.InvalidCode,
@@ -384,24 +384,47 @@ fn takeBits(d: *Decompress, comptime T: type) !T {
     };
 }
 
-fn takeBitsReverseBuffered(d: *Decompress, comptime T: type) !T {
-    _ = d;
-    @panic("TODO");
-}
-
-fn takeNBitsBuffered(d: *Decompress, n: u4) !u16 {
-    _ = d;
-    _ = n;
-    @panic("TODO");
+fn peekBits(d: *Decompress, comptime T: type) !T {
+    const U = @Type(.{ .int = .{ .signedness = .unsigned, .bits = @bitSizeOf(T) } });
+    const remaining_bits = d.remaining_bits;
+    const next_bits = d.next_bits;
+    if (remaining_bits >= @bitSizeOf(T)) {
+        const u: U = @truncate(next_bits);
+        return switch (@typeInfo(T)) {
+            .int => u,
+            .@"enum" => @enumFromInt(u),
+            else => @bitCast(u),
+        };
+    }
+    const in = d.input;
+    const next_int = try in.peekInt(usize, .little);
+    const needed_bits = @bitSizeOf(T) - remaining_bits;
+    const u: U = @intCast((next_bits << needed_bits) | (next_int & ((@as(usize, 1) << needed_bits) - 1)));
+    return switch (@typeInfo(T)) {
+        .int => u,
+        .@"enum" => @enumFromInt(u),
+        else => @bitCast(u),
+    };
 }
 
-fn peekBitsReverse(d: *Decompress, comptime T: type) !T {
-    _ = d;
-    @panic("TODO");
+fn tossBits(d: *Decompress, n: u6) !void {
+    const remaining_bits = d.remaining_bits;
+    const next_bits = d.next_bits;
+    if (remaining_bits >= n) {
+        d.next_bits = next_bits >> n;
+        d.remaining_bits = remaining_bits - n;
+    } else {
+        const in = d.input;
+        const next_int = try in.takeInt(usize, .little);
+        const needed_bits = n - remaining_bits;
+        d.next_bits = next_int >> needed_bits;
+        d.remaining_bits = @intCast(@bitSizeOf(usize) - @as(usize, needed_bits));
+    }
 }
 
-fn peekBitsReverseBuffered(d: *Decompress, comptime T: type) !T {
+fn takeNBitsBuffered(d: *Decompress, n: u4) !u16 {
     _ = d;
+    _ = n;
     @panic("TODO");
 }
 
@@ -422,15 +445,26 @@ fn alignBitsToByte(d: *Decompress) void {
     d.next_bits = 0;
 }
 
-fn shiftBits(d: *Decompress, n: u6) !void {
-    _ = d;
-    _ = n;
-    @panic("TODO");
-}
-
+/// Reads first 7 bits, and then maybe 1 or 2 more to get full 7,8 or 9 bit code.
+/// ref: https://datatracker.ietf.org/doc/html/rfc1951#page-12
+///         Lit Value    Bits        Codes
+///          ---------    ----        -----
+///            0 - 143     8          00110000 through
+///                                   10111111
+///          144 - 255     9          110010000 through
+///                                   111111111
+///          256 - 279     7          0000000 through
+///                                   0010111
+///          280 - 287     8          11000000 through
+///                                   11000111
 fn readFixedCode(d: *Decompress) !u16 {
-    _ = d;
-    @panic("TODO");
+    const code7 = @bitReverse(try d.takeBits(u7));
+    return switch (code7) {
+        0...0b0010_111 => @as(u16, code7) + 256,
+        0b0010_111 + 1...0b1011_111 => (@as(u16, code7) << 1) + @as(u16, try d.takeBits(u1)) - 0b0011_0000,
+        0b1011_111 + 1...0b1100_011 => (@as(u16, code7 - 0b1100000) << 1) + try d.takeBits(u1) + 280,
+        else => (@as(u16, code7 - 0b1100_100) << 2) + @as(u16, @bitReverse(try d.takeBits(u2))) + 144,
+    };
 }
 
 pub const Symbol = packed struct {
@@ -731,20 +765,21 @@ test "encode/decode literals" {
     }
 }
 
-test "basic" {
-    // non compressed block (type 0)
+test "non compressed block (type 0)" {
     try testBasicCase(&[_]u8{
         0b0000_0001, 0b0000_1100, 0x00, 0b1111_0011, 0xff, // deflate fixed buffer header len, nlen
         'H', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', 0x0a, // non compressed data
     }, "Hello world\n");
+}
 
-    // fixed code block (type 1)
+test "fixed code block (type 1)" {
     try testBasicCase(&[_]u8{
         0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x57, 0x28, 0xcf, // deflate data block type 1
         0x2f, 0xca, 0x49, 0xe1, 0x02, 0x00,
     }, "Hello world\n");
+}
 
-    // dynamic block (type 2)
+test "dynamic block (type 2)" {
     try testBasicCase(&[_]u8{
         0x3d, 0xc6, 0x39, 0x11, 0x00, 0x00, 0x0c, 0x02, // deflate data block type 2
         0x30, 0x2b, 0xb5, 0x52, 0x1e, 0xff, 0x96, 0x38,