master
   1const std = @import("../../std.zig");
   2const assert = std.debug.assert;
   3const flate = std.compress.flate;
   4const testing = std.testing;
   5const Writer = std.Io.Writer;
   6const Reader = std.Io.Reader;
   7const Container = flate.Container;
   8
   9const Decompress = @This();
  10const token = @import("token.zig");
  11
  12input: *Reader,
  13consumed_bits: u3,
  14
  15reader: Reader,
  16
  17container_metadata: Container.Metadata,
  18
  19lit_dec: LiteralDecoder,
  20dst_dec: DistanceDecoder,
  21
  22final_block: bool,
  23state: State,
  24
  25err: ?Error,
  26
  27const BlockType = enum(u2) {
  28    stored = 0,
  29    fixed = 1,
  30    dynamic = 2,
  31    invalid = 3,
  32};
  33
  34const State = union(enum) {
  35    protocol_header,
  36    block_header,
  37    stored_block: u16,
  38    fixed_block,
  39    fixed_block_literal: u8,
  40    fixed_block_match: u16,
  41    dynamic_block,
  42    dynamic_block_literal: u8,
  43    dynamic_block_match: u16,
  44    protocol_footer,
  45    end,
  46};
  47
  48pub const Error = Container.Error || error{
  49    InvalidCode,
  50    InvalidMatch,
  51    WrongStoredBlockNlen,
  52    InvalidBlockType,
  53    InvalidDynamicBlockHeader,
  54    ReadFailed,
  55    OversubscribedHuffmanTree,
  56    IncompleteHuffmanTree,
  57    MissingEndOfBlockCode,
  58    EndOfStream,
  59};
  60
  61const direct_vtable: Reader.VTable = .{
  62    .stream = streamDirect,
  63    .rebase = rebaseFallible,
  64    .discard = discardDirect,
  65    .readVec = readVec,
  66};
  67
  68const indirect_vtable: Reader.VTable = .{
  69    .stream = streamIndirect,
  70    .rebase = rebaseFallible,
  71    .discard = discardIndirect,
  72    .readVec = readVec,
  73};
  74
  75/// `input` buffer is asserted to be at least 10 bytes, or EOF before then.
  76///
  77/// If `buffer` is provided then asserted to have `flate.max_window_len`
  78/// capacity.
  79pub fn init(input: *Reader, container: Container, buffer: []u8) Decompress {
  80    if (buffer.len != 0) assert(buffer.len >= flate.max_window_len);
  81    return .{
  82        .reader = .{
  83            .vtable = if (buffer.len == 0) &direct_vtable else &indirect_vtable,
  84            .buffer = buffer,
  85            .seek = 0,
  86            .end = 0,
  87        },
  88        .input = input,
  89        .consumed_bits = 0,
  90        .container_metadata = .init(container),
  91        .lit_dec = .{},
  92        .dst_dec = .{},
  93        .final_block = false,
  94        .state = .protocol_header,
  95        .err = null,
  96    };
  97}
  98
  99fn rebaseFallible(r: *Reader, capacity: usize) Reader.RebaseError!void {
 100    rebase(r, capacity);
 101}
 102
 103fn rebase(r: *Reader, capacity: usize) void {
 104    assert(capacity <= r.buffer.len - flate.history_len);
 105    assert(r.end + capacity > r.buffer.len);
 106    const discard_n = @min(r.seek, r.end - flate.history_len);
 107    const keep = r.buffer[discard_n..r.end];
 108    @memmove(r.buffer[0..keep.len], keep);
 109    r.end = keep.len;
 110    r.seek -= discard_n;
 111}
 112
 113/// This could be improved so that when an amount is discarded that includes an
 114/// entire frame, skip decoding that frame.
 115fn discardDirect(r: *Reader, limit: std.Io.Limit) Reader.Error!usize {
 116    if (r.end + flate.history_len > r.buffer.len) rebase(r, flate.history_len);
 117    var writer: Writer = .{
 118        .vtable = &.{
 119            .drain = std.Io.Writer.Discarding.drain,
 120            .sendFile = std.Io.Writer.Discarding.sendFile,
 121        },
 122        .buffer = r.buffer,
 123        .end = r.end,
 124    };
 125    defer {
 126        assert(writer.end != 0);
 127        r.end = writer.end;
 128        r.seek = r.end;
 129    }
 130    const n = r.stream(&writer, limit) catch |err| switch (err) {
 131        error.WriteFailed => unreachable,
 132        error.ReadFailed => return error.ReadFailed,
 133        error.EndOfStream => return error.EndOfStream,
 134    };
 135    assert(n <= @intFromEnum(limit));
 136    return n;
 137}
 138
 139fn discardIndirect(r: *Reader, limit: std.Io.Limit) Reader.Error!usize {
 140    const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
 141    if (r.end + flate.history_len > r.buffer.len) rebase(r, flate.history_len);
 142    var writer: Writer = .{
 143        .buffer = r.buffer,
 144        .end = r.end,
 145        .vtable = &.{ .drain = Writer.unreachableDrain },
 146    };
 147    {
 148        defer r.end = writer.end;
 149        _ = streamFallible(d, &writer, .limited(writer.buffer.len - writer.end)) catch |err| switch (err) {
 150            error.WriteFailed => unreachable,
 151            else => |e| return e,
 152        };
 153    }
 154    const n = limit.minInt(r.end - r.seek);
 155    r.seek += n;
 156    return n;
 157}
 158
 159fn readVec(r: *Reader, data: [][]u8) Reader.Error!usize {
 160    _ = data;
 161    const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
 162    return streamIndirectInner(d);
 163}
 164
 165fn streamIndirectInner(d: *Decompress) Reader.Error!usize {
 166    const r = &d.reader;
 167    if (r.buffer.len - r.end < flate.history_len) rebase(r, flate.history_len);
 168    var writer: Writer = .{
 169        .buffer = r.buffer,
 170        .end = r.end,
 171        .vtable = &.{
 172            .drain = Writer.unreachableDrain,
 173            .rebase = Writer.unreachableRebase,
 174        },
 175    };
 176    defer r.end = writer.end;
 177    _ = streamFallible(d, &writer, .limited(writer.buffer.len - writer.end)) catch |err| switch (err) {
 178        error.WriteFailed => unreachable,
 179        else => |e| return e,
 180    };
 181    return 0;
 182}
 183
 184fn decodeLength(self: *Decompress, code_int: u5) !u16 {
 185    if (code_int > 28) return error.InvalidCode;
 186    const l: token.LenCode = .fromInt(code_int);
 187    const base = l.base();
 188    const extra = l.extraBits();
 189    return token.min_length + (base | try self.takeBits(extra));
 190}
 191
 192fn decodeDistance(self: *Decompress, code_int: u5) !u16 {
 193    if (code_int > 29) return error.InvalidCode;
 194    const d: token.DistCode = .fromInt(code_int);
 195    const base = d.base();
 196    const extra = d.extraBits();
 197    return token.min_distance + (base | try self.takeBits(extra));
 198}
 199
 200/// Decode code length symbol to code length. Writes decoded length into
 201/// lens slice starting at position pos. Returns number of positions
 202/// advanced.
 203fn dynamicCodeLength(self: *Decompress, code: u16, lens: []u4, pos: usize) !usize {
 204    if (pos >= lens.len)
 205        return error.InvalidDynamicBlockHeader;
 206
 207    switch (code) {
 208        0...15 => {
 209            // Represent code lengths of 0 - 15
 210            lens[pos] = @intCast(code);
 211            return 1;
 212        },
 213        16 => {
 214            // Copy the previous code length 3 - 6 times.
 215            // The next 2 bits indicate repeat length
 216            const n: u8 = @as(u8, try self.takeIntBits(u2)) + 3;
 217            if (pos == 0 or pos + n > lens.len)
 218                return error.InvalidDynamicBlockHeader;
 219            for (0..n) |i| {
 220                lens[pos + i] = lens[pos + i - 1];
 221            }
 222            return n;
 223        },
 224        // Repeat a code length of 0 for 3 - 10 times. (3 bits of length)
 225        17 => return @as(u8, try self.takeIntBits(u3)) + 3,
 226        // Repeat a code length of 0 for 11 - 138 times (7 bits of length)
 227        18 => return @as(u8, try self.takeIntBits(u7)) + 11,
 228        else => return error.InvalidDynamicBlockHeader,
 229    }
 230}
 231
 232fn decodeSymbol(self: *Decompress, decoder: anytype) !Symbol {
 233    // Maximum code len is 15 bits.
 234    const sym = try decoder.find(@bitReverse(try self.peekIntBitsShort(u15)));
 235    try self.tossBitsShort(sym.code_bits);
 236    return sym;
 237}
 238
 239fn streamDirect(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
 240    const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
 241    return streamFallible(d, w, limit);
 242}
 243
 244fn streamIndirect(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
 245    const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
 246    _ = limit;
 247    _ = w;
 248    return streamIndirectInner(d);
 249}
 250
 251fn streamFallible(d: *Decompress, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
 252    return streamInner(d, w, limit) catch |err| switch (err) {
 253        error.EndOfStream => {
 254            if (d.state == .end) {
 255                return error.EndOfStream;
 256            } else {
 257                d.err = error.EndOfStream;
 258                return error.ReadFailed;
 259            }
 260        },
 261        error.WriteFailed => return error.WriteFailed,
 262        else => |e| {
 263            // In the event of an error, state is unmodified so that it can be
 264            // better used to diagnose the failure.
 265            d.err = e;
 266            return error.ReadFailed;
 267        },
 268    };
 269}
 270
 271fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader.StreamError)!usize {
 272    var remaining = @intFromEnum(limit);
 273    const in = d.input;
 274    sw: switch (d.state) {
 275        .protocol_header => switch (d.container_metadata.container()) {
 276            .gzip => {
 277                const Header = extern struct {
 278                    magic: u16 align(1),
 279                    method: u8,
 280                    flags: packed struct(u8) {
 281                        text: bool,
 282                        hcrc: bool,
 283                        extra: bool,
 284                        name: bool,
 285                        comment: bool,
 286                        reserved: u3,
 287                    },
 288                    mtime: u32 align(1),
 289                    xfl: u8,
 290                    os: u8,
 291                };
 292                const header = try in.takeStruct(Header, .little);
 293                if (header.magic != 0x8b1f or header.method != 0x08)
 294                    return error.BadGzipHeader;
 295                if (header.flags.extra) {
 296                    const extra_len = try in.takeInt(u16, .little);
 297                    try in.discardAll(extra_len);
 298                }
 299                if (header.flags.name) {
 300                    _ = try in.discardDelimiterInclusive(0);
 301                }
 302                if (header.flags.comment) {
 303                    _ = try in.discardDelimiterInclusive(0);
 304                }
 305                if (header.flags.hcrc) {
 306                    try in.discardAll(2);
 307                }
 308                continue :sw .block_header;
 309            },
 310            .zlib => {
 311                const header = try in.takeArray(2);
 312                const cmf: packed struct(u8) { cm: u4, cinfo: u4 } = @bitCast(header[0]);
 313                if (cmf.cm != 8 or cmf.cinfo > 7) return error.BadZlibHeader;
 314                continue :sw .block_header;
 315            },
 316            .raw => continue :sw .block_header,
 317        },
 318        .block_header => {
 319            d.final_block = (try d.takeIntBits(u1)) != 0;
 320            const block_type: BlockType = @enumFromInt(try d.takeIntBits(u2));
 321            switch (block_type) {
 322                .stored => {
 323                    d.alignBitsForward();
 324                    // everything after this is byte aligned in stored block
 325                    const len = try in.takeInt(u16, .little);
 326                    const nlen = try in.takeInt(u16, .little);
 327                    if (len != ~nlen) return error.WrongStoredBlockNlen;
 328                    continue :sw .{ .stored_block = len };
 329                },
 330                .fixed => continue :sw .fixed_block,
 331                .dynamic => {
 332                    const hlit: u16 = @as(u16, try d.takeIntBits(u5)) + 257; // number of ll code entries present - 257
 333                    const hdist: u16 = @as(u16, try d.takeIntBits(u5)) + 1; // number of distance code entries - 1
 334                    const hclen: u8 = @as(u8, try d.takeIntBits(u4)) + 4; // hclen + 4 code lengths are encoded
 335
 336                    if (hlit > 286 or hdist > 30)
 337                        return error.InvalidDynamicBlockHeader;
 338
 339                    // lengths for code lengths
 340                    var cl_lens: [19]u4 = @splat(0);
 341                    for (token.codegen_order[0..hclen]) |i| {
 342                        cl_lens[i] = try d.takeIntBits(u3);
 343                    }
 344                    var cl_dec: CodegenDecoder = .{};
 345                    try cl_dec.generate(&cl_lens);
 346
 347                    // decoded code lengths
 348                    var dec_lens: [286 + 30]u4 = @splat(0);
 349                    var pos: usize = 0;
 350                    while (pos < hlit + hdist) {
 351                        const peeked = @bitReverse(try d.peekIntBitsShort(u7));
 352                        const sym = try cl_dec.find(peeked);
 353                        try d.tossBitsShort(sym.code_bits);
 354                        pos += try d.dynamicCodeLength(sym.symbol, &dec_lens, pos);
 355                    }
 356                    if (pos > hlit + hdist) {
 357                        return error.InvalidDynamicBlockHeader;
 358                    }
 359
 360                    // literal code lengths to literal decoder
 361                    try d.lit_dec.generate(dec_lens[0..hlit]);
 362
 363                    // distance code lengths to distance decoder
 364                    try d.dst_dec.generate(dec_lens[hlit..][0..hdist]);
 365
 366                    continue :sw .dynamic_block;
 367                },
 368                .invalid => return error.InvalidBlockType,
 369            }
 370        },
 371        .stored_block => |remaining_len| {
 372            const out: []u8 = if (remaining != 0)
 373                try w.writableSliceGreedyPreserve(flate.history_len, 1)
 374            else
 375                &.{};
 376            var limited_out: [1][]u8 = .{limit.min(.limited(remaining_len)).slice(out)};
 377            const n = try in.readVec(&limited_out);
 378            if (remaining_len - n == 0) {
 379                d.state = if (d.final_block) .protocol_footer else .block_header;
 380            } else {
 381                d.state = .{ .stored_block = @intCast(remaining_len - n) };
 382            }
 383            w.advance(n);
 384            return @intFromEnum(limit) - remaining + n;
 385        },
 386        .fixed_block => {
 387            while (remaining > 0) {
 388                const code = try d.readFixedCode();
 389                switch (code) {
 390                    0...255 => {
 391                        if (remaining != 0) {
 392                            @branchHint(.likely);
 393                            try w.writeBytePreserve(flate.history_len, @intCast(code));
 394                            remaining -= 1;
 395                        } else {
 396                            d.state = .{ .fixed_block_literal = @intCast(code) };
 397                            return @intFromEnum(limit) - remaining;
 398                        }
 399                    },
 400                    256 => {
 401                        d.state = if (d.final_block) .protocol_footer else .block_header;
 402                        return @intFromEnum(limit) - remaining;
 403                    },
 404                    257...285 => {
 405                        // Handles fixed block non literal (length) code.
 406                        // Length code is followed by 5 bits of distance code.
 407                        const length = try d.decodeLength(@intCast(code - 257));
 408                        continue :sw .{ .fixed_block_match = length };
 409                    },
 410                    else => return error.InvalidCode,
 411                }
 412            }
 413            d.state = .fixed_block;
 414            return @intFromEnum(limit) - remaining;
 415        },
 416        .fixed_block_literal => |symbol| {
 417            assert(remaining != 0);
 418            remaining -= 1;
 419            try w.writeBytePreserve(flate.history_len, symbol);
 420            continue :sw .fixed_block;
 421        },
 422        .fixed_block_match => |length| {
 423            if (remaining >= length) {
 424                @branchHint(.likely);
 425                const distance = try d.decodeDistance(@bitReverse(try d.takeIntBits(u5)));
 426                try writeMatch(w, length, distance);
 427                remaining -= length;
 428                continue :sw .fixed_block;
 429            } else {
 430                d.state = .{ .fixed_block_match = length };
 431                return @intFromEnum(limit) - remaining;
 432            }
 433        },
 434        .dynamic_block => {
 435            // In larger archives most blocks are usually dynamic, so
 436            // decompression performance depends on this logic.
 437            var sym = try d.decodeSymbol(&d.lit_dec);
 438            sym: switch (sym.kind) {
 439                .literal => {
 440                    if (remaining != 0) {
 441                        @branchHint(.likely);
 442                        remaining -= 1;
 443                        try w.writeBytePreserve(flate.history_len, sym.symbol);
 444                        sym = try d.decodeSymbol(&d.lit_dec);
 445                        continue :sym sym.kind;
 446                    } else {
 447                        d.state = .{ .dynamic_block_literal = sym.symbol };
 448                        return @intFromEnum(limit) - remaining;
 449                    }
 450                },
 451                .match => {
 452                    // Decode match backreference <length, distance>
 453                    const length = try d.decodeLength(@intCast(sym.symbol));
 454                    continue :sw .{ .dynamic_block_match = length };
 455                },
 456                .end_of_block => {
 457                    d.state = if (d.final_block) .protocol_footer else .block_header;
 458                    continue :sw d.state;
 459                },
 460            }
 461        },
 462        .dynamic_block_literal => |symbol| {
 463            assert(remaining != 0);
 464            remaining -= 1;
 465            try w.writeBytePreserve(flate.history_len, symbol);
 466            continue :sw .dynamic_block;
 467        },
 468        .dynamic_block_match => |length| {
 469            if (remaining >= length) {
 470                @branchHint(.likely);
 471                remaining -= length;
 472                const dsm = try d.decodeSymbol(&d.dst_dec);
 473                const distance = try d.decodeDistance(@intCast(dsm.symbol));
 474                try writeMatch(w, length, distance);
 475                continue :sw .dynamic_block;
 476            } else {
 477                d.state = .{ .dynamic_block_match = length };
 478                return @intFromEnum(limit) - remaining;
 479            }
 480        },
 481        .protocol_footer => {
 482            d.alignBitsForward();
 483            switch (d.container_metadata) {
 484                .gzip => |*gzip| {
 485                    gzip.crc = try in.takeInt(u32, .little);
 486                    gzip.count = try in.takeInt(u32, .little);
 487                },
 488                .zlib => |*zlib| {
 489                    zlib.adler = try in.takeInt(u32, .big);
 490                },
 491                .raw => {},
 492            }
 493            d.state = .end;
 494            return @intFromEnum(limit) - remaining;
 495        },
 496        .end => return error.EndOfStream,
 497    }
 498}
 499
 500/// Write match (back-reference to the same data slice) starting at `distance`
 501/// back from current write position, and `length` of bytes.
 502fn writeMatch(w: *Writer, length: u16, distance: u16) !void {
 503    if (w.end < distance) return error.InvalidMatch;
 504    if (length < token.min_length) return error.InvalidMatch;
 505    if (length > token.max_length) return error.InvalidMatch;
 506    if (distance < token.min_distance) return error.InvalidMatch;
 507    if (distance > token.max_distance) return error.InvalidMatch;
 508
 509    // This is not a @memmove; it intentionally repeats patterns caused by
 510    // iterating one byte at a time.
 511    const dest = try w.writableSlicePreserve(flate.history_len, length);
 512    const end = dest.ptr - w.buffer.ptr;
 513    const src = w.buffer[end - distance ..][0..length];
 514    for (dest, src) |*d, s| d.* = s;
 515}
 516
 517fn peekBits(d: *Decompress, n: u4) !u16 {
 518    const bits = d.input.peekInt(u32, .little) catch |e| return switch (e) {
 519        error.ReadFailed => error.ReadFailed,
 520        error.EndOfStream => d.peekBitsEnding(n),
 521    };
 522    const mask = @shlExact(@as(u16, 1), n) - 1;
 523    return @intCast((bits >> d.consumed_bits) & mask);
 524}
 525
 526fn peekBitsEnding(d: *Decompress, n: u4) !u16 {
 527    @branchHint(.unlikely);
 528
 529    const left = d.input.buffered();
 530    if (left.len * 8 - d.consumed_bits < n) return error.EndOfStream;
 531    const bits = std.mem.readVarInt(u32, left, .little);
 532    const mask = @shlExact(@as(u16, 1), n) - 1;
 533    return @intCast((bits >> d.consumed_bits) & mask);
 534}
 535
 536/// Safe only after `peekBits` has been called with a greater or equal `n` value.
 537fn tossBits(d: *Decompress, n: u4) void {
 538    d.input.toss((@as(u8, n) + d.consumed_bits) / 8);
 539    d.consumed_bits +%= @truncate(n);
 540}
 541
 542fn takeBits(d: *Decompress, n: u4) !u16 {
 543    const bits = try d.peekBits(n);
 544    d.tossBits(n);
 545    return bits;
 546}
 547
 548fn alignBitsForward(d: *Decompress) void {
 549    d.input.toss(@intFromBool(d.consumed_bits != 0));
 550    d.consumed_bits = 0;
 551}
 552
 553fn peekBitsShort(d: *Decompress, n: u4) !u16 {
 554    const bits = d.input.peekInt(u32, .little) catch |e| return switch (e) {
 555        error.ReadFailed => error.ReadFailed,
 556        error.EndOfStream => d.peekBitsShortEnding(n),
 557    };
 558    const mask = @shlExact(@as(u16, 1), n) - 1;
 559    return @intCast((bits >> d.consumed_bits) & mask);
 560}
 561
 562fn peekBitsShortEnding(d: *Decompress, n: u4) !u16 {
 563    @branchHint(.unlikely);
 564
 565    const left = d.input.buffered();
 566    const bits = std.mem.readVarInt(u32, left, .little);
 567    const mask = @shlExact(@as(u16, 1), n) - 1;
 568    return @intCast((bits >> d.consumed_bits) & mask);
 569}
 570
 571fn tossBitsShort(d: *Decompress, n: u4) !void {
 572    if (d.input.bufferedLen() * 8 + d.consumed_bits < n) return error.EndOfStream;
 573    d.tossBits(n);
 574}
 575
 576fn takeIntBits(d: *Decompress, T: type) !T {
 577    return @intCast(try d.takeBits(@bitSizeOf(T)));
 578}
 579
 580fn peekIntBitsShort(d: *Decompress, T: type) !T {
 581    return @intCast(try d.peekBitsShort(@bitSizeOf(T)));
 582}
 583
 584/// Reads first 7 bits, and then maybe 1 or 2 more to get full 7,8 or 9 bit code.
 585/// ref: https://datatracker.ietf.org/doc/html/rfc1951#page-12
 586///         Lit Value    Bits        Codes
 587///          ---------    ----        -----
 588///            0 - 143     8          00110000 through
 589///                                   10111111
 590///          144 - 255     9          110010000 through
 591///                                   111111111
 592///          256 - 279     7          0000000 through
 593///                                   0010111
 594///          280 - 287     8          11000000 through
 595///                                   11000111
 596fn readFixedCode(d: *Decompress) !u16 {
 597    const code7 = @bitReverse(try d.takeIntBits(u7));
 598    return switch (code7) {
 599        0...0b0010_111 => @as(u16, code7) + 256,
 600        0b0010_111 + 1...0b1011_111 => (@as(u16, code7) << 1) + @as(u16, try d.takeIntBits(u1)) - 0b0011_0000,
 601        0b1011_111 + 1...0b1100_011 => (@as(u16, code7 - 0b1100000) << 1) + try d.takeIntBits(u1) + 280,
 602        else => (@as(u16, code7 - 0b1100_100) << 2) + @as(u16, @bitReverse(try d.takeIntBits(u2))) + 144,
 603    };
 604}
 605
 606pub const Symbol = packed struct {
 607    pub const Kind = enum(u2) {
 608        literal,
 609        end_of_block,
 610        match,
 611    };
 612
 613    symbol: u8 = 0, // symbol from alphabet
 614    code_bits: u4 = 0, // number of bits in code 0-15
 615    kind: Kind = .literal,
 616
 617    code: u16 = 0, // huffman code of the symbol
 618    next: u16 = 0, // pointer to the next symbol in linked list
 619    // it is safe to use 0 as null pointer, when sorted 0 has shortest code and fits into lookup
 620
 621    // Sorting less than function.
 622    pub fn asc(_: void, a: Symbol, b: Symbol) bool {
 623        if (a.code_bits == b.code_bits) {
 624            if (a.kind == b.kind) {
 625                return a.symbol < b.symbol;
 626            }
 627            return @intFromEnum(a.kind) < @intFromEnum(b.kind);
 628        }
 629        return a.code_bits < b.code_bits;
 630    }
 631};
 632
 633pub const LiteralDecoder = HuffmanDecoder(286, 15, 9);
 634pub const DistanceDecoder = HuffmanDecoder(30, 15, 9);
 635pub const CodegenDecoder = HuffmanDecoder(19, 7, 7);
 636
 637/// Creates huffman tree codes from list of code lengths (in `build`).
 638///
 639/// `find` then finds symbol for code bits. Code can be any length between 1 and
 640/// 15 bits. When calling `find` we don't know how many bits will be used to
 641/// find symbol. When symbol is returned it has code_bits field which defines
 642/// how much we should advance in bit stream.
 643///
 644/// Lookup table is used to map 15 bit int to symbol. Same symbol is written
 645/// many times in this table; 32K places for 286 (at most) symbols.
 646/// Small lookup table is optimization for faster search.
 647/// It is variation of the algorithm explained in [zlib](https://github.com/madler/zlib/blob/643e17b7498d12ab8d15565662880579692f769d/doc/algorithm.txt#L92)
 648/// with difference that we here use statically allocated arrays.
 649///
 650fn HuffmanDecoder(
 651    comptime alphabet_size: u16,
 652    comptime max_code_bits: u4,
 653    comptime lookup_bits: u4,
 654) type {
 655    const lookup_shift = max_code_bits - lookup_bits;
 656
 657    return struct {
 658        // all symbols in alaphabet, sorted by code_len, symbol
 659        symbols: [alphabet_size]Symbol = undefined,
 660        // lookup table code -> symbol
 661        lookup: [1 << lookup_bits]Symbol = undefined,
 662
 663        const Self = @This();
 664
 665        /// Generates symbols and lookup tables from list of code lens for each symbol.
 666        pub fn generate(self: *Self, lens: []const u4) !void {
 667            try checkCompleteness(lens);
 668
 669            // init alphabet with code_bits
 670            for (self.symbols, 0..) |_, i| {
 671                const cb: u4 = if (i < lens.len) lens[i] else 0;
 672                self.symbols[i] = if (i < 256)
 673                    .{ .kind = .literal, .symbol = @intCast(i), .code_bits = cb }
 674                else if (i == 256)
 675                    .{ .kind = .end_of_block, .symbol = 0xff, .code_bits = cb }
 676                else
 677                    .{ .kind = .match, .symbol = @intCast(i - 257), .code_bits = cb };
 678            }
 679            std.sort.heap(Symbol, &self.symbols, {}, Symbol.asc);
 680
 681            // reset lookup table
 682            for (0..self.lookup.len) |i| {
 683                self.lookup[i] = .{};
 684            }
 685
 686            // assign code to symbols
 687            // reference: https://youtu.be/9_YEGLe33NA?list=PLU4IQLU9e_OrY8oASHx0u3IXAL9TOdidm&t=2639
 688            var code: u16 = 0;
 689            var idx: u16 = 0;
 690            for (&self.symbols, 0..) |*sym, pos| {
 691                if (sym.code_bits == 0) continue; // skip unused
 692                sym.code = code;
 693
 694                const next_code = code + (@as(u16, 1) << (max_code_bits - sym.code_bits));
 695                const next_idx = next_code >> lookup_shift;
 696
 697                if (next_idx > self.lookup.len or idx >= self.lookup.len) break;
 698                if (sym.code_bits <= lookup_bits) {
 699                    // fill small lookup table
 700                    for (idx..next_idx) |j|
 701                        self.lookup[j] = sym.*;
 702                } else {
 703                    // insert into linked table starting at root
 704                    const root = &self.lookup[idx];
 705                    const root_next = root.next;
 706                    root.next = @intCast(pos);
 707                    sym.next = root_next;
 708                }
 709
 710                idx = next_idx;
 711                code = next_code;
 712            }
 713        }
 714
 715        /// Given the list of code lengths check that it represents a canonical
 716        /// Huffman code for n symbols.
 717        ///
 718        /// Reference: https://github.com/madler/zlib/blob/5c42a230b7b468dff011f444161c0145b5efae59/contrib/puff/puff.c#L340
 719        fn checkCompleteness(lens: []const u4) !void {
 720            if (alphabet_size == 286)
 721                if (lens[256] == 0) return error.MissingEndOfBlockCode;
 722
 723            var count = [_]u16{0} ** (@as(usize, max_code_bits) + 1);
 724            var max: usize = 0;
 725            for (lens) |n| {
 726                if (n == 0) continue;
 727                if (n > max) max = n;
 728                count[n] += 1;
 729            }
 730            if (max == 0) // empty tree
 731                return;
 732
 733            // check for an over-subscribed or incomplete set of lengths
 734            var left: usize = 1; // one possible code of zero length
 735            for (1..count.len) |len| {
 736                left <<= 1; // one more bit, double codes left
 737                if (count[len] > left)
 738                    return error.OversubscribedHuffmanTree;
 739                left -= count[len]; // deduct count from possible codes
 740            }
 741            if (left > 0) { // left > 0 means incomplete
 742                // incomplete code ok only for single length 1 code
 743                if (max_code_bits > 7 and max == count[0] + count[1]) return;
 744                return error.IncompleteHuffmanTree;
 745            }
 746        }
 747
 748        /// Finds symbol for lookup table code.
 749        pub fn find(self: *Self, code: u16) !Symbol {
 750            // try to find in lookup table
 751            const idx = code >> lookup_shift;
 752            const sym = self.lookup[idx];
 753            if (sym.code_bits != 0) return sym;
 754            // if not use linked list of symbols with same prefix
 755            return self.findLinked(code, sym.next);
 756        }
 757
 758        fn findLinked(self: *Self, code: u16, start: u16) !Symbol {
 759            var pos = start;
 760            while (pos > 0) {
 761                const sym = self.symbols[pos];
 762                const shift = max_code_bits - sym.code_bits;
 763                // compare code_bits number of upper bits
 764                if ((code ^ sym.code) >> shift == 0) return sym;
 765                pos = sym.next;
 766            }
 767            return error.InvalidCode;
 768        }
 769    };
 770}
 771
 772test "init/find" {
 773    // example data from: https://youtu.be/SJPvNi4HrWQ?t=8423
 774    const code_lens = [_]u4{ 4, 3, 0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 2 };
 775    var h: CodegenDecoder = .{};
 776    try h.generate(&code_lens);
 777
 778    const expected = [_]struct {
 779        sym: Symbol,
 780        code: u16,
 781    }{
 782        .{
 783            .code = 0b00_00000,
 784            .sym = .{ .symbol = 3, .code_bits = 2 },
 785        },
 786        .{
 787            .code = 0b01_00000,
 788            .sym = .{ .symbol = 18, .code_bits = 2 },
 789        },
 790        .{
 791            .code = 0b100_0000,
 792            .sym = .{ .symbol = 1, .code_bits = 3 },
 793        },
 794        .{
 795            .code = 0b101_0000,
 796            .sym = .{ .symbol = 4, .code_bits = 3 },
 797        },
 798        .{
 799            .code = 0b110_0000,
 800            .sym = .{ .symbol = 17, .code_bits = 3 },
 801        },
 802        .{
 803            .code = 0b1110_000,
 804            .sym = .{ .symbol = 0, .code_bits = 4 },
 805        },
 806        .{
 807            .code = 0b1111_000,
 808            .sym = .{ .symbol = 16, .code_bits = 4 },
 809        },
 810    };
 811
 812    // unused symbols
 813    for (0..12) |i| {
 814        try testing.expectEqual(0, h.symbols[i].code_bits);
 815    }
 816    // used, from index 12
 817    for (expected, 12..) |e, i| {
 818        try testing.expectEqual(e.sym.symbol, h.symbols[i].symbol);
 819        try testing.expectEqual(e.sym.code_bits, h.symbols[i].code_bits);
 820        const sym_from_code = try h.find(e.code);
 821        try testing.expectEqual(e.sym.symbol, sym_from_code.symbol);
 822    }
 823
 824    // All possible codes for each symbol.
 825    // Lookup table has 126 elements, to cover all possible 7 bit codes.
 826    for (0b0000_000..0b0100_000) |c| // 0..32 (32)
 827        try testing.expectEqual(3, (try h.find(@intCast(c))).symbol);
 828
 829    for (0b0100_000..0b1000_000) |c| // 32..64 (32)
 830        try testing.expectEqual(18, (try h.find(@intCast(c))).symbol);
 831
 832    for (0b1000_000..0b1010_000) |c| // 64..80 (16)
 833        try testing.expectEqual(1, (try h.find(@intCast(c))).symbol);
 834
 835    for (0b1010_000..0b1100_000) |c| // 80..96 (16)
 836        try testing.expectEqual(4, (try h.find(@intCast(c))).symbol);
 837
 838    for (0b1100_000..0b1110_000) |c| // 96..112 (16)
 839        try testing.expectEqual(17, (try h.find(@intCast(c))).symbol);
 840
 841    for (0b1110_000..0b1111_000) |c| // 112..120 (8)
 842        try testing.expectEqual(0, (try h.find(@intCast(c))).symbol);
 843
 844    for (0b1111_000..0b1_0000_000) |c| // 120...128 (8)
 845        try testing.expectEqual(16, (try h.find(@intCast(c))).symbol);
 846}
 847
 848test "encode/decode literals" {
 849    // Check that the example in RFC 1951 section 3.2.2 works (plus some zeroes)
 850    const max_bits = 5;
 851    var decoder: HuffmanDecoder(16, max_bits, 3) = .{};
 852    try decoder.generate(&.{ 3, 3, 3, 3, 0, 0, 3, 2, 4, 4 });
 853
 854    inline for (0.., .{
 855        @as(u3, 0b010),
 856        @as(u3, 0b011),
 857        @as(u3, 0b100),
 858        @as(u3, 0b101),
 859        @as(u0, 0),
 860        @as(u0, 0),
 861        @as(u3, 0b110),
 862        @as(u2, 0b00),
 863        @as(u4, 0b1110),
 864        @as(u4, 0b1111),
 865    }) |i, code| {
 866        const bits = @bitSizeOf(@TypeOf(code));
 867        if (bits == 0) continue;
 868        for (0..1 << (max_bits - bits)) |extra| {
 869            const full = (@as(u16, code) << (max_bits - bits)) | @as(u16, @intCast(extra));
 870            const symbol = try decoder.find(full);
 871            try testing.expectEqual(i, symbol.symbol);
 872            try testing.expectEqual(bits, symbol.code_bits);
 873        }
 874    }
 875}
 876
 877test "non compressed block (type 0)" {
 878    try testDecompress(.raw, &[_]u8{
 879        0b0000_0001, 0b0000_1100, 0x00, 0b1111_0011, 0xff, // deflate fixed buffer header len, nlen
 880        'H', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', 0x0a, // non compressed data
 881    }, "Hello world\n");
 882}
 883
 884test "fixed code block (type 1)" {
 885    try testDecompress(.raw, &[_]u8{
 886        0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x57, 0x28, 0xcf, // deflate data block type 1
 887        0x2f, 0xca, 0x49, 0xe1, 0x02, 0x00,
 888    }, "Hello world\n");
 889}
 890
 891test "dynamic block (type 2)" {
 892    try testDecompress(.raw, &[_]u8{
 893        0x3d, 0xc6, 0x39, 0x11, 0x00, 0x00, 0x0c, 0x02, // deflate data block type 2
 894        0x30, 0x2b, 0xb5, 0x52, 0x1e, 0xff, 0x96, 0x38,
 895        0x16, 0x96, 0x5c, 0x1e, 0x94, 0xcb, 0x6d, 0x01,
 896    }, "ABCDEABCD ABCDEABCD");
 897}
 898
 899test "gzip non compressed block (type 0)" {
 900    try testDecompress(.gzip, &[_]u8{
 901        0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, // gzip header (10 bytes)
 902        0b0000_0001, 0b0000_1100, 0x00, 0b1111_0011, 0xff, // deflate fixed buffer header len, nlen
 903        'H', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', 0x0a, // non compressed data
 904        0xd5, 0xe0, 0x39, 0xb7, // gzip footer: checksum
 905        0x0c, 0x00, 0x00, 0x00, // gzip footer: size
 906    }, "Hello world\n");
 907}
 908
 909test "gzip fixed code block (type 1)" {
 910    try testDecompress(.gzip, &[_]u8{
 911        0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x03, // gzip header (10 bytes)
 912        0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x57, 0x28, 0xcf, // deflate data block type 1
 913        0x2f, 0xca, 0x49, 0xe1, 0x02, 0x00,
 914        0xd5, 0xe0, 0x39, 0xb7, 0x0c, 0x00, 0x00, 0x00, // gzip footer (chksum, len)
 915    }, "Hello world\n");
 916}
 917
 918test "gzip dynamic block (type 2)" {
 919    try testDecompress(.gzip, &[_]u8{
 920        0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, // gzip header (10 bytes)
 921        0x3d, 0xc6, 0x39, 0x11, 0x00, 0x00, 0x0c, 0x02, // deflate data block type 2
 922        0x30, 0x2b, 0xb5, 0x52, 0x1e, 0xff, 0x96, 0x38,
 923        0x16, 0x96, 0x5c, 0x1e, 0x94, 0xcb, 0x6d, 0x01,
 924        0x17, 0x1c, 0x39, 0xb4, 0x13, 0x00, 0x00, 0x00, // gzip footer (chksum, len)
 925    }, "ABCDEABCD ABCDEABCD");
 926}
 927
 928test "gzip header with name" {
 929    try testDecompress(.gzip, &[_]u8{
 930        0x1f, 0x8b, 0x08, 0x08, 0xe5, 0x70, 0xb1, 0x65, 0x00, 0x03, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x2e,
 931        0x74, 0x78, 0x74, 0x00, 0xf3, 0x48, 0xcd, 0xc9, 0xc9, 0x57, 0x28, 0xcf, 0x2f, 0xca, 0x49, 0xe1,
 932        0x02, 0x00, 0xd5, 0xe0, 0x39, 0xb7, 0x0c, 0x00, 0x00, 0x00,
 933    }, "Hello world\n");
 934}
 935
 936test "zlib decompress non compressed block (type 0)" {
 937    try testDecompress(.zlib, &[_]u8{
 938        0x78, 0b10_0_11100, // zlib header (2 bytes)
 939        0b0000_0001, 0b0000_1100, 0x00, 0b1111_0011, 0xff, // deflate fixed buffer header len, nlen
 940        'H', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', 0x0a, // non compressed data
 941        0x1c, 0xf2, 0x04, 0x47, // zlib footer: checksum
 942    }, "Hello world\n");
 943}
 944
 945test "failing end-of-stream" {
 946    try testFailure(.raw, @embedFile("testdata/fuzz/end-of-stream.input"), error.EndOfStream);
 947}
 948test "failing invalid-distance" {
 949    try testFailure(.raw, @embedFile("testdata/fuzz/invalid-distance.input"), error.InvalidMatch);
 950}
 951test "failing invalid-tree01" {
 952    try testFailure(.raw, @embedFile("testdata/fuzz/invalid-tree01.input"), error.IncompleteHuffmanTree);
 953}
 954test "failing invalid-tree02" {
 955    try testFailure(.raw, @embedFile("testdata/fuzz/invalid-tree02.input"), error.IncompleteHuffmanTree);
 956}
 957test "failing invalid-tree03" {
 958    try testFailure(.raw, @embedFile("testdata/fuzz/invalid-tree03.input"), error.IncompleteHuffmanTree);
 959}
 960test "failing lengths-overflow" {
 961    try testFailure(.raw, @embedFile("testdata/fuzz/lengths-overflow.input"), error.InvalidDynamicBlockHeader);
 962}
 963test "failing out-of-codes" {
 964    try testFailure(.raw, @embedFile("testdata/fuzz/out-of-codes.input"), error.InvalidCode);
 965}
 966test "failing puff01" {
 967    try testFailure(.raw, @embedFile("testdata/fuzz/puff01.input"), error.WrongStoredBlockNlen);
 968}
 969test "failing puff02" {
 970    try testFailure(.raw, @embedFile("testdata/fuzz/puff02.input"), error.EndOfStream);
 971}
 972test "failing puff04" {
 973    try testFailure(.raw, @embedFile("testdata/fuzz/puff04.input"), error.InvalidCode);
 974}
 975test "failing puff05" {
 976    try testFailure(.raw, @embedFile("testdata/fuzz/puff05.input"), error.EndOfStream);
 977}
 978test "failing puff06" {
 979    try testFailure(.raw, @embedFile("testdata/fuzz/puff06.input"), error.EndOfStream);
 980}
 981test "failing puff08" {
 982    try testFailure(.raw, @embedFile("testdata/fuzz/puff08.input"), error.InvalidCode);
 983}
 984test "failing puff10" {
 985    try testFailure(.raw, @embedFile("testdata/fuzz/puff10.input"), error.InvalidCode);
 986}
 987test "failing puff11" {
 988    try testFailure(.raw, @embedFile("testdata/fuzz/puff11.input"), error.InvalidMatch);
 989}
 990test "failing puff12" {
 991    try testFailure(.raw, @embedFile("testdata/fuzz/puff12.input"), error.InvalidDynamicBlockHeader);
 992}
 993test "failing puff13" {
 994    try testFailure(.raw, @embedFile("testdata/fuzz/puff13.input"), error.IncompleteHuffmanTree);
 995}
 996test "failing puff14" {
 997    try testFailure(.raw, @embedFile("testdata/fuzz/puff14.input"), error.EndOfStream);
 998}
 999test "failing puff15" {
1000    try testFailure(.raw, @embedFile("testdata/fuzz/puff15.input"), error.IncompleteHuffmanTree);
1001}
1002test "failing puff16" {
1003    try testFailure(.raw, @embedFile("testdata/fuzz/puff16.input"), error.InvalidDynamicBlockHeader);
1004}
1005test "failing puff17" {
1006    try testFailure(.raw, @embedFile("testdata/fuzz/puff17.input"), error.MissingEndOfBlockCode);
1007}
1008test "failing fuzz1" {
1009    try testFailure(.raw, @embedFile("testdata/fuzz/fuzz1.input"), error.InvalidDynamicBlockHeader);
1010}
1011test "failing fuzz2" {
1012    try testFailure(.raw, @embedFile("testdata/fuzz/fuzz2.input"), error.InvalidDynamicBlockHeader);
1013}
1014test "failing fuzz3" {
1015    try testFailure(.raw, @embedFile("testdata/fuzz/fuzz3.input"), error.InvalidMatch);
1016}
1017test "failing fuzz4" {
1018    try testFailure(.raw, @embedFile("testdata/fuzz/fuzz4.input"), error.OversubscribedHuffmanTree);
1019}
1020test "failing puff18" {
1021    try testFailure(.raw, @embedFile("testdata/fuzz/puff18.input"), error.OversubscribedHuffmanTree);
1022}
1023test "failing puff19" {
1024    try testFailure(.raw, @embedFile("testdata/fuzz/puff19.input"), error.OversubscribedHuffmanTree);
1025}
1026test "failing puff20" {
1027    try testFailure(.raw, @embedFile("testdata/fuzz/puff20.input"), error.OversubscribedHuffmanTree);
1028}
1029test "failing puff21" {
1030    try testFailure(.raw, @embedFile("testdata/fuzz/puff21.input"), error.OversubscribedHuffmanTree);
1031}
1032test "failing puff22" {
1033    try testFailure(.raw, @embedFile("testdata/fuzz/puff22.input"), error.OversubscribedHuffmanTree);
1034}
1035test "failing puff23" {
1036    try testFailure(.raw, @embedFile("testdata/fuzz/puff23.input"), error.OversubscribedHuffmanTree);
1037}
1038test "failing puff24" {
1039    try testFailure(.raw, @embedFile("testdata/fuzz/puff24.input"), error.IncompleteHuffmanTree);
1040}
1041test "failing puff25" {
1042    try testFailure(.raw, @embedFile("testdata/fuzz/puff25.input"), error.OversubscribedHuffmanTree);
1043}
1044test "failing puff26" {
1045    try testFailure(.raw, @embedFile("testdata/fuzz/puff26.input"), error.InvalidDynamicBlockHeader);
1046}
1047test "failing puff27" {
1048    try testFailure(.raw, @embedFile("testdata/fuzz/puff27.input"), error.InvalidDynamicBlockHeader);
1049}
1050
1051test "deflate-stream" {
1052    try testDecompress(
1053        .raw,
1054        @embedFile("testdata/fuzz/deflate-stream.input"),
1055        @embedFile("testdata/fuzz/deflate-stream.expect"),
1056    );
1057}
1058
1059test "empty-distance-alphabet01" {
1060    try testDecompress(.raw, @embedFile("testdata/fuzz/empty-distance-alphabet01.input"), "");
1061}
1062
1063test "empty-distance-alphabet02" {
1064    try testDecompress(.raw, @embedFile("testdata/fuzz/empty-distance-alphabet02.input"), "");
1065}
1066
1067test "puff03" {
1068    try testDecompress(.raw, @embedFile("testdata/fuzz/puff03.input"), &.{0xa});
1069}
1070
1071test "puff09" {
1072    try testDecompress(.raw, @embedFile("testdata/fuzz/puff09.input"), "P");
1073}
1074
1075test "invalid block type" {
1076    try testFailure(.raw, &[_]u8{0b110}, error.InvalidBlockType);
1077}
1078
1079test "bug 18966" {
1080    try testDecompress(
1081        .gzip,
1082        @embedFile("testdata/fuzz/bug_18966.input"),
1083        @embedFile("testdata/fuzz/bug_18966.expect"),
1084    );
1085}
1086
1087test "reading into empty buffer" {
1088    // Inspired by https://github.com/ziglang/zig/issues/19895
1089    const input = &[_]u8{
1090        0b0000_0001, 0b0000_1100, 0x00, 0b1111_0011, 0xff, // deflate fixed buffer header len, nlen
1091        'H', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', 0x0a, // non compressed data
1092    };
1093    var in: Reader = .fixed(input);
1094    var decomp: Decompress = .init(&in, .raw, &.{});
1095    const r = &decomp.reader;
1096    var bufs: [1][]u8 = .{&.{}};
1097    try testing.expectEqual(0, try r.readVec(&bufs));
1098}
1099
1100test "zlib header" {
1101    // Truncated header
1102    try testFailure(.zlib, &[_]u8{0x78}, error.EndOfStream);
1103
1104    // Wrong CM
1105    try testFailure(.zlib, &[_]u8{ 0x79, 0x94 }, error.BadZlibHeader);
1106
1107    // Wrong CINFO
1108    try testFailure(.zlib, &[_]u8{ 0x88, 0x98 }, error.BadZlibHeader);
1109
1110    // Truncated checksum
1111    try testFailure(.zlib, &[_]u8{ 0x78, 0xda, 0x03, 0x00, 0x00 }, error.EndOfStream);
1112}
1113
1114test "gzip header" {
1115    // Truncated header
1116    try testFailure(.gzip, &[_]u8{ 0x1f, 0x8B }, error.EndOfStream);
1117
1118    // Wrong CM
1119    try testFailure(.gzip, &[_]u8{
1120        0x1f, 0x8b, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00,
1121        0x00, 0x03,
1122    }, error.BadGzipHeader);
1123
1124    // Truncated checksum
1125    try testFailure(.gzip, &[_]u8{
1126        0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00,
1127        0x00, 0x03, 0x03, 0x00, 0x00, 0x00, 0x00,
1128    }, error.EndOfStream);
1129
1130    // Truncated initial size field
1131    try testFailure(.gzip, &[_]u8{
1132        0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00,
1133        0x00, 0x03, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
1134        0x00, 0x00, 0x00,
1135    }, error.EndOfStream);
1136
1137    try testDecompress(.gzip, &[_]u8{
1138        // GZIP header
1139        0x1f, 0x8b, 0x08, 0x12, 0x00, 0x09, 0x6e, 0x88, 0x00, 0xff, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x00,
1140        // header.FHCRC (should cover entire header)
1141        0x99, 0xd6,
1142        // GZIP data
1143        0x01, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1144    }, "");
1145}
1146
1147test "zlib should not overshoot" {
1148    // Compressed zlib data with extra 4 bytes at the end.
1149    const data = [_]u8{
1150        0x78, 0x9c, 0x73, 0xce, 0x2f, 0xa8, 0x2c, 0xca, 0x4c, 0xcf, 0x28, 0x51, 0x08, 0xcf, 0xcc, 0xc9,
1151        0x49, 0xcd, 0x55, 0x28, 0x4b, 0xcc, 0x53, 0x08, 0x4e, 0xce, 0x48, 0xcc, 0xcc, 0xd6, 0x51, 0x08,
1152        0xce, 0xcc, 0x4b, 0x4f, 0x2c, 0xc8, 0x2f, 0x4a, 0x55, 0x30, 0xb4, 0xb4, 0x34, 0xd5, 0xb5, 0x34,
1153        0x03, 0x00, 0x8b, 0x61, 0x0f, 0xa4, 0x52, 0x5a, 0x94, 0x12,
1154    };
1155
1156    var reader: std.Io.Reader = .fixed(&data);
1157
1158    var decompress_buffer: [flate.max_window_len]u8 = undefined;
1159    var decompress: Decompress = .init(&reader, .zlib, &decompress_buffer);
1160    var out: [128]u8 = undefined;
1161
1162    {
1163        const n = try decompress.reader.readSliceShort(&out);
1164        try std.testing.expectEqual(46, n);
1165        try std.testing.expectEqualStrings("Copyright Willem van Schaik, Singapore 1995-96", out[0..n]);
1166    }
1167
1168    // 4 bytes after compressed chunk are available in reader.
1169    const n = try reader.readSliceShort(&out);
1170    try std.testing.expectEqual(n, 4);
1171    try std.testing.expectEqualSlices(u8, data[data.len - 4 .. data.len], out[0..n]);
1172}
1173
1174fn testFailure(container: Container, in: []const u8, expected_err: anyerror) !void {
1175    var reader: Reader = .fixed(in);
1176    var aw: Writer.Allocating = .init(testing.allocator);
1177    defer aw.deinit();
1178
1179    var decompress: Decompress = .init(&reader, container, &.{});
1180    try testing.expectError(error.ReadFailed, decompress.reader.streamRemaining(&aw.writer));
1181    try testing.expectEqual(expected_err, decompress.err orelse return error.TestFailed);
1182}
1183
1184fn testDecompress(container: Container, compressed: []const u8, expected_plain: []const u8) !void {
1185    var in: std.Io.Reader = .fixed(compressed);
1186    var aw: std.Io.Writer.Allocating = .init(testing.allocator);
1187    defer aw.deinit();
1188
1189    var decompress: Decompress = .init(&in, container, &.{});
1190    const decompressed_len = try decompress.reader.streamRemaining(&aw.writer);
1191    try testing.expectEqual(expected_plain.len, decompressed_len);
1192    try testing.expectEqualSlices(u8, expected_plain, aw.written());
1193}