Commit 6509fa1cf3

Andrew Kelley <andrew@ziglang.org>
2025-07-28 20:42:43
std.compress.flate.Decompress: passing basic test case
1 parent 88ca750
Changed files (1)
lib
std
compress
lib/std/compress/flate/Decompress.zig
@@ -10,7 +10,11 @@ const Decompress = @This();
 const Token = @import("Token.zig");
 
 input: *Reader,
+next_bits: usize,
+remaining_bits: std.math.Log2Int(usize),
+
 reader: Reader,
+
 /// Hashes, produces checksum, of uncompressed data for gzip/zlib footer.
 hasher: Container.Hasher,
 
@@ -65,6 +69,8 @@ pub fn init(input: *Reader, container: Container, buffer: []u8) Decompress {
             .end = 0,
         },
         .input = input,
+        .next_bits = 0,
+        .remaining_bits = 0,
         .hasher = .init(container),
         .lit_dec = .{},
         .dst_dec = .{},
@@ -228,15 +234,15 @@ fn readInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader.S
                         return error.InvalidDynamicBlockHeader;
 
                     // lengths for code lengths
-                    var cl_lens = [_]u4{0} ** 19;
-                    for (0..hclen) |i| {
-                        cl_lens[flate.HuffmanEncoder.codegen_order[i]] = try d.takeBits(u3);
+                    var cl_lens: [19]u4 = @splat(0);
+                    for (flate.HuffmanEncoder.codegen_order[0..hclen]) |i| {
+                        cl_lens[i] = try d.takeBits(u3);
                     }
                     var cl_dec: CodegenDecoder = .{};
                     try cl_dec.generate(&cl_lens);
 
                     // decoded code lengths
-                    var dec_lens = [_]u4{0} ** (286 + 30);
+                    var dec_lens: [286 + 30]u4 = @splat(0);
                     var pos: usize = 0;
                     while (pos < hlit + hdist) {
                         const sym = try cl_dec.find(try d.peekBitsReverse(u7));
@@ -352,8 +358,30 @@ fn writeMatch(w: *Writer, length: u16, distance: u16, remaining: usize) !usize {
 }
 
 fn takeBits(d: *Decompress, comptime T: type) !T {
-    _ = d;
-    @panic("TODO");
+    const U = @Type(.{ .int = .{ .signedness = .unsigned, .bits = @bitSizeOf(T) } });
+    const remaining_bits = d.remaining_bits;
+    const next_bits = d.next_bits;
+    if (remaining_bits >= @bitSizeOf(T)) {
+        const u: U = @truncate(next_bits);
+        d.next_bits = next_bits >> @bitSizeOf(T);
+        d.remaining_bits = remaining_bits - @bitSizeOf(T);
+        return switch (@typeInfo(T)) {
+            .int => u,
+            .@"enum" => @enumFromInt(u),
+            else => @bitCast(u),
+        };
+    }
+    const in = d.input;
+    const next_int = try in.takeInt(usize, .little);
+    const needed_bits = @bitSizeOf(T) - remaining_bits;
+    const u: U = @intCast((next_bits << needed_bits) | (next_int & ((@as(usize, 1) << needed_bits) - 1)));
+    d.next_bits = next_int >> needed_bits;
+    d.remaining_bits = @intCast(@bitSizeOf(usize) - @as(usize, needed_bits));
+    return switch (@typeInfo(T)) {
+        .int => u,
+        .@"enum" => @enumFromInt(u),
+        else => @bitCast(u),
+    };
 }
 
 fn takeBitsReverseBuffered(d: *Decompress, comptime T: type) !T {
@@ -378,8 +406,20 @@ fn peekBitsReverseBuffered(d: *Decompress, comptime T: type) !T {
 }
 
 fn alignBitsToByte(d: *Decompress) void {
-    _ = d;
-    @panic("TODO");
+    const remaining_bits = d.remaining_bits;
+    const next_bits = d.next_bits;
+    if (remaining_bits == 0) return;
+    const discard_bits = remaining_bits % 8;
+    const n_bytes = remaining_bits / 8;
+    var put_back_bits = next_bits >> discard_bits;
+    const in = d.input;
+    in.seek -= n_bytes;
+    for (in.buffer[in.seek..][0..n_bytes]) |*b| {
+        b.* = @truncate(put_back_bits);
+        put_back_bits >>= 8;
+    }
+    d.remaining_bits = 0;
+    d.next_bits = 0;
 }
 
 fn shiftBits(d: *Decompress, n: u6) !void {
@@ -691,47 +731,37 @@ test "encode/decode literals" {
     }
 }
 
-test "decompress" {
-    const cases = [_]struct {
-        in: []const u8,
-        out: []const u8,
-    }{
-        // non compressed block (type 0)
-        .{
-            .in = &[_]u8{
-                0b0000_0001, 0b0000_1100, 0x00, 0b1111_0011, 0xff, // deflate fixed buffer header len, nlen
-                'H', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', 0x0a, // non compressed data
-            },
-            .out = "Hello world\n",
-        },
-        // fixed code block (type 1)
-        .{
-            .in = &[_]u8{
-                0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x57, 0x28, 0xcf, // deflate data block type 1
-                0x2f, 0xca, 0x49, 0xe1, 0x02, 0x00,
-            },
-            .out = "Hello world\n",
-        },
-        // dynamic block (type 2)
-        .{
-            .in = &[_]u8{
-                0x3d, 0xc6, 0x39, 0x11, 0x00, 0x00, 0x0c, 0x02, // deflate data block type 2
-                0x30, 0x2b, 0xb5, 0x52, 0x1e, 0xff, 0x96, 0x38,
-                0x16, 0x96, 0x5c, 0x1e, 0x94, 0xcb, 0x6d, 0x01,
-            },
-            .out = "ABCDEABCD ABCDEABCD",
-        },
-    };
-    for (cases) |c| {
-        var fb: Reader = .fixed(c.in);
-        var aw: Writer.Allocating = .init(testing.allocator);
-        defer aw.deinit();
+test "basic" {
+    // non compressed block (type 0)
+    try testBasicCase(&[_]u8{
+        0b0000_0001, 0b0000_1100, 0x00, 0b1111_0011, 0xff, // deflate fixed buffer header len, nlen
+        'H', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', 0x0a, // non compressed data
+    }, "Hello world\n");
+
+    // fixed code block (type 1)
+    try testBasicCase(&[_]u8{
+        0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x57, 0x28, 0xcf, // deflate data block type 1
+        0x2f, 0xca, 0x49, 0xe1, 0x02, 0x00,
+    }, "Hello world\n");
+
+    // dynamic block (type 2)
+    try testBasicCase(&[_]u8{
+        0x3d, 0xc6, 0x39, 0x11, 0x00, 0x00, 0x0c, 0x02, // deflate data block type 2
+        0x30, 0x2b, 0xb5, 0x52, 0x1e, 0xff, 0x96, 0x38,
+        0x16, 0x96, 0x5c, 0x1e, 0x94, 0xcb, 0x6d, 0x01,
+    }, "ABCDEABCD ABCDEABCD");
+}
 
-        var decompress: Decompress = .init(&fb, .raw, &.{});
-        const r = &decompress.reader;
-        _ = try r.streamRemaining(&aw.writer);
-        try testing.expectEqualStrings(c.out, aw.getWritten());
-    }
+fn testBasicCase(in: []const u8, out: []const u8) !void {
+    var reader: Reader = .fixed(in);
+    var aw: Writer.Allocating = .init(testing.allocator);
+    try aw.ensureUnusedCapacity(flate.history_len + 1);
+    defer aw.deinit();
+
+    var decompress: Decompress = .init(&reader, .raw, &.{});
+    const r = &decompress.reader;
+    _ = try r.streamRemaining(&aw.writer);
+    try testing.expectEqualStrings(out, aw.getWritten());
 }
 
 test "gzip decompress" {