Commit e73ca2444e

Andrew Kelley <andrew@ziglang.org>
2025-07-29 01:46:16
std.compress.flate.Decompress: implement peekBitsEnding and writeMatch
1 parent 7bf91d7
Changed files (2)
lib
std
lib/std/compress/flate/Decompress.zig
@@ -97,7 +97,7 @@ fn decodeLength(self: *Decompress, code: u8) !u16 {
     return if (ml.extra_bits == 0) // 0 - 5 extra bits
         ml.base
     else
-        ml.base + try self.takeNBitsBuffered(ml.extra_bits);
+        ml.base + try self.takeBitsRuntime(ml.extra_bits);
 }
 
 fn decodeDistance(self: *Decompress, code: u8) !u16 {
@@ -106,7 +106,7 @@ fn decodeDistance(self: *Decompress, code: u8) !u16 {
     return if (md.extra_bits == 0) // 0 - 13 extra bits
         md.base
     else
-        md.base + try self.takeNBitsBuffered(md.extra_bits);
+        md.base + try self.takeBitsRuntime(md.extra_bits);
 }
 
 // Decode code length symbol to code length. Writes decoded length into
@@ -293,7 +293,8 @@ fn readInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader.S
                         // Length code is followed by 5 bits of distance code.
                         const length = try d.decodeLength(@intCast(code - 257));
                         const distance = try d.decodeDistance(@bitReverse(try d.takeBits(u5)));
-                        remaining = try writeMatch(w, length, distance, remaining);
+                        try writeMatch(w, length, distance);
+                        remaining -= length;
                     },
                     else => return error.InvalidCode,
                 }
@@ -317,7 +318,8 @@ fn readInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader.S
                         const length = try d.decodeLength(sym.symbol);
                         const dsm = try d.decodeSymbol(&d.dst_dec);
                         const distance = try d.decodeDistance(dsm.symbol);
-                        remaining = try writeMatch(w, length, distance, remaining);
+                        try writeMatch(w, length, distance);
+                        remaining -= length;
                     },
                     .end_of_block => {
                         d.state = if (d.final_block) .protocol_footer else .block_header;
@@ -350,12 +352,19 @@ fn readInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader.S
 
 /// Write match (back-reference to the same data slice) starting at `distance`
 /// back from current write position, and `length` of bytes.
-fn writeMatch(w: *Writer, length: u16, distance: u16, remaining: usize) !usize {
-    _ = w;
-    _ = length;
-    _ = distance;
-    _ = remaining;
-    @panic("TODO");
+fn writeMatch(w: *Writer, length: u16, distance: u16) !void {
+    if (w.end < length) return error.InvalidMatch;
+    if (length < Token.base_length) return error.InvalidMatch;
+    if (length > Token.max_length) return error.InvalidMatch;
+    if (distance < Token.min_distance) return error.InvalidMatch;
+    if (distance > Token.max_distance) return error.InvalidMatch;
+
+    // This is not a @memmove; it intentionally repeats patterns caused by
+    // iterating one byte at a time.
+    const dest = try w.writableSlicePreserve(flate.history_len, length);
+    const end = dest.ptr - w.buffer.ptr;
+    const src = w.buffer[end - distance ..][0..length];
+    for (dest, src) |*d, s| d.* = s;
 }
 
 fn takeBits(d: *Decompress, comptime T: type) !T {
@@ -417,35 +426,48 @@ fn takeBitsEnding(d: *Decompress, comptime T: type) !T {
     };
 }
 
-fn peekBits(d: *Decompress, comptime T: type) !T {
-    const U = @Type(.{ .int = .{ .signedness = .unsigned, .bits = @bitSizeOf(T) } });
+fn peekBits(d: *Decompress, comptime U: type) !U {
     const remaining_bits = d.remaining_bits;
     const next_bits = d.next_bits;
-    if (remaining_bits >= @bitSizeOf(T)) {
-        const u: U = @truncate(next_bits);
-        return switch (@typeInfo(T)) {
-            .int => u,
-            .@"enum" => @enumFromInt(u),
-            else => @bitCast(u),
-        };
-    }
+    if (remaining_bits >= @bitSizeOf(U)) return @truncate(next_bits);
     const in = d.input;
     const next_int = in.peekInt(usize, .little) catch |err| switch (err) {
         error.ReadFailed => return error.ReadFailed,
-        error.EndOfStream => return peekBitsEnding(d, T),
-    };
-    const needed_bits = @bitSizeOf(T) - remaining_bits;
-    const u: U = @intCast(((next_int & ((@as(usize, 1) << needed_bits) - 1)) << remaining_bits) | next_bits);
-    return switch (@typeInfo(T)) {
-        .int => u,
-        .@"enum" => @enumFromInt(u),
-        else => @bitCast(u),
+        error.EndOfStream => return peekBitsEnding(d, U),
     };
+    const needed_bits = @bitSizeOf(U) - remaining_bits;
+    return @intCast(((next_int & ((@as(usize, 1) << needed_bits) - 1)) << remaining_bits) | next_bits);
 }
 
-fn peekBitsEnding(d: *Decompress, comptime T: type) !T {
-    _ = d;
-    @panic("TODO");
+fn peekBitsEnding(d: *Decompress, comptime U: type) !U {
+    const remaining_bits = d.remaining_bits;
+    const next_bits = d.next_bits;
+    const in = d.input;
+    var u: U = 0;
+    var remaining_needed_bits = @bitSizeOf(U) - remaining_bits;
+    var peek_len: usize = 0;
+    while (@bitSizeOf(U) >= 8 and remaining_needed_bits >= 8) {
+        peek_len += 1;
+        const byte = try specialPeek(in, next_bits, peek_len);
+        u = (u << 8) | byte;
+        remaining_needed_bits -= 8;
+    }
+    if (remaining_needed_bits != 0) {
+        peek_len += 1;
+        const byte = try specialPeek(in, next_bits, peek_len);
+        u = @intCast((@as(usize, u) << remaining_needed_bits) | (byte & ((@as(usize, 1) << remaining_needed_bits) - 1)));
+    }
+    return @intCast((@as(usize, u) << remaining_bits) | next_bits);
+}
+
+/// If there is any unconsumed data, handles EndOfStream by pretending there
+/// are zeroes afterwards.
+fn specialPeek(in: *Reader, next_bits: usize, n: usize) Reader.Error!u8 {
+    const peeked = in.peek(n) catch |err| switch (err) {
+        error.ReadFailed => return error.ReadFailed,
+        error.EndOfStream => if (next_bits == 0 and n == 0) return error.EndOfStream else return 0,
+    };
+    return peeked[n - 1];
 }
 
 fn tossBits(d: *Decompress, n: u6) !void {
@@ -472,10 +494,12 @@ fn tossBitsEnding(d: *Decompress, n: u6) !void {
     @panic("TODO");
 }
 
-fn takeNBitsBuffered(d: *Decompress, n: u4) !u16 {
-    _ = d;
-    _ = n;
-    @panic("TODO");
+fn takeBitsRuntime(d: *Decompress, n: u4) !u16 {
+    const x = try peekBits(d, u16);
+    const mask: u16 = (@as(u16, 1) << n) - 1;
+    const u: u16 = @as(u16, @truncate(x)) & mask;
+    try tossBits(d, n);
+    return u;
 }
 
 fn alignBitsToByte(d: *Decompress) void {
lib/std/compress/flate.zig
@@ -11,8 +11,8 @@ pub const history_len = 32768;
 /// of LZ77 and Huffman coding.
 pub const Compress = @import("flate/Compress.zig");
 
-/// Inflate is the decoding process that takes a Deflate bitstream for
-/// decompression and correctly produces the original full-size data or file.
+/// Inflate is the decoding process that consumes a Deflate bitstream and
+/// produces the original full-size data.
 pub const Decompress = @import("flate/Decompress.zig");
 
 /// Compression without Lempel-Ziv match searching. Faster compression, less